chore: Fix Typing Issues (#27)
* typing.Any and friends
* message
* chore: Import Message model in llm_client
* fix: 💄 mypy errors
* clean up mypy stuff
* mypy
* format
* mypy
* mypy
* mypy
---------
Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
Co-authored-by: prestonrasmussen <prasmuss15@gmail.com>
This commit is contained in:
parent
7152a211ae
commit
9cc9883e66
24 changed files with 134 additions and 587 deletions
2
Makefile
2
Makefile
|
|
@ -22,7 +22,7 @@ format:
|
||||||
# Lint code
|
# Lint code
|
||||||
lint:
|
lint:
|
||||||
$(RUFF) check
|
$(RUFF) check
|
||||||
$(MYPY) . --show-column-numbers --show-error-codes --pretty
|
$(MYPY) ./core --show-column-numbers --show-error-codes --pretty
|
||||||
|
|
||||||
# Run tests
|
# Run tests
|
||||||
test:
|
test:
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ class Graphiti:
|
||||||
else:
|
else:
|
||||||
self.llm_client = OpenAIClient(
|
self.llm_client = OpenAIClient(
|
||||||
LLMConfig(
|
LLMConfig(
|
||||||
api_key=os.getenv('OPENAI_API_KEY'),
|
api_key=os.getenv('OPENAI_API_KEY', default=''),
|
||||||
model='gpt-4o-mini',
|
model='gpt-4o-mini',
|
||||||
base_url='https://api.openai.com/v1',
|
base_url='https://api.openai.com/v1',
|
||||||
)
|
)
|
||||||
|
|
@ -72,28 +72,16 @@ class Graphiti:
|
||||||
self,
|
self,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int = EPISODE_WINDOW_LEN,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
sources: list[str] | None = 'messages',
|
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""Retrieve the last n episodic nodes from the graph"""
|
"""Retrieve the last n episodic nodes from the graph"""
|
||||||
return await retrieve_episodes(self.driver, reference_time, last_n, sources)
|
return await retrieve_episodes(self.driver, reference_time, last_n)
|
||||||
|
|
||||||
# Invalidate edges that are no longer valid
|
|
||||||
async def invalidate_edges(
|
|
||||||
self,
|
|
||||||
episode: EpisodicNode,
|
|
||||||
new_nodes: list[EntityNode],
|
|
||||||
new_edges: list[EntityEdge],
|
|
||||||
relevant_schema: dict[str, any],
|
|
||||||
previous_episodes: list[EpisodicNode],
|
|
||||||
): ...
|
|
||||||
|
|
||||||
async def add_episode(
|
async def add_episode(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
episode_body: str,
|
episode_body: str,
|
||||||
source_description: str,
|
source_description: str,
|
||||||
reference_time: datetime | None = None,
|
reference_time: datetime,
|
||||||
episode_type: str | None = 'string', # TODO: this field isn't used yet?
|
|
||||||
success_callback: Callable | None = None,
|
success_callback: Callable | None = None,
|
||||||
error_callback: Callable | None = None,
|
error_callback: Callable | None = None,
|
||||||
):
|
):
|
||||||
|
|
@ -104,7 +92,7 @@ class Graphiti:
|
||||||
nodes: list[EntityNode] = []
|
nodes: list[EntityNode] = []
|
||||||
entity_edges: list[EntityEdge] = []
|
entity_edges: list[EntityEdge] = []
|
||||||
episodic_edges: list[EpisodicEdge] = []
|
episodic_edges: list[EpisodicEdge] = []
|
||||||
embedder = self.llm_client.client.embeddings
|
embedder = self.llm_client.get_embedder()
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
||||||
previous_episodes = await self.retrieve_episodes(reference_time)
|
previous_episodes = await self.retrieve_episodes(reference_time)
|
||||||
|
|
@ -234,7 +222,7 @@ class Graphiti:
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
start = time()
|
start = time()
|
||||||
embedder = self.llm_client.client.embeddings
|
embedder = self.llm_client.get_embedder()
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
||||||
episodes = [
|
episodes = [
|
||||||
|
|
@ -276,14 +264,22 @@ class Graphiti:
|
||||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||||
|
|
||||||
# re-map edge pointers so that they don't point to discard dupe nodes
|
# re-map edge pointers so that they don't point to discard dupe nodes
|
||||||
extracted_edges: list[EntityEdge] = resolve_edge_pointers(extracted_edges, uuid_map)
|
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
||||||
episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(episodic_edges, uuid_map)
|
extracted_edges, uuid_map
|
||||||
|
)
|
||||||
|
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
|
||||||
|
episodic_edges, uuid_map
|
||||||
|
)
|
||||||
|
|
||||||
# save episodic edges to KG
|
# save episodic edges to KG
|
||||||
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
await asyncio.gather(
|
||||||
|
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
|
||||||
|
)
|
||||||
|
|
||||||
# Dedupe extracted edges
|
# Dedupe extracted edges
|
||||||
edges = await dedupe_edges_bulk(self.driver, self.llm_client, extracted_edges)
|
edges = await dedupe_edges_bulk(
|
||||||
|
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
|
||||||
|
)
|
||||||
logger.info(f'extracted edge length: {len(edges)}')
|
logger.info(f'extracted edge length: {len(edges)}')
|
||||||
|
|
||||||
# invalidate edges
|
# invalidate edges
|
||||||
|
|
@ -302,12 +298,12 @@ class Graphiti:
|
||||||
edges = (
|
edges = (
|
||||||
await hybrid_search(
|
await hybrid_search(
|
||||||
self.driver,
|
self.driver,
|
||||||
self.llm_client.client.embeddings,
|
self.llm_client.get_embedder(),
|
||||||
query,
|
query,
|
||||||
datetime.now(),
|
datetime.now(),
|
||||||
search_config,
|
search_config,
|
||||||
)
|
)
|
||||||
)['edges']
|
).edges
|
||||||
|
|
||||||
facts = [edge.fact for edge in edges]
|
facts = [edge.fact for edge in edges]
|
||||||
|
|
||||||
|
|
@ -315,5 +311,5 @@ class Graphiti:
|
||||||
|
|
||||||
async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
|
async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
|
||||||
return await hybrid_search(
|
return await hybrid_search(
|
||||||
self.driver, self.llm_client.client.embeddings, query, timestamp, config
|
self.driver, self.llm_client.get_embedder(), query, timestamp, config
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
|
import typing
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from ..prompts.models import Message
|
||||||
from .config import LLMConfig
|
from .config import LLMConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -9,5 +11,9 @@ class LLMClient(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
|
def get_embedder(self) -> typing.Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import typing
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
from ..prompts.models import Message
|
||||||
from .client import LLMClient
|
from .client import LLMClient
|
||||||
from .config import LLMConfig
|
from .config import LLMConfig
|
||||||
|
|
||||||
|
|
@ -14,16 +17,26 @@ class OpenAIClient(LLMClient):
|
||||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
self.model = config.model
|
self.model = config.model
|
||||||
|
|
||||||
async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
|
def get_embedder(self) -> typing.Any:
|
||||||
|
return self.client.embeddings
|
||||||
|
|
||||||
|
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||||
|
openai_messages: list[ChatCompletionMessageParam] = []
|
||||||
|
for m in messages:
|
||||||
|
if m.role == 'user':
|
||||||
|
openai_messages.append({'role': 'user', 'content': m.content})
|
||||||
|
elif m.role == 'system':
|
||||||
|
openai_messages.append({'role': 'system', 'content': m.content})
|
||||||
try:
|
try:
|
||||||
response = await self.client.chat.completions.create(
|
response = await self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=openai_messages,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
max_tokens=3000,
|
max_tokens=3000,
|
||||||
response_format={'type': 'json_object'},
|
response_format={'type': 'json_object'},
|
||||||
)
|
)
|
||||||
return json.loads(response.choices[0].message.content)
|
result = response.choices[0].message.content or ''
|
||||||
|
return json.loads(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'Error in generating LLM response: {e}')
|
logger.error(f'Error in generating LLM response: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
from typing import Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from .models import Message, PromptFunction, PromptVersion
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
|
||||||
|
|
@ -7,6 +7,7 @@ from .models import Message, PromptFunction, PromptVersion
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
v1: PromptVersion
|
v1: PromptVersion
|
||||||
v2: PromptVersion
|
v2: PromptVersion
|
||||||
|
edge_list: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
|
|
@ -15,7 +16,7 @@ class Versions(TypedDict):
|
||||||
edge_list: PromptFunction
|
edge_list: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def v1(context: dict[str, any]) -> list[Message]:
|
def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
|
|
@ -55,7 +56,7 @@ def v1(context: dict[str, any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def v2(context: dict[str, any]) -> list[Message]:
|
def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
|
|
@ -97,7 +98,7 @@ def v2(context: dict[str, any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def edge_list(context: dict[str, any]) -> list[Message]:
|
def edge_list(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
from typing import Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from .models import Message, PromptFunction, PromptVersion
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
|
||||||
|
|
@ -16,7 +16,7 @@ class Versions(TypedDict):
|
||||||
node_list: PromptVersion
|
node_list: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
def v1(context: dict[str, any]) -> list[Message]:
|
def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
|
|
@ -56,7 +56,7 @@ def v1(context: dict[str, any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def v2(context: dict[str, any]) -> list[Message]:
|
def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
|
|
@ -96,7 +96,7 @@ def v2(context: dict[str, any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def node_list(context: dict[str, any]) -> list[Message]:
|
def node_list(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
from typing import Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from .models import Message, PromptFunction, PromptVersion
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
|
||||||
|
|
@ -14,7 +14,7 @@ class Versions(TypedDict):
|
||||||
v2: PromptFunction
|
v2: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def v1(context: dict[str, any]) -> list[Message]:
|
def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
|
|
@ -70,7 +70,7 @@ def v1(context: dict[str, any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def v2(context: dict[str, any]) -> list[Message]:
|
def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
from typing import Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from .models import Message, PromptFunction, PromptVersion
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
|
||||||
|
|
@ -16,7 +16,7 @@ class Versions(TypedDict):
|
||||||
v3: PromptFunction
|
v3: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def v1(context: dict[str, any]) -> list[Message]:
|
def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
|
|
@ -64,7 +64,7 @@ def v1(context: dict[str, any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def v2(context: dict[str, any]) -> list[Message]:
|
def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
|
|
@ -105,7 +105,7 @@ def v2(context: dict[str, any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def v3(context: dict[str, any]) -> list[Message]:
|
def v3(context: dict[str, Any]) -> list[Message]:
|
||||||
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
|
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
|
||||||
|
|
||||||
user_prompt = f"""
|
user_prompt = f"""
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from .models import Message, PromptFunction, PromptVersion
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
|
||||||
|
|
@ -11,7 +11,7 @@ class Versions(TypedDict):
|
||||||
v1: PromptFunction
|
v1: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def v1(context: dict[str, any]) -> list[Message]:
|
def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role='system',
|
role='system',
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from .dedupe_edges import (
|
from .dedupe_edges import (
|
||||||
Prompt as DedupeEdgesPrompt,
|
Prompt as DedupeEdgesPrompt,
|
||||||
|
|
@ -68,7 +68,7 @@ class VersionWrapper:
|
||||||
def __init__(self, func: PromptFunction):
|
def __init__(self, func: PromptFunction):
|
||||||
self.func = func
|
self.func = func
|
||||||
|
|
||||||
def __call__(self, context: dict[str, any]) -> list[Message]:
|
def __call__(self, context: dict[str, Any]) -> list[Message]:
|
||||||
return self.func(context)
|
return self.func(context)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -81,7 +81,7 @@ class PromptTypeWrapper:
|
||||||
class PromptLibraryWrapper:
|
class PromptLibraryWrapper:
|
||||||
def __init__(self, library: PromptLibraryImpl):
|
def __init__(self, library: PromptLibraryImpl):
|
||||||
for prompt_type, versions in library.items():
|
for prompt_type, versions in library.items():
|
||||||
setattr(self, prompt_type, PromptTypeWrapper(versions))
|
setattr(self, prompt_type, PromptTypeWrapper(versions)) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
||||||
|
|
@ -91,5 +91,4 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
||||||
'dedupe_edges': dedupe_edges_versions,
|
'dedupe_edges': dedupe_edges_versions,
|
||||||
'invalidate_edges': invalidate_edges_versions,
|
'invalidate_edges': invalidate_edges_versions,
|
||||||
}
|
}
|
||||||
|
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
|
||||||
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Callable, Protocol
|
from typing import Any, Callable, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -9,7 +9,7 @@ class Message(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class PromptVersion(Protocol):
|
class PromptVersion(Protocol):
|
||||||
def __call__(self, context: dict[str, any]) -> list[Message]: ...
|
def __call__(self, context: dict[str, Any]) -> list[Message]: ...
|
||||||
|
|
||||||
|
|
||||||
PromptFunction = Callable[[dict[str, any]], list[Message]]
|
PromptFunction = Callable[[dict[str, Any]], list[Message]]
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,9 @@ from time import time
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.edges import Edge
|
from core.edges import EntityEdge
|
||||||
from core.llm_client.config import EMBEDDING_DIM
|
from core.llm_client.config import EMBEDDING_DIM
|
||||||
from core.nodes import Node
|
from core.nodes import EntityNode, EpisodicNode
|
||||||
from core.search.search_utils import (
|
from core.search.search_utils import (
|
||||||
edge_fulltext_search,
|
edge_fulltext_search,
|
||||||
edge_similarity_search,
|
edge_similarity_search,
|
||||||
|
|
@ -28,9 +28,15 @@ class SearchConfig(BaseModel):
|
||||||
reranker: str = 'rrf'
|
reranker: str = 'rrf'
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResults(BaseModel):
|
||||||
|
episodes: list[EpisodicNode]
|
||||||
|
nodes: list[EntityNode]
|
||||||
|
edges: list[EntityEdge]
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_search(
|
async def hybrid_search(
|
||||||
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
|
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
|
||||||
) -> dict[str, [Node | Edge]]:
|
) -> SearchResults:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
episodes = []
|
episodes = []
|
||||||
|
|
@ -86,11 +92,7 @@ async def hybrid_search(
|
||||||
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||||
edges.extend(reranked_edges)
|
edges.extend(reranked_edges)
|
||||||
|
|
||||||
context = {
|
context = SearchResults(episodes=episodes, nodes=nodes, edges=edges)
|
||||||
'episodes': episodes,
|
|
||||||
'nodes': nodes,
|
|
||||||
'edges': edges,
|
|
||||||
}
|
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import typing
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||||
RELEVANT_SCHEMA_LIMIT = 3
|
RELEVANT_SCHEMA_LIMIT = 3
|
||||||
|
|
||||||
|
|
||||||
def parse_db_date(neo_date: neo4j_time.Date | None) -> datetime | None:
|
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||||
return neo_date.to_native() if neo_date else None
|
return neo_date.to_native() if neo_date else None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -41,7 +42,7 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
labels=['Entity'],
|
labels=['Entity'],
|
||||||
created_at=datetime.now(),
|
created_at=record['created_at'].to_native(),
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -74,7 +75,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||||
node_ids=node_ids,
|
node_ids=node_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
context = {}
|
context: dict[str, typing.Any] = {}
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
n_uuid = record['source_node_uuid']
|
n_uuid = record['source_node_uuid']
|
||||||
|
|
@ -173,7 +174,7 @@ async def entity_similarity_search(
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
labels=['Entity'],
|
labels=['Entity'],
|
||||||
created_at=datetime.now(),
|
created_at=record['created_at'].to_native(),
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -208,7 +209,7 @@ async def entity_fulltext_search(
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
labels=['Entity'],
|
labels=['Entity'],
|
||||||
created_at=datetime.now(),
|
created_at=record['created_at'].to_native(),
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -277,7 +278,11 @@ async def get_relevant_nodes(
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[entity_fulltext_search(node.name, driver) for node in nodes],
|
*[entity_fulltext_search(node.name, driver) for node in nodes],
|
||||||
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
|
*[
|
||||||
|
entity_similarity_search(node.name_embedding, driver)
|
||||||
|
for node in nodes
|
||||||
|
if node.name_embedding is not None
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
|
|
@ -303,7 +308,11 @@ async def get_relevant_edges(
|
||||||
relevant_edge_uuids = set()
|
relevant_edge_uuids = set()
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
|
*[
|
||||||
|
edge_similarity_search(edge.fact_embedding, driver)
|
||||||
|
for edge in edges
|
||||||
|
if edge.fact_embedding is not None
|
||||||
|
],
|
||||||
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
|
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,15 @@
|
||||||
from .maintenance import (
|
from .maintenance import (
|
||||||
build_episodic_edges,
|
build_episodic_edges,
|
||||||
clear_data,
|
clear_data,
|
||||||
extract_new_edges,
|
extract_edges,
|
||||||
extract_new_nodes,
|
extract_nodes,
|
||||||
retrieve_episodes,
|
retrieve_episodes,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'extract_new_edges',
|
'extract_edges',
|
||||||
'build_episodic_edges',
|
'build_episodic_edges',
|
||||||
'extract_new_nodes',
|
'extract_nodes',
|
||||||
'clear_data',
|
'clear_data',
|
||||||
'retrieve_episodes',
|
'retrieve_episodes',
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import typing
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
|
|
@ -121,8 +122,8 @@ async def dedupe_edges_bulk(
|
||||||
|
|
||||||
|
|
||||||
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
|
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
uuid_map = {}
|
uuid_map: dict[str, str] = {}
|
||||||
name_map = {}
|
name_map: dict[str, EntityNode] = {}
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if node.name in name_map:
|
if node.name in name_map:
|
||||||
uuid_map[node.uuid] = name_map[node.name].uuid
|
uuid_map[node.uuid] = name_map[node.name].uuid
|
||||||
|
|
@ -182,7 +183,10 @@ def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
|
||||||
return compressed_map
|
return compressed_map
|
||||||
|
|
||||||
|
|
||||||
def resolve_edge_pointers(edges: list[Edge], uuid_map: dict[str, str]):
|
E = typing.TypeVar('E', bound=Edge)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
source_uuid = edge.source_node_uuid
|
source_uuid = edge.source_node_uuid
|
||||||
target_uuid = edge.target_node_uuid
|
target_uuid = edge.target_node_uuid
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,15 @@
|
||||||
from .edge_operations import build_episodic_edges, extract_new_edges
|
from .edge_operations import build_episodic_edges, extract_edges
|
||||||
from .graph_data_operations import (
|
from .graph_data_operations import (
|
||||||
clear_data,
|
clear_data,
|
||||||
retrieve_episodes,
|
retrieve_episodes,
|
||||||
)
|
)
|
||||||
from .node_operations import extract_new_nodes
|
from .node_operations import extract_nodes
|
||||||
from .temporal_operations import invalidate_edges
|
from .temporal_operations import invalidate_edges
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'extract_new_edges',
|
'extract_edges',
|
||||||
'build_episodic_edges',
|
'build_episodic_edges',
|
||||||
'extract_new_nodes',
|
'extract_nodes',
|
||||||
'clear_data',
|
'clear_data',
|
||||||
'retrieve_episodes',
|
'retrieve_episodes',
|
||||||
'invalidate_edges',
|
'invalidate_edges',
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
@ -8,7 +7,6 @@ from core.edges import EntityEdge, EpisodicEdge
|
||||||
from core.llm_client import LLMClient
|
from core.llm_client import LLMClient
|
||||||
from core.nodes import EntityNode, EpisodicNode
|
from core.nodes import EntityNode, EpisodicNode
|
||||||
from core.prompts import prompt_library
|
from core.prompts import prompt_library
|
||||||
from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -31,103 +29,6 @@ def build_episodic_edges(
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
async def extract_new_edges(
|
|
||||||
llm_client: LLMClient,
|
|
||||||
episode: EpisodicNode,
|
|
||||||
new_nodes: list[EntityNode],
|
|
||||||
relevant_schema: dict[str, any],
|
|
||||||
previous_episodes: list[EpisodicNode],
|
|
||||||
) -> tuple[list[EntityEdge], list[EntityNode]]:
|
|
||||||
# Prepare context for LLM
|
|
||||||
context = {
|
|
||||||
'episode_content': episode.content,
|
|
||||||
'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
|
|
||||||
'relevant_schema': json.dumps(relevant_schema, indent=2),
|
|
||||||
'new_nodes': [{'name': node.name, 'summary': node.summary} for node in new_nodes],
|
|
||||||
'previous_episodes': [
|
|
||||||
{
|
|
||||||
'content': ep.content,
|
|
||||||
'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
|
|
||||||
}
|
|
||||||
for ep in previous_episodes
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(prompt_library.extract_edges.v1(context))
|
|
||||||
new_edges_data = llm_response.get('new_edges', [])
|
|
||||||
logger.info(f'Extracted new edges: {new_edges_data}')
|
|
||||||
|
|
||||||
# Convert the extracted data into EntityEdge objects
|
|
||||||
new_edges = []
|
|
||||||
for edge_data in new_edges_data:
|
|
||||||
source_node = next(
|
|
||||||
(node for node in new_nodes if node.name == edge_data['source_node']),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
target_node = next(
|
|
||||||
(node for node in new_nodes if node.name == edge_data['target_node']),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If source or target is not in new_nodes, check if it's an existing node
|
|
||||||
if source_node is None and edge_data['source_node'] in relevant_schema['nodes']:
|
|
||||||
existing_node_data = relevant_schema['nodes'][edge_data['source_node']]
|
|
||||||
source_node = EntityNode(
|
|
||||||
uuid=existing_node_data['uuid'],
|
|
||||||
name=edge_data['source_node'],
|
|
||||||
labels=[existing_node_data['label']],
|
|
||||||
summary='',
|
|
||||||
created_at=datetime.now(),
|
|
||||||
)
|
|
||||||
if target_node is None and edge_data['target_node'] in relevant_schema['nodes']:
|
|
||||||
existing_node_data = relevant_schema['nodes'][edge_data['target_node']]
|
|
||||||
target_node = EntityNode(
|
|
||||||
uuid=existing_node_data['uuid'],
|
|
||||||
name=edge_data['target_node'],
|
|
||||||
labels=[existing_node_data['label']],
|
|
||||||
summary='',
|
|
||||||
created_at=datetime.now(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
source_node
|
|
||||||
and target_node
|
|
||||||
and not (
|
|
||||||
source_node.name.startswith('Message') or target_node.name.startswith('Message')
|
|
||||||
)
|
|
||||||
):
|
|
||||||
valid_at = (
|
|
||||||
datetime.fromisoformat(edge_data['valid_at'])
|
|
||||||
if edge_data['valid_at']
|
|
||||||
else episode.valid_at or datetime.now()
|
|
||||||
)
|
|
||||||
invalid_at = (
|
|
||||||
datetime.fromisoformat(edge_data['invalid_at']) if edge_data['invalid_at'] else None
|
|
||||||
)
|
|
||||||
|
|
||||||
new_edge = EntityEdge(
|
|
||||||
source_node=source_node,
|
|
||||||
target_node=target_node,
|
|
||||||
name=edge_data['relation_type'],
|
|
||||||
fact=edge_data['fact'],
|
|
||||||
episodes=[episode.uuid],
|
|
||||||
created_at=datetime.now(),
|
|
||||||
valid_at=valid_at,
|
|
||||||
invalid_at=invalid_at,
|
|
||||||
)
|
|
||||||
new_edges.append(new_edge)
|
|
||||||
logger.info(
|
|
||||||
f'Created new edge: {new_edge.name} from {source_node.name} (UUID: {source_node.uuid}) to {target_node.name} (UUID: {target_node.uuid})'
|
|
||||||
)
|
|
||||||
|
|
||||||
affected_nodes = set()
|
|
||||||
|
|
||||||
for edge in new_edges:
|
|
||||||
affected_nodes.add(edge.source_node)
|
|
||||||
affected_nodes.add(edge.target_node)
|
|
||||||
return new_edges, list(affected_nodes)
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_edges(
|
async def extract_edges(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
|
|
@ -186,45 +87,6 @@ def create_edge_identifier(
|
||||||
return f'{source_node.name}-{edge.name}-{target_node.name}'
|
return f'{source_node.name}-{edge.name}-{target_node.name}'
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_extracted_edges_v2(
|
|
||||||
llm_client: LLMClient,
|
|
||||||
extracted_edges: list[NodeEdgeNodeTriplet],
|
|
||||||
existing_edges: list[NodeEdgeNodeTriplet],
|
|
||||||
) -> list[NodeEdgeNodeTriplet]:
|
|
||||||
# Create edge map
|
|
||||||
edge_map = {}
|
|
||||||
for n1, edge, n2 in existing_edges:
|
|
||||||
edge_map[create_edge_identifier(n1, edge, n2)] = edge
|
|
||||||
for n1, edge, n2 in extracted_edges:
|
|
||||||
if create_edge_identifier(n1, edge, n2) in edge_map:
|
|
||||||
continue
|
|
||||||
edge_map[create_edge_identifier(n1, edge, n2)] = edge
|
|
||||||
|
|
||||||
# Prepare context for LLM
|
|
||||||
context = {
|
|
||||||
'extracted_edges': [
|
|
||||||
{'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact}
|
|
||||||
for n1, edge, n2 in extracted_edges
|
|
||||||
],
|
|
||||||
'existing_edges': [
|
|
||||||
{'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact}
|
|
||||||
for n1, edge, n2 in extracted_edges
|
|
||||||
],
|
|
||||||
}
|
|
||||||
logger.info(prompt_library.dedupe_edges.v2(context))
|
|
||||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v2(context))
|
|
||||||
new_edges_data = llm_response.get('new_edges', [])
|
|
||||||
logger.info(f'Extracted new edges: {new_edges_data}')
|
|
||||||
|
|
||||||
# Get full edge data
|
|
||||||
edges = []
|
|
||||||
for edge_data in new_edges_data:
|
|
||||||
edge = edge_map[edge_data['triplet']]
|
|
||||||
edges.append(edge)
|
|
||||||
|
|
||||||
return edges
|
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_extracted_edges(
|
async def dedupe_extracted_edges(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
extracted_edges: list[EntityEdge],
|
extracted_edges: list[EntityEdge],
|
||||||
|
|
|
||||||
|
|
@ -52,9 +52,7 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
||||||
}}
|
}}
|
||||||
""",
|
""",
|
||||||
]
|
]
|
||||||
index_queries: list[LiteralString] = (
|
index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
|
||||||
range_indices + fulltext_indices + vector_indices
|
|
||||||
)
|
|
||||||
|
|
||||||
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
|
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
|
||||||
|
|
||||||
|
|
@ -72,7 +70,6 @@ async def retrieve_episodes(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int = EPISODE_WINDOW_LEN,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
sources: list[str] | None = 'messages',
|
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""Retrieve the last n episodic nodes from the graph"""
|
"""Retrieve the last n episodic nodes from the graph"""
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
|
|
@ -97,14 +94,7 @@ async def retrieve_episodes(
|
||||||
created_at=datetime.fromtimestamp(
|
created_at=datetime.fromtimestamp(
|
||||||
record['created_at'].to_native().timestamp(), timezone.utc
|
record['created_at'].to_native().timestamp(), timezone.utc
|
||||||
),
|
),
|
||||||
valid_at=(
|
valid_at=(record['valid_at'].to_native()),
|
||||||
datetime.fromtimestamp(
|
|
||||||
record['valid_at'].to_native().timestamp(),
|
|
||||||
timezone.utc,
|
|
||||||
)
|
|
||||||
if record['valid_at'] is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
source=record['source'],
|
source=record['source'],
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
|
|
|
||||||
|
|
@ -9,53 +9,6 @@ from core.prompts import prompt_library
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def extract_new_nodes(
|
|
||||||
llm_client: LLMClient,
|
|
||||||
episode: EpisodicNode,
|
|
||||||
relevant_schema: dict[str, any],
|
|
||||||
previous_episodes: list[EpisodicNode],
|
|
||||||
) -> list[EntityNode]:
|
|
||||||
# Prepare context for LLM
|
|
||||||
existing_nodes = [
|
|
||||||
{'name': node_name, 'label': node_info['label'], 'uuid': node_info['uuid']}
|
|
||||||
for node_name, node_info in relevant_schema['nodes'].items()
|
|
||||||
]
|
|
||||||
|
|
||||||
context = {
|
|
||||||
'episode_content': episode.content,
|
|
||||||
'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
|
|
||||||
'existing_nodes': existing_nodes,
|
|
||||||
'previous_episodes': [
|
|
||||||
{
|
|
||||||
'content': ep.content,
|
|
||||||
'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
|
|
||||||
}
|
|
||||||
for ep in previous_episodes
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v1(context))
|
|
||||||
new_nodes_data = llm_response.get('new_nodes', [])
|
|
||||||
logger.info(f'Extracted new nodes: {new_nodes_data}')
|
|
||||||
# Convert the extracted data into EntityNode objects
|
|
||||||
new_nodes = []
|
|
||||||
for node_data in new_nodes_data:
|
|
||||||
# Check if the node already exists
|
|
||||||
if not any(existing_node['name'] == node_data['name'] for existing_node in existing_nodes):
|
|
||||||
new_node = EntityNode(
|
|
||||||
name=node_data['name'],
|
|
||||||
labels=node_data['labels'],
|
|
||||||
summary=node_data['summary'],
|
|
||||||
created_at=datetime.now(),
|
|
||||||
)
|
|
||||||
new_nodes.append(new_node)
|
|
||||||
logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
|
||||||
else:
|
|
||||||
logger.info(f"Node {node_data['name']} already exists, skipping creation.")
|
|
||||||
|
|
||||||
return new_nodes
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_nodes(
|
async def extract_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
|
|
@ -100,16 +53,16 @@ async def dedupe_extracted_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
extracted_nodes: list[EntityNode],
|
extracted_nodes: list[EntityNode],
|
||||||
existing_nodes: list[EntityNode],
|
existing_nodes: list[EntityNode],
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str], list[EntityNode]]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
# build existing node map
|
# build existing node map
|
||||||
node_map = {}
|
node_map: dict[str, EntityNode] = {}
|
||||||
for node in existing_nodes:
|
for node in existing_nodes:
|
||||||
node_map[node.name] = node
|
node_map[node.name] = node
|
||||||
|
|
||||||
# Temp hack
|
# Temp hack
|
||||||
new_nodes_map = {}
|
new_nodes_map: dict[str, EntityNode] = {}
|
||||||
for node in extracted_nodes:
|
for node in extracted_nodes:
|
||||||
new_nodes_map[node.name] = node
|
new_nodes_map[node.name] = node
|
||||||
|
|
||||||
|
|
@ -134,14 +87,14 @@ async def dedupe_extracted_nodes(
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
|
logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
|
||||||
|
|
||||||
uuid_map = {}
|
uuid_map: dict[str, str] = {}
|
||||||
for duplicate in duplicate_data:
|
for duplicate in duplicate_data:
|
||||||
uuid = new_nodes_map[duplicate['name']].uuid
|
uuid = new_nodes_map[duplicate['name']].uuid
|
||||||
uuid_value = node_map[duplicate['duplicate_of']].uuid
|
uuid_value = node_map[duplicate['duplicate_of']].uuid
|
||||||
uuid_map[uuid] = uuid_value
|
uuid_map[uuid] = uuid_value
|
||||||
|
|
||||||
nodes = []
|
nodes: list[EntityNode] = []
|
||||||
brand_new_nodes = []
|
brand_new_nodes: list[EntityNode] = []
|
||||||
for node in extracted_nodes:
|
for node in extracted_nodes:
|
||||||
if node.uuid in uuid_map:
|
if node.uuid in uuid_map:
|
||||||
existing_uuid = uuid_map[node.uuid]
|
existing_uuid = uuid_map[node.uuid]
|
||||||
|
|
@ -149,7 +102,9 @@ async def dedupe_extracted_nodes(
|
||||||
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
|
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
|
||||||
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
|
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
|
||||||
existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None)
|
existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None)
|
||||||
nodes.append(existing_node)
|
if existing_node:
|
||||||
|
nodes.append(existing_node)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
brand_new_nodes.append(node)
|
brand_new_nodes.append(node)
|
||||||
nodes.append(node)
|
nodes.append(node)
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,8 @@ def extract_node_edge_node_triplet(
|
||||||
) -> NodeEdgeNodeTriplet:
|
) -> NodeEdgeNodeTriplet:
|
||||||
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
|
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
|
||||||
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
|
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
|
||||||
|
if not source_node or not target_node:
|
||||||
|
raise ValueError(f'Source or target node not found for edge {edge.uuid}')
|
||||||
return (source_node, edge, target_node)
|
return (source_node, edge, target_node)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -31,11 +33,8 @@ def prepare_edges_for_invalidation(
|
||||||
new_edges: list[EntityEdge],
|
new_edges: list[EntityEdge],
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
|
) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
|
||||||
existing_edges_pending_invalidation = [] # TODO: this is not yet used?
|
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = []
|
||||||
new_edges_with_nodes = [] # TODO: this is not yet used?
|
new_edges_with_nodes: list[NodeEdgeNodeTriplet] = []
|
||||||
|
|
||||||
existing_edges_pending_invalidation = []
|
|
||||||
new_edges_with_nodes = []
|
|
||||||
|
|
||||||
for edge_list, result_list in [
|
for edge_list, result_list in [
|
||||||
(existing_edges, existing_edges_pending_invalidation),
|
(existing_edges, existing_edges_pending_invalidation),
|
||||||
|
|
|
||||||
|
|
@ -1,292 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
from time import time
|
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
|
||||||
from neo4j import time as neo4j_time
|
|
||||||
|
|
||||||
from core.edges import EntityEdge
|
|
||||||
from core.nodes import EntityNode
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
RELEVANT_SCHEMA_LIMIT = 3
|
|
||||||
|
|
||||||
|
|
||||||
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
|
|
||||||
RETURN
|
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
n.name AS source_name,
|
|
||||||
n.summary AS source_summary,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
m.name AS target_name,
|
|
||||||
m.summary AS target_summary,
|
|
||||||
r.uuid AS uuid,
|
|
||||||
r.created_at AS created_at,
|
|
||||||
r.name AS name,
|
|
||||||
r.fact AS fact,
|
|
||||||
r.fact_embedding AS fact_embedding,
|
|
||||||
r.episodes AS episodes,
|
|
||||||
r.expired_at AS expired_at,
|
|
||||||
r.valid_at AS valid_at,
|
|
||||||
r.invalid_at AS invalid_at
|
|
||||||
|
|
||||||
""",
|
|
||||||
node_ids=node_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
context = {}
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
n_uuid = record['source_node_uuid']
|
|
||||||
if n_uuid in context:
|
|
||||||
context[n_uuid]['facts'].append(record['fact'])
|
|
||||||
else:
|
|
||||||
context[n_uuid] = {
|
|
||||||
'name': record['source_name'],
|
|
||||||
'summary': record['source_summary'],
|
|
||||||
'facts': [record['fact']],
|
|
||||||
}
|
|
||||||
|
|
||||||
m_uuid = record['target_node_uuid']
|
|
||||||
if m_uuid not in context:
|
|
||||||
context[m_uuid] = {
|
|
||||||
'name': record['target_name'],
|
|
||||||
'summary': record['target_summary'],
|
|
||||||
'facts': [],
|
|
||||||
}
|
|
||||||
logger.info(f'bfs search returned context: {context}')
|
|
||||||
return context
|
|
||||||
|
|
||||||
|
|
||||||
async def edge_similarity_search(
|
|
||||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
|
||||||
) -> list[EntityEdge]:
|
|
||||||
# vector similarity search over embedded facts
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
|
|
||||||
YIELD relationship AS r, score
|
|
||||||
MATCH (n)-[r:RELATES_TO]->(m)
|
|
||||||
RETURN
|
|
||||||
r.uuid AS uuid,
|
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
r.created_at AS created_at,
|
|
||||||
r.name AS name,
|
|
||||||
r.fact AS fact,
|
|
||||||
r.fact_embedding AS fact_embedding,
|
|
||||||
r.episodes AS episodes,
|
|
||||||
r.expired_at AS expired_at,
|
|
||||||
r.valid_at AS valid_at,
|
|
||||||
r.invalid_at AS invalid_at
|
|
||||||
ORDER BY score DESC LIMIT $limit
|
|
||||||
""",
|
|
||||||
search_vector=search_vector,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
edges: list[EntityEdge] = []
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
edge = EntityEdge(
|
|
||||||
uuid=record['uuid'],
|
|
||||||
source_node_uuid=record['source_node_uuid'],
|
|
||||||
target_node_uuid=record['target_node_uuid'],
|
|
||||||
fact=record['fact'],
|
|
||||||
name=record['name'],
|
|
||||||
episodes=record['episodes'],
|
|
||||||
fact_embedding=record['fact_embedding'],
|
|
||||||
created_at=safely_parse_db_date(record['created_at']),
|
|
||||||
expired_at=safely_parse_db_date(record['expired_at']),
|
|
||||||
valid_at=safely_parse_db_date(record['valid_at']),
|
|
||||||
invalid_At=safely_parse_db_date(record['invalid_at']),
|
|
||||||
)
|
|
||||||
|
|
||||||
edges.append(edge)
|
|
||||||
|
|
||||||
return edges
|
|
||||||
|
|
||||||
|
|
||||||
async def entity_similarity_search(
|
|
||||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
|
||||||
) -> list[EntityNode]:
|
|
||||||
# vector similarity search over entity names
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
|
||||||
YIELD node AS n, score
|
|
||||||
RETURN
|
|
||||||
n.uuid As uuid,
|
|
||||||
n.name AS name,
|
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary
|
|
||||||
ORDER BY score DESC
|
|
||||||
""",
|
|
||||||
search_vector=search_vector,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
nodes: list[EntityNode] = []
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
nodes.append(
|
|
||||||
EntityNode(
|
|
||||||
uuid=record['uuid'],
|
|
||||||
name=record['name'],
|
|
||||||
labels=[],
|
|
||||||
created_at=safely_parse_db_date(record['created_at']),
|
|
||||||
summary=record['summary'],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
|
|
||||||
async def entity_fulltext_search(
|
|
||||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
|
||||||
) -> list[EntityNode]:
|
|
||||||
# BM25 search to get top nodes
|
|
||||||
fuzzy_query = query + '~'
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
|
||||||
RETURN
|
|
||||||
node.uuid As uuid,
|
|
||||||
node.name AS name,
|
|
||||||
node.created_at AS created_at,
|
|
||||||
node.summary AS summary
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
""",
|
|
||||||
query=fuzzy_query,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
nodes: list[EntityNode] = []
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
nodes.append(
|
|
||||||
EntityNode(
|
|
||||||
uuid=record['uuid'],
|
|
||||||
name=record['name'],
|
|
||||||
labels=[],
|
|
||||||
created_at=safely_parse_db_date(record['created_at']),
|
|
||||||
summary=record['summary'],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
|
|
||||||
async def edge_fulltext_search(
|
|
||||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
|
||||||
) -> list[EntityEdge]:
|
|
||||||
# fulltext search over facts
|
|
||||||
fuzzy_query = query + '~'
|
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
||||||
YIELD relationship AS r, score
|
|
||||||
MATCH (n:Entity)-[r]->(m:Entity)
|
|
||||||
RETURN
|
|
||||||
r.uuid AS uuid,
|
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
r.created_at AS created_at,
|
|
||||||
r.name AS name,
|
|
||||||
r.fact AS fact,
|
|
||||||
r.fact_embedding AS fact_embedding,
|
|
||||||
r.episodes AS episodes,
|
|
||||||
r.expired_at AS expired_at,
|
|
||||||
r.valid_at AS valid_at,
|
|
||||||
r.invalid_at AS invalid_at
|
|
||||||
ORDER BY score DESC LIMIT $limit
|
|
||||||
""",
|
|
||||||
query=fuzzy_query,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
edges: list[EntityEdge] = []
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
edge = EntityEdge(
|
|
||||||
uuid=record['uuid'],
|
|
||||||
source_node_uuid=record['source_node_uuid'],
|
|
||||||
target_node_uuid=record['target_node_uuid'],
|
|
||||||
fact=record['fact'],
|
|
||||||
name=record['name'],
|
|
||||||
episodes=record['episodes'],
|
|
||||||
fact_embedding=record['fact_embedding'],
|
|
||||||
created_at=safely_parse_db_date(record['created_at']),
|
|
||||||
expired_at=safely_parse_db_date(record['expired_at']),
|
|
||||||
valid_at=safely_parse_db_date(record['valid_at']),
|
|
||||||
invalid_At=safely_parse_db_date(record['invalid_at']),
|
|
||||||
)
|
|
||||||
|
|
||||||
edges.append(edge)
|
|
||||||
|
|
||||||
return edges
|
|
||||||
|
|
||||||
|
|
||||||
def safely_parse_db_date(date_str: neo4j_time.Date) -> datetime:
|
|
||||||
if date_str:
|
|
||||||
return datetime.fromisoformat(date_str.iso_format())
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_nodes(
|
|
||||||
nodes: list[EntityNode],
|
|
||||||
driver: AsyncDriver,
|
|
||||||
) -> list[EntityNode]:
|
|
||||||
start = time()
|
|
||||||
relevant_nodes: list[EntityNode] = []
|
|
||||||
relevant_node_uuids = set()
|
|
||||||
|
|
||||||
results = await asyncio.gather(
|
|
||||||
*[entity_fulltext_search(node.name, driver) for node in nodes],
|
|
||||||
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
|
|
||||||
)
|
|
||||||
|
|
||||||
for result in results:
|
|
||||||
for node in result:
|
|
||||||
if node.uuid in relevant_node_uuids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
relevant_node_uuids.add(node.uuid)
|
|
||||||
relevant_nodes.append(node)
|
|
||||||
|
|
||||||
end = time()
|
|
||||||
logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
|
|
||||||
|
|
||||||
return relevant_nodes
|
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_edges(
|
|
||||||
edges: list[EntityEdge],
|
|
||||||
driver: AsyncDriver,
|
|
||||||
) -> list[EntityEdge]:
|
|
||||||
start = time()
|
|
||||||
relevant_edges: list[EntityEdge] = []
|
|
||||||
relevant_edge_uuids = set()
|
|
||||||
|
|
||||||
results = await asyncio.gather(
|
|
||||||
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
|
|
||||||
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
|
|
||||||
)
|
|
||||||
|
|
||||||
for result in results:
|
|
||||||
for edge in result:
|
|
||||||
if edge.uuid in relevant_edge_uuids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
relevant_edge_uuids.add(edge.uuid)
|
|
||||||
relevant_edges.append(edge)
|
|
||||||
|
|
||||||
end = time()
|
|
||||||
logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
|
|
||||||
|
|
||||||
return relevant_edges
|
|
||||||
|
|
@ -14,8 +14,8 @@ def build_episodic_edges(
|
||||||
for node in entity_nodes:
|
for node in entity_nodes:
|
||||||
edges.append(
|
edges.append(
|
||||||
EpisodicEdge(
|
EpisodicEdge(
|
||||||
source_node_uuid=episode,
|
source_node_uuid=episode.uuid,
|
||||||
target_node_uuid=node,
|
target_node_uuid=node.uuid,
|
||||||
created_at=episode.created_at,
|
created_at=episode.created_at,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,10 @@
|
||||||
name = "graphiti"
|
name = "graphiti"
|
||||||
version = "0.0.1"
|
version = "0.0.1"
|
||||||
description = "Graph building library"
|
description = "Graph building library"
|
||||||
authors = ["Paul Paliychuk <paul@getzep.com>", "Preston Rasmussen <preston@getzep.com>"]
|
authors = [
|
||||||
|
"Paul Paliychuk <paul@getzep.com>",
|
||||||
|
"Preston Rasmussen <preston@getzep.com>",
|
||||||
|
]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
|
|
@ -56,4 +59,4 @@ ignore = ["E501"]
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
quote-style = "single"
|
quote-style = "single"
|
||||||
indent-style = "tab"
|
indent-style = "tab"
|
||||||
docstring-code-format = true
|
docstring-code-format = true
|
||||||
|
|
|
||||||
|
|
@ -103,11 +103,11 @@ async def test_graph_integration():
|
||||||
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
|
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
|
||||||
|
|
||||||
episodic_edge_1 = EpisodicEdge(
|
episodic_edge_1 = EpisodicEdge(
|
||||||
source_node_uuid=episode, target_node_uuid=alice_node, created_at=now
|
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
|
||||||
)
|
)
|
||||||
|
|
||||||
episodic_edge_2 = EpisodicEdge(
|
episodic_edge_2 = EpisodicEdge(
|
||||||
source_node_uuid=episode, target_node_uuid=bob_node, created_at=now
|
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
|
||||||
)
|
)
|
||||||
|
|
||||||
entity_edge = EntityEdge(
|
entity_edge = EntityEdge(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue