Update Maintenance LLM Queries and Partial Schema Retrieval (#6)

* search updates

* add search_utils

* updates

* graph maintenance updates

* revert extract_new_nodes

* revert extract_new_edges

* parallelize node searching

* add edge fulltext search

* search optimizations
This commit is contained in:
Preston Rasmussen 2024-08-18 13:22:31 -04:00 committed by GitHub
parent ad552b527e
commit 4db3906049
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 953 additions and 119 deletions

View file

@ -5,15 +5,16 @@ from neo4j import AsyncDriver
from uuid import uuid4 from uuid import uuid4
import logging import logging
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import Node from core.nodes import Node
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Edge(BaseModel, ABC): class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4())) uuid: str = Field(default_factory=lambda: uuid4().hex)
source_node: Node source_node_uuid: str
target_node: Node target_node_uuid: str
created_at: datetime created_at: datetime
@abstractmethod @abstractmethod
@ -30,11 +31,6 @@ class Edge(BaseModel, ABC):
class EpisodicEdge(Edge): class EpisodicEdge(Edge):
async def save(self, driver: AsyncDriver): async def save(self, driver: AsyncDriver):
if self.uuid is None:
uuid = uuid4()
logger.info(f"Created uuid: {uuid} for episodic edge")
self.uuid = str(uuid)
result = await driver.execute_query( result = await driver.execute_query(
""" """
MATCH (episode:Episodic {uuid: $episode_uuid}) MATCH (episode:Episodic {uuid: $episode_uuid})
@ -42,8 +38,8 @@ class EpisodicEdge(Edge):
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node) MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, created_at: $created_at} SET r = {uuid: $uuid, created_at: $created_at}
RETURN r.uuid AS uuid""", RETURN r.uuid AS uuid""",
episode_uuid=self.source_node.uuid, episode_uuid=self.source_node_uuid,
entity_uuid=self.target_node.uuid, entity_uuid=self.target_node_uuid,
uuid=self.uuid, uuid=self.uuid,
created_at=self.created_at, created_at=self.created_at,
) )
@ -79,10 +75,10 @@ class EntityEdge(Edge):
default=None, description="datetime of when the fact stopped being true" default=None, description="datetime of when the fact stopped being true"
) )
def generate_embedding(self, embedder, model="text-embedding-3-large"): async def generate_embedding(self, embedder, model="text-embedding-3-small"):
text = self.fact.replace("\n", " ") text = self.fact.replace("\n", " ")
embedding = embedder.create(input=[text], model=model).data[0].embedding embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.fact_embedding = embedding self.fact_embedding = embedding[:EMBEDDING_DIM]
return embedding return embedding
@ -96,8 +92,8 @@ class EntityEdge(Edge):
episodes: $episodes, created_at: $created_at, expired_at: $expired_at, episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
valid_at: $valid_at, invalid_at: $invalid_at} valid_at: $valid_at, invalid_at: $invalid_at}
RETURN r.uuid AS uuid""", RETURN r.uuid AS uuid""",
source_uuid=self.source_node.uuid, source_uuid=self.source_node_uuid,
target_uuid=self.target_node.uuid, target_uuid=self.target_node_uuid,
uuid=self.uuid, uuid=self.uuid,
name=self.name, name=self.name,
fact=self.fact, fact=self.fact,

View file

@ -1,12 +1,14 @@
import asyncio import asyncio
from datetime import datetime from datetime import datetime
import logging import logging
from typing import Callable, LiteralString, Tuple from typing import Callable, LiteralString
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
from dotenv import load_dotenv from dotenv import load_dotenv
import os import os
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EntityNode, EpisodicNode, Node from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, Edge from core.edges import EntityEdge, Edge, EpisodicEdge
from core.utils import ( from core.utils import (
build_episodic_edges, build_episodic_edges,
retrieve_relevant_schema, retrieve_relevant_schema,
@ -16,6 +18,15 @@ from core.utils import (
retrieve_episodes, retrieve_episodes,
) )
from core.llm_client import LLMClient, OpenAIClient, LLMConfig from core.llm_client import LLMClient, OpenAIClient, LLMConfig
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.search.search_utils import (
edge_similarity_search,
entity_fulltext_search,
bfs,
get_relevant_nodes,
get_relevant_edges,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,7 +45,7 @@ class Graphiti:
self.llm_client = OpenAIClient( self.llm_client = OpenAIClient(
LLMConfig( LLMConfig(
api_key=os.getenv("OPENAI_API_KEY"), api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o", model="gpt-4o-mini",
base_url="https://api.openai.com/v1", base_url="https://api.openai.com/v1",
) )
) )
@ -75,8 +86,12 @@ class Graphiti:
): ):
"""Process an episode and update the graph""" """Process an episode and update the graph"""
try: try:
nodes: list[Node] = [] nodes: list[EntityNode] = []
edges: list[Edge] = [] entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.client.embeddings
now = datetime.now()
previous_episodes = await self.retrieve_episodes(last_n=3) previous_episodes = await self.retrieve_episodes(last_n=3)
episode = EpisodicNode( episode = EpisodicNode(
name=name, name=name,
@ -84,38 +99,65 @@ class Graphiti:
source="messages", source="messages",
content=episode_body, content=episode_body,
source_description=source_description, source_description=source_description,
created_at=datetime.now(), created_at=now,
valid_at=reference_time, valid_at=reference_time,
) )
# await episode.save(self.driver) # relevant_schema = await self.retrieve_relevant_schema(episode.content)
relevant_schema = await self.retrieve_relevant_schema(episode.content)
new_nodes = await extract_new_nodes( extracted_nodes = await extract_nodes(
self.llm_client, episode, relevant_schema, previous_episodes self.llm_client, episode, previous_episodes
)
# Calculate Embeddings
await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
)
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes
) )
nodes.extend(new_nodes) nodes.extend(new_nodes)
new_edges, affected_nodes = await extract_new_edges(
self.llm_client, episode, new_nodes, relevant_schema, previous_episodes extracted_edges = await extract_edges(
self.llm_client, episode, new_nodes, previous_episodes
) )
edges.extend(new_edges)
episodic_edges = build_episodic_edges( await asyncio.gather(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them *[edge.generate_embedding(embedder) for edge in extracted_edges]
list(set(nodes + affected_nodes)), )
episode,
datetime.now(), existing_edges = await get_relevant_edges(extracted_edges, self.driver)
new_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
)
entity_edges.extend(new_edges)
episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
nodes,
episode,
now,
)
) )
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built # Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
nodes.append(episode)
logger.info(f"Built episodic edges: {episodic_edges}") logger.info(f"Built episodic edges: {episodic_edges}")
edges.extend(episodic_edges)
# invalidated_edges = await self.invalidate_edges( # invalidated_edges = await self.invalidate_edges(
# episode, new_nodes, new_edges, relevant_schema, previous_episodes # episode, new_nodes, new_edges, relevant_schema, previous_episodes
# ) # )
# edges.extend(invalidated_edges) # edges.extend(invalidated_edges)
# Future optimization would be using batch operations to save nodes and edges # Future optimization would be using batch operations to save nodes and edges
await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes]) await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in edges]) await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
# for node in nodes: # for node in nodes:
# if isinstance(node, EntityNode): # if isinstance(node, EntityNode):
# await node.update_summary(self.driver) # await node.update_summary(self.driver)
@ -129,6 +171,10 @@ class Graphiti:
async def build_indices(self): async def build_indices(self):
index_queries: list[LiteralString] = [ index_queries: list[LiteralString] = [
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.uuid)",
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[r:MENTIONS]-() ON (r.uuid)",
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)", "CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)", "CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)", "CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
@ -143,13 +189,19 @@ class Graphiti:
for query in index_queries: for query in index_queries:
await self.driver.execute_query(query) await self.driver.execute_query(query)
# Add the entity indices # Add the semantic indices
await self.driver.execute_query( await self.driver.execute_query(
""" """
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary] CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
""" """
) )
await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON EACH [r.name, r.fact]
"""
)
await self.driver.execute_query( await self.driver.execute_query(
""" """
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
@ -161,29 +213,40 @@ class Graphiti:
""" """
) )
async def search( await self.driver.execute_query(
self, query: str, config """
) -> (list)[tuple[EntityNode, list[EntityEdge]]]: CREATE VECTOR INDEX name_embedding IF NOT EXISTS
(vec_nodes, vec_edges) = similarity_search(query, embedder) FOR (n:Entity) ON (n.name_embedding)
(text_nodes, text_edges) = fulltext_search(query) OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
"""
)
nodes = vec_nodes.extend(text_nodes) async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
edges = vec_edges.extend(text_edges) text = query.replace("\n", " ")
search_vector = (
(
await self.llm_client.client.embeddings.create(
input=[text], model="text-embedding-3-small"
)
)
.data[0]
.embedding[:EMBEDDING_DIM]
)
results = bfs(nodes, edges, k=1) edges = await edge_similarity_search(search_vector, self.driver)
nodes = await entity_fulltext_search(query, self.driver)
episode_ids = ["Mode of episode ids"] node_ids = [node.uuid for node in nodes]
episodes = get_episodes(episode_ids[:episode_count]) for edge in edges:
node_ids.append(edge.source_node_uuid)
node_ids.append(edge.target_node_uuid)
return [(node, edges)], episodes node_ids = list(dict.fromkeys(node_ids))
# Invalidate edges that are no longer valid context = await bfs(node_ids, self.driver)
async def invalidate_edges(
self, return context
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...

View file

@ -1,3 +1,6 @@
EMBEDDING_DIM = 1024
class LLMConfig: class LLMConfig:
""" """
Configuration class for the Language Learning Model (LLM). Configuration class for the Language Learning Model (LLM).
@ -10,7 +13,7 @@ class LLMConfig:
def __init__( def __init__(
self, self,
api_key: str, api_key: str,
model: str = "gpt-4o", model: str = "gpt-4o-mini",
base_url: str = "https://api.openai.com", base_url: str = "https://api.openai.com",
): ):
""" """
@ -21,7 +24,7 @@ class LLMConfig:
This is required for making authorized requests. This is required for making authorized requests.
model (str, optional): The specific LLM model to use for generating responses. model (str, optional): The specific LLM model to use for generating responses.
Defaults to "gpt-4o", which appears to be a custom model name. Defaults to "gpt-4o-mini", which appears to be a custom model name.
Common values might include "gpt-3.5-turbo" or "gpt-4". Common values might include "gpt-3.5-turbo" or "gpt-4".
base_url (str, optional): The base URL of the LLM API service. base_url (str, optional): The base URL of the LLM API service.

View file

@ -8,11 +8,13 @@ from pydantic import BaseModel, Field
from neo4j import AsyncDriver from neo4j import AsyncDriver
import logging import logging
from core.llm_client.config import EMBEDDING_DIM
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Node(BaseModel, ABC): class Node(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4())) uuid: str = Field(default_factory=lambda: uuid4().hex)
name: str name: str
labels: list[str] = Field(default_factory=list) labels: list[str] = Field(default_factory=list)
created_at: datetime created_at: datetime
@ -66,21 +68,32 @@ class EpisodicNode(Node):
class EntityNode(Node): class EntityNode(Node):
name_embedding: list[float] | None = Field(
default=None, description="embedding of the name"
)
summary: str = Field(description="regional summary of surrounding edges") summary: str = Field(description="regional summary of surrounding edges")
async def update_summary(self, driver: AsyncDriver): ... async def update_summary(self, driver: AsyncDriver): ...
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ... async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
async def generate_name_embedding(self, embedder, model="text-embedding-3-small"):
text = self.name.replace("\n", " ")
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.name_embedding = embedding[:EMBEDDING_DIM]
return embedding
async def save(self, driver: AsyncDriver): async def save(self, driver: AsyncDriver):
result = await driver.execute_query( result = await driver.execute_query(
""" """
MERGE (n:Entity {uuid: $uuid}) MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, summary: $summary, created_at: $created_at} SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at}
RETURN n.uuid AS uuid""", RETURN n.uuid AS uuid""",
uuid=self.uuid, uuid=self.uuid,
name=self.name, name=self.name,
summary=self.summary, summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at, created_at=self.created_at,
) )

View file

@ -0,0 +1,56 @@
import json
from typing import TypedDict, Protocol
from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol):
v1: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that de-duplicates relationship from edge lists.",
),
Message(
role="user",
content=f"""
Given the following context, deduplicate edges from a list of new edges given a list of existing edges:
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}
New Edges:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
1. start with the list of edges from New Edges
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
edge in the list
3. Respond with the resulting list of edges
Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates,
duplicate edges may have different names
Respond with a JSON object in the following format:
{{
"new_edges": [
{{
"name": "Unique identifier for the edge",
"fact": "one sentence description of the fact"
}}
]
}}
""",
),
]
versions: Versions = {"v1": v1}

View file

@ -0,0 +1,56 @@
import json
from typing import TypedDict, Protocol
from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol):
v1: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that de-duplicates nodes from node lists.",
),
Message(
role="user",
content=f"""
Given the following context, deduplicate nodes from a list of new nodes given a list of existing nodes:
Existing Nodes:
{json.dumps(context['existing_nodes'], indent=2)}
New Nodes:
{json.dumps(context['extracted_nodes'], indent=2)}
Task:
1. start with the list of nodes from New Nodes
2. If any node in New Nodes is a duplicate of a node in Existing Nodes, replace the new node with the existing
node in the list
3. Respond with the resulting list of nodes
Guidelines:
1. Use both the name and summary of nodes to determine if they are duplicates,
duplicate nodes may have different names
Respond with a JSON object in the following format:
{{
"new_nodes": [
{{
"name": "Unique identifier for the node",
"summary": "Brief summary of the node's role or significance"
}}
]
}}
""",
),
]
versions: Versions = {"v1": v1}

View file

@ -6,10 +6,12 @@ from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol): class Prompt(Protocol):
v1: PromptVersion v1: PromptVersion
v2: PromptVersion
class Versions(TypedDict): class Versions(TypedDict):
v1: PromptFunction v1: PromptFunction
v2: PromptFunction
def v1(context: dict[str, any]) -> list[Message]: def v1(context: dict[str, any]) -> list[Message]:
@ -68,6 +70,108 @@ def v1(context: dict[str, any]) -> list[Message]:
] ]
versions: Versions = { def v1(context: dict[str, any]) -> list[Message]:
"v1": v1, return [
} Message(
role="system",
content="You are a helpful assistant that extracts graph edges from provided context.",
),
Message(
role="user",
content=f"""
Given the following context, extract new semantic edges (relationships) that need to be added to the knowledge graph:
Current Graph Structure:
{context['relevant_schema']}
New Nodes:
{json.dumps(context['new_nodes'], indent=2)}
New Episode:
Content: {context['episode_content']}
Timestamp: {context['episode_timestamp']}
Previous Episodes:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
Extract new semantic edges based on the content of the current episode, considering the existing graph structure, new nodes, and context from previous episodes.
Guidelines:
1. Create edges only between semantic nodes (not episodic nodes like messages).
2. Each edge should represent a clear relationship between two semantic nodes.
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
4. Provide a more detailed fact describing the relationship.
5. If a relationship seems to update an existing one, create a new edge with the updated information.
6. Consider temporal aspects of relationships when relevant.
7. Do not create edges involving episodic nodes (like Message 1 or Message 2).
8. Use existing nodes from the current graph structure when appropriate.
Respond with a JSON object in the following format:
{{
"new_edges": [
{{
"relation_type": "RELATION_TYPE_IN_CAPS",
"source_node": "Name of the source semantic node",
"target_node": "Name of the target semantic node",
"fact": "Detailed description of the relationship",
"valid_at": "YYYY-MM-DDTHH:MM:SSZ or null if not explicitly mentioned",
"invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null if ongoing or not explicitly mentioned"
}}
]
}}
If no new edges need to be added, return an empty list for "new_edges".
""",
),
]
def v2(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that extracts graph edges from provided context.",
),
Message(
role="user",
content=f"""
Given the following context, extract new edges (relationships) that need to be added to the knowledge graph:
Nodes:
{json.dumps(context['nodes'], indent=2)}
New Episode:
Content: {context['episode_content']}
Previous Episodes:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
Extract new entity edges based on the content of the current episode, the given nodes, and context from previous episodes.
Guidelines:
1. Create edges only between the provided nodes.
2. Each edge should represent a clear relationship between two nodes.
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
4. Provide a more detailed fact describing the relationship.
5. Consider temporal aspects of relationships when relevant.
Respond with a JSON object in the following format:
{{
"edges": [
{{
"relation_type": "RELATION_TYPE_IN_CAPS",
"source_node_uuid": "uuid of the source entity node",
"target_node_uuid": "uuid of the target entity node",
"fact": "Detailed description of the relationship",
"valid_at": "YYYY-MM-DDTHH:MM:SSZ or null if not explicitly mentioned",
"invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null if ongoing or not explicitly mentioned"
}}
]
}}
If no new edges need to be added, return an empty list for "new_edges".
""",
),
]
versions: Versions = {"v1": v1, "v2": v2}

View file

@ -6,10 +6,12 @@ from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol): class Prompt(Protocol):
v1: PromptVersion v1: PromptVersion
v2: PromptVersion
class Versions(TypedDict): class Versions(TypedDict):
v1: PromptFunction v1: PromptFunction
v2: PromptFunction
def v1(context: dict[str, any]) -> list[Message]: def v1(context: dict[str, any]) -> list[Message]:
@ -60,6 +62,45 @@ def v1(context: dict[str, any]) -> list[Message]:
] ]
versions: Versions = { def v2(context: dict[str, any]) -> list[Message]:
"v1": v1, return [
} Message(
role="system",
content="You are a helpful assistant that extracts graph nodes from provided context.",
),
Message(
role="user",
content=f"""
Given the following context, extract new entity nodes that need to be added to the knowledge graph:
Previous Episodes:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
New Episode:
Content: {context["episode_content"]}
Extract new entity nodes based on the content of the current episode, while considering the context from previous episodes.
Guidelines:
1. Focus on entities, concepts, or actors that are central to the current episode.
2. Avoid creating nodes for relationships or actions (these will be handled as edges later).
3. Provide a brief but informative summary for each node.
Respond with a JSON object in the following format:
{{
"new_nodes": [
{{
"name": "Unique identifier for the node",
"labels": ["Entity", "OptionalAdditionalLabel"],
"summary": "Brief summary of the node's role or significance"
}}
]
}}
If no new nodes need to be added, return an empty list for "new_nodes".
""",
),
]
versions: Versions = {"v1": v1, "v2": v2}

View file

@ -8,21 +8,37 @@ from .extract_nodes import (
versions as extract_nodes_versions, versions as extract_nodes_versions,
) )
from .dedupe_nodes import (
Prompt as DedupeNodesPrompt,
Versions as DedupeNodesVersions,
versions as dedupe_nodes_versions,
)
from .extract_edges import ( from .extract_edges import (
Prompt as ExtractEdgesPrompt, Prompt as ExtractEdgesPrompt,
Versions as ExtractEdgesVersions, Versions as ExtractEdgesVersions,
versions as extract_edges_versions, versions as extract_edges_versions,
) )
from .dedupe_edges import (
Prompt as DedupeEdgesPrompt,
Versions as DedupeEdgesVersions,
versions as dedupe_edges_versions,
)
class PromptLibrary(Protocol): class PromptLibrary(Protocol):
extract_nodes: ExtractNodesPrompt extract_nodes: ExtractNodesPrompt
dedupe_nodes: DedupeNodesPrompt
extract_edges: ExtractEdgesPrompt extract_edges: ExtractEdgesPrompt
dedupe_edges: DedupeEdgesPrompt
class PromptLibraryImpl(TypedDict): class PromptLibraryImpl(TypedDict):
extract_nodes: ExtractNodesVersions extract_nodes: ExtractNodesVersions
dedupe_nodes: DedupeNodesVersions
extract_edges: ExtractEdgesVersions extract_edges: ExtractEdgesVersions
dedupe_edges: DedupeEdgesVersions
class VersionWrapper: class VersionWrapper:
@ -47,7 +63,9 @@ class PromptLibraryWrapper:
PROMPT_LIBRARY_IMPL: PromptLibraryImpl = { PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
"extract_nodes": extract_nodes_versions, "extract_nodes": extract_nodes_versions,
"dedupe_nodes": dedupe_nodes_versions,
"extract_edges": extract_edges_versions, "extract_edges": extract_edges_versions,
"dedupe_edges": dedupe_edges_versions,
} }
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)

View file

@ -1,45 +0,0 @@
from typing import Tuple
from core.edges import EpisodicEdge, EntityEdge, Edge
from core.nodes import EntityNode, EpisodicNode, Node
async def bfs(
nodes: list[Node], edges: list[Edge], k: int
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
# Breadth first search over nodes and edges with desired depth
async def similarity_search(
query: str, embedder
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
# vector similarity search over embedded facts
async def fulltext_search(
query: str,
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
# fulltext search over names and summary
def build_episodic_edges(
entity_nodes: list[EntityNode], episode: EpisodicNode
) -> list[EpisodicEdge]:
edges: list[EpisodicEdge] = []
for node in entity_nodes:
edges.append(
EpisodicEdge(
source_node=episode,
target_node=node,
created_at=episode.created_at,
)
)
return edges

View file

@ -13,15 +13,17 @@ logger = logging.getLogger(__name__)
def build_episodic_edges( def build_episodic_edges(
semantic_nodes: List[EntityNode], entity_nodes: List[EntityNode],
episode: EpisodicNode, episode: EpisodicNode,
transaction_from: datetime, transaction_from: datetime,
) -> List[EpisodicEdge]: ) -> List[EpisodicEdge]:
edges: List[EpisodicEdge] = [] edges: List[EpisodicEdge] = []
for node in semantic_nodes: for node in entity_nodes:
edge = EpisodicEdge( edge = EpisodicEdge(
source_node=episode, target_node=node, created_at=transaction_from source_node_uuid=episode.uuid,
target_node_uuid=node.uuid,
created_at=transaction_from,
) )
edges.append(edge) edges.append(edge)
@ -132,3 +134,94 @@ async def extract_new_edges(
affected_nodes.add(edge.source_node) affected_nodes.add(edge.source_node)
affected_nodes.add(edge.target_node) affected_nodes.add(edge.target_node)
return new_edges, list(affected_nodes) return new_edges, list(affected_nodes)
async def extract_edges(
llm_client: LLMClient,
episode: EpisodicNode,
nodes: list[EntityNode],
previous_episodes: list[EpisodicNode],
) -> list[EntityEdge]:
# Prepare context for LLM
context = {
"episode_content": episode.content,
"episode_timestamp": (
episode.valid_at.isoformat() if episode.valid_at else None
),
"nodes": [
{"uuid": node.uuid, "name": node.name, "summary": node.summary}
for node in nodes
],
"previous_episodes": [
{
"content": ep.content,
"timestamp": ep.valid_at.isoformat() if ep.valid_at else None,
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(
prompt_library.extract_edges.v2(context)
)
edges_data = llm_response.get("edges", [])
logger.info(f"Extracted new edges: {edges_data}")
# Convert the extracted data into EntityEdge objects
edges = []
for edge_data in edges_data:
edge = EntityEdge(
source_node_uuid=edge_data["source_node_uuid"],
target_node_uuid=edge_data["target_node_uuid"],
name=edge_data["relation_type"],
fact=edge_data["fact"],
episodes=[episode.uuid],
created_at=datetime.now(),
valid_at=edge_data["valid_at"],
invalid_at=edge_data["invalid_at"],
)
edges.append(edge)
logger.info(
f"Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})"
)
return edges
async def dedupe_extracted_edges(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
) -> list[EntityEdge]:
# Create edge map
edge_map = {}
for edge in existing_edges:
edge_map[edge.name] = edge
for edge in extracted_edges:
if edge.name in edge_map.keys():
continue
edge_map[edge.name] = edge
# Prepare context for LLM
context = {
"extracted_edges": [
{"name": edge.name, "fact": edge.fact} for edge in extracted_edges
],
"existing_edges": [
{"name": edge.name, "fact": edge.fact} for edge in extracted_edges
],
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.v1(context)
)
new_edges_data = llm_response.get("new_edges", [])
logger.info(f"Extracted new edges: {new_edges_data}")
# Get full edge data
edges = []
for edge_data in new_edges_data:
edge = edge_map[edge_data["name"]]
edges.append(edge)
return edges

View file

@ -61,3 +61,87 @@ async def extract_new_nodes(
logger.info(f"Node {node_data['name']} already exists, skipping creation.") logger.info(f"Node {node_data['name']} already exists, skipping creation.")
return new_nodes return new_nodes
async def extract_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> list[EntityNode]:
# Prepare context for LLM
context = {
"episode_content": episode.content,
"episode_timestamp": (
episode.valid_at.isoformat() if episode.valid_at else None
),
"previous_episodes": [
{
"content": ep.content,
"timestamp": ep.valid_at.isoformat() if ep.valid_at else None,
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.v2(context)
)
new_nodes_data = llm_response.get("new_nodes", [])
logger.info(f"Extracted new nodes: {new_nodes_data}")
# Convert the extracted data into EntityNode objects
new_nodes = []
for node_data in new_nodes_data:
new_node = EntityNode(
name=node_data["name"],
labels=node_data["labels"],
summary=node_data["summary"],
created_at=datetime.now(),
)
new_nodes.append(new_node)
logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})")
return new_nodes
async def dedupe_extracted_nodes(
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes: list[EntityNode],
) -> list[EntityNode]:
# build node map
node_map = {}
for node in existing_nodes:
node_map[node.name] = node
for node in extracted_nodes:
if node.name in node_map.keys():
continue
node_map[node.name] = node
# Prepare context for LLM
existing_nodes_context = [
{"name": node.name, "summary": node.summary} for node in existing_nodes
]
extracted_nodes_context = [
{"name": node.name, "summary": node.summary} for node in extracted_nodes
]
context = {
"existing_nodes": existing_nodes_context,
"extracted_nodes": extracted_nodes_context,
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.v1(context)
)
new_nodes_data = llm_response.get("new_nodes", [])
logger.info(f"Deduplicated nodes: {new_nodes_data}")
# Get full node data
nodes = []
for node_data in new_nodes_data:
node = node_map[node_data["name"]]
nodes.append(node)
return nodes

View file

View file

@ -0,0 +1,274 @@
import asyncio
import logging
from datetime import datetime
from neo4j import AsyncDriver
from core.edges import EntityEdge
from core.nodes import EntityNode
logger = logging.getLogger(__name__)
async def bfs(node_ids: list[str], driver: AsyncDriver):
records, _, _ = await driver.execute_query(
"""
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
RETURN
n.uuid AS source_node_uuid,
n.name AS source_name,
n.summary AS source_summary,
m.uuid AS target_node_uuid,
m.name AS target_name,
m.summary AS target_summary,
r.uuid AS uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
""",
node_ids=node_ids,
)
context = {}
for record in records:
n_uuid = record["source_node_uuid"]
if n_uuid in context.keys():
context[n_uuid]["facts"].append(record["fact"])
else:
context[n_uuid] = {
"name": record["source_name"],
"summary": record["source_summary"],
"facts": [record["fact"]],
}
m_uuid = record["target_node_uuid"]
if m_uuid not in context:
context[m_uuid] = {
"name": record["target_name"],
"summary": record["target_summary"],
"facts": [],
}
logger.info(f"bfs search returned context: {context}")
return context
async def edge_similarity_search(
search_vector: list[float], driver: AsyncDriver
) -> list[EntityEdge]:
# vector similarity search over embedded facts
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
YIELD relationship AS r, score
MATCH (n)-[r:RELATES_TO]->(m)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT 10
""",
search_vector=search_vector,
)
edges: list[EntityEdge] = []
now = datetime.now()
for record in records:
edge = EntityEdge(
uuid=record["uuid"],
source_node_uuid=record["source_node_uuid"],
target_node_uuid=record["target_node_uuid"],
fact=record["fact"],
name=record["name"],
episodes=record["episodes"],
fact_embedding=record["fact_embedding"],
created_at=now,
expired_at=now,
valid_at=now,
invalid_At=now,
)
edges.append(edge)
logger.info(f"similarity search results. RESULT: {[edge.uuid for edge in edges]}")
return edges
async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver
) -> list[EntityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("name_embedding", 5, $search_vector)
YIELD node AS n, score
RETURN
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
""",
search_vector=search_vector,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=[],
created_at=datetime.now(),
summary=record["summary"],
)
)
logger.info(f"name semantic search results. RESULT: {nodes}")
return nodes
async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = query + "~"
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
RETURN
node.uuid As uuid,
node.name AS name,
node.created_at AS created_at,
node.summary AS summary
ORDER BY score DESC
LIMIT 10
""",
query=fuzzy_query,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=[],
created_at=datetime.now(),
summary=record["summary"],
)
)
logger.info(f"fulltext search results. QUERY:{query}. RESULT: {nodes}")
return nodes
async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = query + "~"
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS r, score
MATCH (n:Entity)-[r]->(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT 10
""",
query=fuzzy_query,
)
edges: list[EntityEdge] = []
now = datetime.now()
for record in records:
edge = EntityEdge(
uuid=record["uuid"],
source_node_uuid=record["source_node_uuid"],
target_node_uuid=record["target_node_uuid"],
fact=record["fact"],
name=record["name"],
episodes=record["episodes"],
fact_embedding=record["fact_embedding"],
created_at=now,
expired_at=now,
valid_at=now,
invalid_At=now,
)
edges.append(edge)
logger.info(
f"similarity search results. QUERY:{query}. RESULT: {[edge.uuid for edge in edges]}"
)
return edges
async def get_relevant_nodes(
nodes: list[EntityNode],
driver: AsyncDriver,
) -> list[EntityNode]:
relevant_nodes: dict[str, EntityNode] = {}
results = await asyncio.gather(
*[entity_fulltext_search(node.name, driver) for node in nodes],
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
)
for result in results:
for node in result:
relevant_nodes[node.uuid] = node
logger.info(f"Found relevant nodes: {relevant_nodes.keys()}")
return relevant_nodes.values()
async def get_relevant_edges(
edges: list[EntityEdge],
driver: AsyncDriver,
) -> list[EntityEdge]:
relevant_edges: dict[str, EntityEdge] = {}
results = await asyncio.gather(
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
)
for result in results:
for edge in result:
relevant_edges[edge.uuid] = edge
logger.info(f"Found relevant nodes: {relevant_edges.keys()}")
return relevant_edges.values()

25
core/utils/utils.py Normal file
View file

@ -0,0 +1,25 @@
import logging
from neo4j import AsyncDriver
from core.edges import EpisodicEdge, EntityEdge, Edge
from core.nodes import EntityNode, EpisodicNode, Node
logger = logging.getLogger(__name__)
def build_episodic_edges(
entity_nodes: list[EntityNode], episode: EpisodicNode
) -> list[EpisodicEdge]:
edges: list[EpisodicEdge] = []
for node in entity_nodes:
edges.append(
EpisodicEdge(
source_node_uuid=episode,
target_node_uuid=node,
created_at=episode.created_at,
)
)
return edges

View file

@ -1,3 +1,5 @@
import logging
import sys
import os import os
import pytest import pytest
@ -9,9 +11,11 @@ from openai import OpenAI
from core.edges import EpisodicEdge, EntityEdge from core.edges import EpisodicEdge, EntityEdge
from core.graphiti import Graphiti from core.graphiti import Graphiti
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EpisodicNode, EntityNode from core.nodes import EpisodicNode, EntityNode
from datetime import datetime from datetime import datetime
pytest_plugins = ("pytest_asyncio",) pytest_plugins = ("pytest_asyncio",)
load_dotenv() load_dotenv()
@ -21,10 +25,59 @@ NEO4j_USER = os.getenv("NEO4J_USER")
NEO4j_PASSWORD = os.getenv("NEO4J_PASSWORD") NEO4j_PASSWORD = os.getenv("NEO4J_PASSWORD")
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
def format_context(context):
formatted_string = ""
for uuid, data in context.items():
formatted_string += f"UUID: {uuid}\n"
formatted_string += f" Name: {data['name']}\n"
formatted_string += f" Summary: {data['summary']}\n"
formatted_string += " Facts:\n"
for fact in data["facts"]:
formatted_string += f" - {fact}\n"
formatted_string += "\n"
return formatted_string.strip()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graphiti_init(): async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
await graphiti.build_indices() await graphiti.build_indices()
context = await graphiti.search("Freakenomics guest")
logger.info("QUERY: Freakenomics guest" + "RESULT:" + format_context(context))
context = await graphiti.search("tania tetlow")
logger.info("QUERY: Tania Tetlow" + "RESULT:" + format_context(context))
context = await graphiti.search("issues with higher ed")
logger.info("QUERY: issues with higher ed" + "RESULT:" + format_context(context))
graphiti.close() graphiti.close()
@ -57,16 +110,16 @@ async def test_graph_integration():
bob_node = EntityNode(name="Bob", labels=[], created_at=now, summary="Bob summary") bob_node = EntityNode(name="Bob", labels=[], created_at=now, summary="Bob summary")
episodic_edge_1 = EpisodicEdge( episodic_edge_1 = EpisodicEdge(
source_node=episode, target_node=alice_node, created_at=now source_node_uuid=episode, target_node_uuid=alice_node, created_at=now
) )
episodic_edge_2 = EpisodicEdge( episodic_edge_2 = EpisodicEdge(
source_node=episode, target_node=bob_node, created_at=now source_node_uuid=episode, target_node_uuid=bob_node, created_at=now
) )
entity_edge = EntityEdge( entity_edge = EntityEdge(
source_node=alice_node, source_node_uuid=alice_node.uuid,
target_node=bob_node, target_node_uuid=bob_node.uuid,
created_at=now, created_at=now,
name="likes", name="likes",
fact="Alice likes Bob", fact="Alice likes Bob",