Update Maintenance LLM Queries and Partial Schema Retrieval (#6)

* search updates

* add search_utils

* updates

* graph maintenance updates

* revert extract_new_nodes

* revert extract_new_edges

* parallelize node searching

* add edge fulltext search

* search optimizations
This commit is contained in:
Preston Rasmussen 2024-08-18 13:22:31 -04:00 committed by GitHub
parent ad552b527e
commit 4db3906049
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 953 additions and 119 deletions

View file

@ -5,15 +5,16 @@ from neo4j import AsyncDriver
from uuid import uuid4
import logging
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import Node
logger = logging.getLogger(__name__)
class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
source_node: Node
target_node: Node
uuid: str = Field(default_factory=lambda: uuid4().hex)
source_node_uuid: str
target_node_uuid: str
created_at: datetime
@abstractmethod
@ -30,11 +31,6 @@ class Edge(BaseModel, ABC):
class EpisodicEdge(Edge):
async def save(self, driver: AsyncDriver):
if self.uuid is None:
uuid = uuid4()
logger.info(f"Created uuid: {uuid} for episodic edge")
self.uuid = str(uuid)
result = await driver.execute_query(
"""
MATCH (episode:Episodic {uuid: $episode_uuid})
@ -42,8 +38,8 @@ class EpisodicEdge(Edge):
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, created_at: $created_at}
RETURN r.uuid AS uuid""",
episode_uuid=self.source_node.uuid,
entity_uuid=self.target_node.uuid,
episode_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
created_at=self.created_at,
)
@ -79,10 +75,10 @@ class EntityEdge(Edge):
default=None, description="datetime of when the fact stopped being true"
)
def generate_embedding(self, embedder, model="text-embedding-3-large"):
async def generate_embedding(self, embedder, model="text-embedding-3-small"):
text = self.fact.replace("\n", " ")
embedding = embedder.create(input=[text], model=model).data[0].embedding
self.fact_embedding = embedding
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.fact_embedding = embedding[:EMBEDDING_DIM]
return embedding
@ -96,8 +92,8 @@ class EntityEdge(Edge):
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
valid_at: $valid_at, invalid_at: $invalid_at}
RETURN r.uuid AS uuid""",
source_uuid=self.source_node.uuid,
target_uuid=self.target_node.uuid,
source_uuid=self.source_node_uuid,
target_uuid=self.target_node_uuid,
uuid=self.uuid,
name=self.name,
fact=self.fact,

View file

@ -1,12 +1,14 @@
import asyncio
from datetime import datetime
import logging
from typing import Callable, LiteralString, Tuple
from typing import Callable, LiteralString
from neo4j import AsyncGraphDatabase
from dotenv import load_dotenv
import os
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, Edge
from core.edges import EntityEdge, Edge, EpisodicEdge
from core.utils import (
build_episodic_edges,
retrieve_relevant_schema,
@ -16,6 +18,15 @@ from core.utils import (
retrieve_episodes,
)
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.search.search_utils import (
edge_similarity_search,
entity_fulltext_search,
bfs,
get_relevant_nodes,
get_relevant_edges,
)
logger = logging.getLogger(__name__)
@ -34,7 +45,7 @@ class Graphiti:
self.llm_client = OpenAIClient(
LLMConfig(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o",
model="gpt-4o-mini",
base_url="https://api.openai.com/v1",
)
)
@ -75,8 +86,12 @@ class Graphiti:
):
"""Process an episode and update the graph"""
try:
nodes: list[Node] = []
edges: list[Edge] = []
nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.client.embeddings
now = datetime.now()
previous_episodes = await self.retrieve_episodes(last_n=3)
episode = EpisodicNode(
name=name,
@ -84,38 +99,65 @@ class Graphiti:
source="messages",
content=episode_body,
source_description=source_description,
created_at=datetime.now(),
created_at=now,
valid_at=reference_time,
)
# await episode.save(self.driver)
relevant_schema = await self.retrieve_relevant_schema(episode.content)
new_nodes = await extract_new_nodes(
self.llm_client, episode, relevant_schema, previous_episodes
# relevant_schema = await self.retrieve_relevant_schema(episode.content)
extracted_nodes = await extract_nodes(
self.llm_client, episode, previous_episodes
)
# Calculate Embeddings
await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
)
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes
)
nodes.extend(new_nodes)
new_edges, affected_nodes = await extract_new_edges(
self.llm_client, episode, new_nodes, relevant_schema, previous_episodes
extracted_edges = await extract_edges(
self.llm_client, episode, new_nodes, previous_episodes
)
edges.extend(new_edges)
episodic_edges = build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
list(set(nodes + affected_nodes)),
episode,
datetime.now(),
await asyncio.gather(
*[edge.generate_embedding(embedder) for edge in extracted_edges]
)
existing_edges = await get_relevant_edges(extracted_edges, self.driver)
new_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
)
entity_edges.extend(new_edges)
episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
nodes,
episode,
now,
)
)
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
nodes.append(episode)
logger.info(f"Built episodic edges: {episodic_edges}")
edges.extend(episodic_edges)
# invalidated_edges = await self.invalidate_edges(
# episode, new_nodes, new_edges, relevant_schema, previous_episodes
# )
# edges.extend(invalidated_edges)
# Future optimization would be using batch operations to save nodes and edges
await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
# for node in nodes:
# if isinstance(node, EntityNode):
# await node.update_summary(self.driver)
@ -129,6 +171,10 @@ class Graphiti:
async def build_indices(self):
index_queries: list[LiteralString] = [
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.uuid)",
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[r:MENTIONS]-() ON (r.uuid)",
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
@ -143,13 +189,19 @@ class Graphiti:
for query in index_queries:
await self.driver.execute_query(query)
# Add the entity indices
# Add the semantic indices
await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
"""
)
await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON EACH [r.name, r.fact]
"""
)
await self.driver.execute_query(
"""
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
@ -161,29 +213,40 @@ class Graphiti:
"""
)
async def search(
self, query: str, config
) -> (list)[tuple[EntityNode, list[EntityEdge]]]:
(vec_nodes, vec_edges) = similarity_search(query, embedder)
(text_nodes, text_edges) = fulltext_search(query)
await self.driver.execute_query(
"""
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
FOR (n:Entity) ON (n.name_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
"""
)
nodes = vec_nodes.extend(text_nodes)
edges = vec_edges.extend(text_edges)
async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
text = query.replace("\n", " ")
search_vector = (
(
await self.llm_client.client.embeddings.create(
input=[text], model="text-embedding-3-small"
)
)
.data[0]
.embedding[:EMBEDDING_DIM]
)
results = bfs(nodes, edges, k=1)
edges = await edge_similarity_search(search_vector, self.driver)
nodes = await entity_fulltext_search(query, self.driver)
episode_ids = ["Mode of episode ids"]
node_ids = [node.uuid for node in nodes]
episodes = get_episodes(episode_ids[:episode_count])
for edge in edges:
node_ids.append(edge.source_node_uuid)
node_ids.append(edge.target_node_uuid)
return [(node, edges)], episodes
node_ids = list(dict.fromkeys(node_ids))
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
context = await bfs(node_ids, self.driver)
return context

View file

@ -1,3 +1,6 @@
EMBEDDING_DIM = 1024
class LLMConfig:
"""
Configuration class for the Language Learning Model (LLM).
@ -10,7 +13,7 @@ class LLMConfig:
def __init__(
self,
api_key: str,
model: str = "gpt-4o",
model: str = "gpt-4o-mini",
base_url: str = "https://api.openai.com",
):
"""
@ -21,7 +24,7 @@ class LLMConfig:
This is required for making authorized requests.
model (str, optional): The specific LLM model to use for generating responses.
Defaults to "gpt-4o", which appears to be a custom model name.
Defaults to "gpt-4o-mini", which appears to be a custom model name.
Common values might include "gpt-3.5-turbo" or "gpt-4".
base_url (str, optional): The base URL of the LLM API service.

View file

@ -8,11 +8,13 @@ from pydantic import BaseModel, Field
from neo4j import AsyncDriver
import logging
from core.llm_client.config import EMBEDDING_DIM
logger = logging.getLogger(__name__)
class Node(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
uuid: str = Field(default_factory=lambda: uuid4().hex)
name: str
labels: list[str] = Field(default_factory=list)
created_at: datetime
@ -66,21 +68,32 @@ class EpisodicNode(Node):
class EntityNode(Node):
name_embedding: list[float] | None = Field(
default=None, description="embedding of the name"
)
summary: str = Field(description="regional summary of surrounding edges")
async def update_summary(self, driver: AsyncDriver): ...
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
async def generate_name_embedding(self, embedder, model="text-embedding-3-small"):
text = self.name.replace("\n", " ")
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.name_embedding = embedding[:EMBEDDING_DIM]
return embedding
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, summary: $summary, created_at: $created_at}
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at}
RETURN n.uuid AS uuid""",
uuid=self.uuid,
name=self.name,
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
)

View file

@ -0,0 +1,56 @@
import json
from typing import TypedDict, Protocol
from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol):
v1: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that de-duplicates relationship from edge lists.",
),
Message(
role="user",
content=f"""
Given the following context, deduplicate edges from a list of new edges given a list of existing edges:
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}
New Edges:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
1. start with the list of edges from New Edges
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
edge in the list
3. Respond with the resulting list of edges
Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates,
duplicate edges may have different names
Respond with a JSON object in the following format:
{{
"new_edges": [
{{
"name": "Unique identifier for the edge",
"fact": "one sentence description of the fact"
}}
]
}}
""",
),
]
versions: Versions = {"v1": v1}

View file

@ -0,0 +1,56 @@
import json
from typing import TypedDict, Protocol
from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol):
v1: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that de-duplicates nodes from node lists.",
),
Message(
role="user",
content=f"""
Given the following context, deduplicate nodes from a list of new nodes given a list of existing nodes:
Existing Nodes:
{json.dumps(context['existing_nodes'], indent=2)}
New Nodes:
{json.dumps(context['extracted_nodes'], indent=2)}
Task:
1. start with the list of nodes from New Nodes
2. If any node in New Nodes is a duplicate of a node in Existing Nodes, replace the new node with the existing
node in the list
3. Respond with the resulting list of nodes
Guidelines:
1. Use both the name and summary of nodes to determine if they are duplicates,
duplicate nodes may have different names
Respond with a JSON object in the following format:
{{
"new_nodes": [
{{
"name": "Unique identifier for the node",
"summary": "Brief summary of the node's role or significance"
}}
]
}}
""",
),
]
versions: Versions = {"v1": v1}

View file

@ -6,10 +6,12 @@ from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
@ -68,6 +70,108 @@ def v1(context: dict[str, any]) -> list[Message]:
]
versions: Versions = {
"v1": v1,
}
def v1(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that extracts graph edges from provided context.",
),
Message(
role="user",
content=f"""
Given the following context, extract new semantic edges (relationships) that need to be added to the knowledge graph:
Current Graph Structure:
{context['relevant_schema']}
New Nodes:
{json.dumps(context['new_nodes'], indent=2)}
New Episode:
Content: {context['episode_content']}
Timestamp: {context['episode_timestamp']}
Previous Episodes:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
Extract new semantic edges based on the content of the current episode, considering the existing graph structure, new nodes, and context from previous episodes.
Guidelines:
1. Create edges only between semantic nodes (not episodic nodes like messages).
2. Each edge should represent a clear relationship between two semantic nodes.
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
4. Provide a more detailed fact describing the relationship.
5. If a relationship seems to update an existing one, create a new edge with the updated information.
6. Consider temporal aspects of relationships when relevant.
7. Do not create edges involving episodic nodes (like Message 1 or Message 2).
8. Use existing nodes from the current graph structure when appropriate.
Respond with a JSON object in the following format:
{{
"new_edges": [
{{
"relation_type": "RELATION_TYPE_IN_CAPS",
"source_node": "Name of the source semantic node",
"target_node": "Name of the target semantic node",
"fact": "Detailed description of the relationship",
"valid_at": "YYYY-MM-DDTHH:MM:SSZ or null if not explicitly mentioned",
"invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null if ongoing or not explicitly mentioned"
}}
]
}}
If no new edges need to be added, return an empty list for "new_edges".
""",
),
]
def v2(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that extracts graph edges from provided context.",
),
Message(
role="user",
content=f"""
Given the following context, extract new edges (relationships) that need to be added to the knowledge graph:
Nodes:
{json.dumps(context['nodes'], indent=2)}
New Episode:
Content: {context['episode_content']}
Previous Episodes:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
Extract new entity edges based on the content of the current episode, the given nodes, and context from previous episodes.
Guidelines:
1. Create edges only between the provided nodes.
2. Each edge should represent a clear relationship between two nodes.
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
4. Provide a more detailed fact describing the relationship.
5. Consider temporal aspects of relationships when relevant.
Respond with a JSON object in the following format:
{{
"edges": [
{{
"relation_type": "RELATION_TYPE_IN_CAPS",
"source_node_uuid": "uuid of the source entity node",
"target_node_uuid": "uuid of the target entity node",
"fact": "Detailed description of the relationship",
"valid_at": "YYYY-MM-DDTHH:MM:SSZ or null if not explicitly mentioned",
"invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null if ongoing or not explicitly mentioned"
}}
]
}}
If no new edges need to be added, return an empty list for "new_edges".
""",
),
]
versions: Versions = {"v1": v1, "v2": v2}

View file

@ -6,10 +6,12 @@ from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
@ -60,6 +62,45 @@ def v1(context: dict[str, any]) -> list[Message]:
]
versions: Versions = {
"v1": v1,
}
def v2(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that extracts graph nodes from provided context.",
),
Message(
role="user",
content=f"""
Given the following context, extract new entity nodes that need to be added to the knowledge graph:
Previous Episodes:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
New Episode:
Content: {context["episode_content"]}
Extract new entity nodes based on the content of the current episode, while considering the context from previous episodes.
Guidelines:
1. Focus on entities, concepts, or actors that are central to the current episode.
2. Avoid creating nodes for relationships or actions (these will be handled as edges later).
3. Provide a brief but informative summary for each node.
Respond with a JSON object in the following format:
{{
"new_nodes": [
{{
"name": "Unique identifier for the node",
"labels": ["Entity", "OptionalAdditionalLabel"],
"summary": "Brief summary of the node's role or significance"
}}
]
}}
If no new nodes need to be added, return an empty list for "new_nodes".
""",
),
]
versions: Versions = {"v1": v1, "v2": v2}

View file

@ -8,21 +8,37 @@ from .extract_nodes import (
versions as extract_nodes_versions,
)
from .dedupe_nodes import (
Prompt as DedupeNodesPrompt,
Versions as DedupeNodesVersions,
versions as dedupe_nodes_versions,
)
from .extract_edges import (
Prompt as ExtractEdgesPrompt,
Versions as ExtractEdgesVersions,
versions as extract_edges_versions,
)
from .dedupe_edges import (
Prompt as DedupeEdgesPrompt,
Versions as DedupeEdgesVersions,
versions as dedupe_edges_versions,
)
class PromptLibrary(Protocol):
extract_nodes: ExtractNodesPrompt
dedupe_nodes: DedupeNodesPrompt
extract_edges: ExtractEdgesPrompt
dedupe_edges: DedupeEdgesPrompt
class PromptLibraryImpl(TypedDict):
extract_nodes: ExtractNodesVersions
dedupe_nodes: DedupeNodesVersions
extract_edges: ExtractEdgesVersions
dedupe_edges: DedupeEdgesVersions
class VersionWrapper:
@ -47,7 +63,9 @@ class PromptLibraryWrapper:
PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
"extract_nodes": extract_nodes_versions,
"dedupe_nodes": dedupe_nodes_versions,
"extract_edges": extract_edges_versions,
"dedupe_edges": dedupe_edges_versions,
}
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)

View file

@ -1,45 +0,0 @@
from typing import Tuple
from core.edges import EpisodicEdge, EntityEdge, Edge
from core.nodes import EntityNode, EpisodicNode, Node
async def bfs(
nodes: list[Node], edges: list[Edge], k: int
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
# Breadth first search over nodes and edges with desired depth
async def similarity_search(
query: str, embedder
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
# vector similarity search over embedded facts
async def fulltext_search(
query: str,
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
# fulltext search over names and summary
def build_episodic_edges(
entity_nodes: list[EntityNode], episode: EpisodicNode
) -> list[EpisodicEdge]:
edges: list[EpisodicEdge] = []
for node in entity_nodes:
edges.append(
EpisodicEdge(
source_node=episode,
target_node=node,
created_at=episode.created_at,
)
)
return edges

View file

@ -13,15 +13,17 @@ logger = logging.getLogger(__name__)
def build_episodic_edges(
semantic_nodes: List[EntityNode],
entity_nodes: List[EntityNode],
episode: EpisodicNode,
transaction_from: datetime,
) -> List[EpisodicEdge]:
edges: List[EpisodicEdge] = []
for node in semantic_nodes:
for node in entity_nodes:
edge = EpisodicEdge(
source_node=episode, target_node=node, created_at=transaction_from
source_node_uuid=episode.uuid,
target_node_uuid=node.uuid,
created_at=transaction_from,
)
edges.append(edge)
@ -132,3 +134,94 @@ async def extract_new_edges(
affected_nodes.add(edge.source_node)
affected_nodes.add(edge.target_node)
return new_edges, list(affected_nodes)
async def extract_edges(
llm_client: LLMClient,
episode: EpisodicNode,
nodes: list[EntityNode],
previous_episodes: list[EpisodicNode],
) -> list[EntityEdge]:
# Prepare context for LLM
context = {
"episode_content": episode.content,
"episode_timestamp": (
episode.valid_at.isoformat() if episode.valid_at else None
),
"nodes": [
{"uuid": node.uuid, "name": node.name, "summary": node.summary}
for node in nodes
],
"previous_episodes": [
{
"content": ep.content,
"timestamp": ep.valid_at.isoformat() if ep.valid_at else None,
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(
prompt_library.extract_edges.v2(context)
)
edges_data = llm_response.get("edges", [])
logger.info(f"Extracted new edges: {edges_data}")
# Convert the extracted data into EntityEdge objects
edges = []
for edge_data in edges_data:
edge = EntityEdge(
source_node_uuid=edge_data["source_node_uuid"],
target_node_uuid=edge_data["target_node_uuid"],
name=edge_data["relation_type"],
fact=edge_data["fact"],
episodes=[episode.uuid],
created_at=datetime.now(),
valid_at=edge_data["valid_at"],
invalid_at=edge_data["invalid_at"],
)
edges.append(edge)
logger.info(
f"Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})"
)
return edges
async def dedupe_extracted_edges(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
) -> list[EntityEdge]:
# Create edge map
edge_map = {}
for edge in existing_edges:
edge_map[edge.name] = edge
for edge in extracted_edges:
if edge.name in edge_map.keys():
continue
edge_map[edge.name] = edge
# Prepare context for LLM
context = {
"extracted_edges": [
{"name": edge.name, "fact": edge.fact} for edge in extracted_edges
],
"existing_edges": [
{"name": edge.name, "fact": edge.fact} for edge in extracted_edges
],
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.v1(context)
)
new_edges_data = llm_response.get("new_edges", [])
logger.info(f"Extracted new edges: {new_edges_data}")
# Get full edge data
edges = []
for edge_data in new_edges_data:
edge = edge_map[edge_data["name"]]
edges.append(edge)
return edges

View file

@ -61,3 +61,87 @@ async def extract_new_nodes(
logger.info(f"Node {node_data['name']} already exists, skipping creation.")
return new_nodes
async def extract_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> list[EntityNode]:
# Prepare context for LLM
context = {
"episode_content": episode.content,
"episode_timestamp": (
episode.valid_at.isoformat() if episode.valid_at else None
),
"previous_episodes": [
{
"content": ep.content,
"timestamp": ep.valid_at.isoformat() if ep.valid_at else None,
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.v2(context)
)
new_nodes_data = llm_response.get("new_nodes", [])
logger.info(f"Extracted new nodes: {new_nodes_data}")
# Convert the extracted data into EntityNode objects
new_nodes = []
for node_data in new_nodes_data:
new_node = EntityNode(
name=node_data["name"],
labels=node_data["labels"],
summary=node_data["summary"],
created_at=datetime.now(),
)
new_nodes.append(new_node)
logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})")
return new_nodes
async def dedupe_extracted_nodes(
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes: list[EntityNode],
) -> list[EntityNode]:
# build node map
node_map = {}
for node in existing_nodes:
node_map[node.name] = node
for node in extracted_nodes:
if node.name in node_map.keys():
continue
node_map[node.name] = node
# Prepare context for LLM
existing_nodes_context = [
{"name": node.name, "summary": node.summary} for node in existing_nodes
]
extracted_nodes_context = [
{"name": node.name, "summary": node.summary} for node in extracted_nodes
]
context = {
"existing_nodes": existing_nodes_context,
"extracted_nodes": extracted_nodes_context,
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.v1(context)
)
new_nodes_data = llm_response.get("new_nodes", [])
logger.info(f"Deduplicated nodes: {new_nodes_data}")
# Get full node data
nodes = []
for node_data in new_nodes_data:
node = node_map[node_data["name"]]
nodes.append(node)
return nodes

View file

View file

@ -0,0 +1,274 @@
import asyncio
import logging
from datetime import datetime
from neo4j import AsyncDriver
from core.edges import EntityEdge
from core.nodes import EntityNode
logger = logging.getLogger(__name__)
async def bfs(node_ids: list[str], driver: AsyncDriver):
records, _, _ = await driver.execute_query(
"""
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
RETURN
n.uuid AS source_node_uuid,
n.name AS source_name,
n.summary AS source_summary,
m.uuid AS target_node_uuid,
m.name AS target_name,
m.summary AS target_summary,
r.uuid AS uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
""",
node_ids=node_ids,
)
context = {}
for record in records:
n_uuid = record["source_node_uuid"]
if n_uuid in context.keys():
context[n_uuid]["facts"].append(record["fact"])
else:
context[n_uuid] = {
"name": record["source_name"],
"summary": record["source_summary"],
"facts": [record["fact"]],
}
m_uuid = record["target_node_uuid"]
if m_uuid not in context:
context[m_uuid] = {
"name": record["target_name"],
"summary": record["target_summary"],
"facts": [],
}
logger.info(f"bfs search returned context: {context}")
return context
async def edge_similarity_search(
search_vector: list[float], driver: AsyncDriver
) -> list[EntityEdge]:
# vector similarity search over embedded facts
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
YIELD relationship AS r, score
MATCH (n)-[r:RELATES_TO]->(m)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT 10
""",
search_vector=search_vector,
)
edges: list[EntityEdge] = []
now = datetime.now()
for record in records:
edge = EntityEdge(
uuid=record["uuid"],
source_node_uuid=record["source_node_uuid"],
target_node_uuid=record["target_node_uuid"],
fact=record["fact"],
name=record["name"],
episodes=record["episodes"],
fact_embedding=record["fact_embedding"],
created_at=now,
expired_at=now,
valid_at=now,
invalid_At=now,
)
edges.append(edge)
logger.info(f"similarity search results. RESULT: {[edge.uuid for edge in edges]}")
return edges
async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver
) -> list[EntityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("name_embedding", 5, $search_vector)
YIELD node AS n, score
RETURN
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
""",
search_vector=search_vector,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=[],
created_at=datetime.now(),
summary=record["summary"],
)
)
logger.info(f"name semantic search results. RESULT: {nodes}")
return nodes
async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = query + "~"
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
RETURN
node.uuid As uuid,
node.name AS name,
node.created_at AS created_at,
node.summary AS summary
ORDER BY score DESC
LIMIT 10
""",
query=fuzzy_query,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=[],
created_at=datetime.now(),
summary=record["summary"],
)
)
logger.info(f"fulltext search results. QUERY:{query}. RESULT: {nodes}")
return nodes
async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = query + "~"
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS r, score
MATCH (n:Entity)-[r]->(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT 10
""",
query=fuzzy_query,
)
edges: list[EntityEdge] = []
now = datetime.now()
for record in records:
edge = EntityEdge(
uuid=record["uuid"],
source_node_uuid=record["source_node_uuid"],
target_node_uuid=record["target_node_uuid"],
fact=record["fact"],
name=record["name"],
episodes=record["episodes"],
fact_embedding=record["fact_embedding"],
created_at=now,
expired_at=now,
valid_at=now,
invalid_At=now,
)
edges.append(edge)
logger.info(
f"similarity search results. QUERY:{query}. RESULT: {[edge.uuid for edge in edges]}"
)
return edges
async def get_relevant_nodes(
nodes: list[EntityNode],
driver: AsyncDriver,
) -> list[EntityNode]:
relevant_nodes: dict[str, EntityNode] = {}
results = await asyncio.gather(
*[entity_fulltext_search(node.name, driver) for node in nodes],
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
)
for result in results:
for node in result:
relevant_nodes[node.uuid] = node
logger.info(f"Found relevant nodes: {relevant_nodes.keys()}")
return relevant_nodes.values()
async def get_relevant_edges(
edges: list[EntityEdge],
driver: AsyncDriver,
) -> list[EntityEdge]:
relevant_edges: dict[str, EntityEdge] = {}
results = await asyncio.gather(
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
)
for result in results:
for edge in result:
relevant_edges[edge.uuid] = edge
logger.info(f"Found relevant nodes: {relevant_edges.keys()}")
return relevant_edges.values()

25
core/utils/utils.py Normal file
View file

@ -0,0 +1,25 @@
import logging
from neo4j import AsyncDriver
from core.edges import EpisodicEdge, EntityEdge, Edge
from core.nodes import EntityNode, EpisodicNode, Node
logger = logging.getLogger(__name__)
def build_episodic_edges(
entity_nodes: list[EntityNode], episode: EpisodicNode
) -> list[EpisodicEdge]:
edges: list[EpisodicEdge] = []
for node in entity_nodes:
edges.append(
EpisodicEdge(
source_node_uuid=episode,
target_node_uuid=node,
created_at=episode.created_at,
)
)
return edges

View file

@ -1,3 +1,5 @@
import logging
import sys
import os
import pytest
@ -9,9 +11,11 @@ from openai import OpenAI
from core.edges import EpisodicEdge, EntityEdge
from core.graphiti import Graphiti
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EpisodicNode, EntityNode
from datetime import datetime
pytest_plugins = ("pytest_asyncio",)
load_dotenv()
@ -21,10 +25,59 @@ NEO4j_USER = os.getenv("NEO4J_USER")
NEO4j_PASSWORD = os.getenv("NEO4J_PASSWORD")
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
def format_context(context):
formatted_string = ""
for uuid, data in context.items():
formatted_string += f"UUID: {uuid}\n"
formatted_string += f" Name: {data['name']}\n"
formatted_string += f" Summary: {data['summary']}\n"
formatted_string += " Facts:\n"
for fact in data["facts"]:
formatted_string += f" - {fact}\n"
formatted_string += "\n"
return formatted_string.strip()
@pytest.mark.asyncio
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
await graphiti.build_indices()
context = await graphiti.search("Freakenomics guest")
logger.info("QUERY: Freakenomics guest" + "RESULT:" + format_context(context))
context = await graphiti.search("tania tetlow")
logger.info("QUERY: Tania Tetlow" + "RESULT:" + format_context(context))
context = await graphiti.search("issues with higher ed")
logger.info("QUERY: issues with higher ed" + "RESULT:" + format_context(context))
graphiti.close()
@ -57,16 +110,16 @@ async def test_graph_integration():
bob_node = EntityNode(name="Bob", labels=[], created_at=now, summary="Bob summary")
episodic_edge_1 = EpisodicEdge(
source_node=episode, target_node=alice_node, created_at=now
source_node_uuid=episode, target_node_uuid=alice_node, created_at=now
)
episodic_edge_2 = EpisodicEdge(
source_node=episode, target_node=bob_node, created_at=now
source_node_uuid=episode, target_node_uuid=bob_node, created_at=now
)
entity_edge = EntityEdge(
source_node=alice_node,
target_node=bob_node,
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name="likes",
fact="Alice likes Bob",