Update Maintenance LLM Queries and Partial Schema Retrieval (#6)
* search updates * add search_utils * updates * graph maintenance updates * revert extract_new_nodes * revert extract_new_edges * parallelize node searching * add edge fulltext search * search optimizations
This commit is contained in:
parent
ad552b527e
commit
4db3906049
16 changed files with 953 additions and 119 deletions
|
|
@ -5,15 +5,16 @@ from neo4j import AsyncDriver
|
|||
from uuid import uuid4
|
||||
import logging
|
||||
|
||||
from core.llm_client.config import EMBEDDING_DIM
|
||||
from core.nodes import Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Edge(BaseModel, ABC):
|
||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||
source_node: Node
|
||||
target_node: Node
|
||||
uuid: str = Field(default_factory=lambda: uuid4().hex)
|
||||
source_node_uuid: str
|
||||
target_node_uuid: str
|
||||
created_at: datetime
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -30,11 +31,6 @@ class Edge(BaseModel, ABC):
|
|||
|
||||
class EpisodicEdge(Edge):
|
||||
async def save(self, driver: AsyncDriver):
|
||||
if self.uuid is None:
|
||||
uuid = uuid4()
|
||||
logger.info(f"Created uuid: {uuid} for episodic edge")
|
||||
self.uuid = str(uuid)
|
||||
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||
|
|
@ -42,8 +38,8 @@ class EpisodicEdge(Edge):
|
|||
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
||||
SET r = {uuid: $uuid, created_at: $created_at}
|
||||
RETURN r.uuid AS uuid""",
|
||||
episode_uuid=self.source_node.uuid,
|
||||
entity_uuid=self.target_node.uuid,
|
||||
episode_uuid=self.source_node_uuid,
|
||||
entity_uuid=self.target_node_uuid,
|
||||
uuid=self.uuid,
|
||||
created_at=self.created_at,
|
||||
)
|
||||
|
|
@ -79,10 +75,10 @@ class EntityEdge(Edge):
|
|||
default=None, description="datetime of when the fact stopped being true"
|
||||
)
|
||||
|
||||
def generate_embedding(self, embedder, model="text-embedding-3-large"):
|
||||
async def generate_embedding(self, embedder, model="text-embedding-3-small"):
|
||||
text = self.fact.replace("\n", " ")
|
||||
embedding = embedder.create(input=[text], model=model).data[0].embedding
|
||||
self.fact_embedding = embedding
|
||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||
self.fact_embedding = embedding[:EMBEDDING_DIM]
|
||||
|
||||
return embedding
|
||||
|
||||
|
|
@ -96,8 +92,8 @@ class EntityEdge(Edge):
|
|||
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
|
||||
valid_at: $valid_at, invalid_at: $invalid_at}
|
||||
RETURN r.uuid AS uuid""",
|
||||
source_uuid=self.source_node.uuid,
|
||||
target_uuid=self.target_node.uuid,
|
||||
source_uuid=self.source_node_uuid,
|
||||
target_uuid=self.target_node_uuid,
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
fact=self.fact,
|
||||
|
|
|
|||
147
core/graphiti.py
147
core/graphiti.py
|
|
@ -1,12 +1,14 @@
|
|||
import asyncio
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Callable, LiteralString, Tuple
|
||||
from typing import Callable, LiteralString
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
from core.llm_client.config import EMBEDDING_DIM
|
||||
from core.nodes import EntityNode, EpisodicNode, Node
|
||||
from core.edges import EntityEdge, Edge
|
||||
from core.edges import EntityEdge, Edge, EpisodicEdge
|
||||
from core.utils import (
|
||||
build_episodic_edges,
|
||||
retrieve_relevant_schema,
|
||||
|
|
@ -16,6 +18,15 @@ from core.utils import (
|
|||
retrieve_episodes,
|
||||
)
|
||||
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
|
||||
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
|
||||
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
||||
from core.utils.search.search_utils import (
|
||||
edge_similarity_search,
|
||||
entity_fulltext_search,
|
||||
bfs,
|
||||
get_relevant_nodes,
|
||||
get_relevant_edges,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -34,7 +45,7 @@ class Graphiti:
|
|||
self.llm_client = OpenAIClient(
|
||||
LLMConfig(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o",
|
||||
model="gpt-4o-mini",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
)
|
||||
|
|
@ -75,8 +86,12 @@ class Graphiti:
|
|||
):
|
||||
"""Process an episode and update the graph"""
|
||||
try:
|
||||
nodes: list[Node] = []
|
||||
edges: list[Edge] = []
|
||||
nodes: list[EntityNode] = []
|
||||
entity_edges: list[EntityEdge] = []
|
||||
episodic_edges: list[EpisodicEdge] = []
|
||||
embedder = self.llm_client.client.embeddings
|
||||
now = datetime.now()
|
||||
|
||||
previous_episodes = await self.retrieve_episodes(last_n=3)
|
||||
episode = EpisodicNode(
|
||||
name=name,
|
||||
|
|
@ -84,38 +99,65 @@ class Graphiti:
|
|||
source="messages",
|
||||
content=episode_body,
|
||||
source_description=source_description,
|
||||
created_at=datetime.now(),
|
||||
created_at=now,
|
||||
valid_at=reference_time,
|
||||
)
|
||||
# await episode.save(self.driver)
|
||||
relevant_schema = await self.retrieve_relevant_schema(episode.content)
|
||||
new_nodes = await extract_new_nodes(
|
||||
self.llm_client, episode, relevant_schema, previous_episodes
|
||||
# relevant_schema = await self.retrieve_relevant_schema(episode.content)
|
||||
|
||||
extracted_nodes = await extract_nodes(
|
||||
self.llm_client, episode, previous_episodes
|
||||
)
|
||||
|
||||
# Calculate Embeddings
|
||||
|
||||
await asyncio.gather(
|
||||
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
||||
)
|
||||
|
||||
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
|
||||
|
||||
new_nodes = await dedupe_extracted_nodes(
|
||||
self.llm_client, extracted_nodes, existing_nodes
|
||||
)
|
||||
nodes.extend(new_nodes)
|
||||
new_edges, affected_nodes = await extract_new_edges(
|
||||
self.llm_client, episode, new_nodes, relevant_schema, previous_episodes
|
||||
|
||||
extracted_edges = await extract_edges(
|
||||
self.llm_client, episode, new_nodes, previous_episodes
|
||||
)
|
||||
edges.extend(new_edges)
|
||||
episodic_edges = build_episodic_edges(
|
||||
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
|
||||
list(set(nodes + affected_nodes)),
|
||||
episode,
|
||||
datetime.now(),
|
||||
|
||||
await asyncio.gather(
|
||||
*[edge.generate_embedding(embedder) for edge in extracted_edges]
|
||||
)
|
||||
|
||||
existing_edges = await get_relevant_edges(extracted_edges, self.driver)
|
||||
|
||||
new_edges = await dedupe_extracted_edges(
|
||||
self.llm_client, extracted_edges, existing_edges
|
||||
)
|
||||
|
||||
entity_edges.extend(new_edges)
|
||||
episodic_edges.extend(
|
||||
build_episodic_edges(
|
||||
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
|
||||
nodes,
|
||||
episode,
|
||||
now,
|
||||
)
|
||||
)
|
||||
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
|
||||
nodes.append(episode)
|
||||
logger.info(f"Built episodic edges: {episodic_edges}")
|
||||
edges.extend(episodic_edges)
|
||||
|
||||
# invalidated_edges = await self.invalidate_edges(
|
||||
# episode, new_nodes, new_edges, relevant_schema, previous_episodes
|
||||
# )
|
||||
|
||||
# edges.extend(invalidated_edges)
|
||||
|
||||
# Future optimization would be using batch operations to save nodes and edges
|
||||
await episode.save(self.driver)
|
||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
||||
# for node in nodes:
|
||||
# if isinstance(node, EntityNode):
|
||||
# await node.update_summary(self.driver)
|
||||
|
|
@ -129,6 +171,10 @@ class Graphiti:
|
|||
|
||||
async def build_indices(self):
|
||||
index_queries: list[LiteralString] = [
|
||||
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
|
||||
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
|
||||
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.uuid)",
|
||||
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[r:MENTIONS]-() ON (r.uuid)",
|
||||
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
||||
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
|
||||
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
|
||||
|
|
@ -143,13 +189,19 @@ class Graphiti:
|
|||
for query in index_queries:
|
||||
await self.driver.execute_query(query)
|
||||
|
||||
# Add the entity indices
|
||||
# Add the semantic indices
|
||||
await self.driver.execute_query(
|
||||
"""
|
||||
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
|
||||
"""
|
||||
)
|
||||
|
||||
await self.driver.execute_query(
|
||||
"""
|
||||
CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON EACH [r.name, r.fact]
|
||||
"""
|
||||
)
|
||||
|
||||
await self.driver.execute_query(
|
||||
"""
|
||||
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
|
||||
|
|
@ -161,29 +213,40 @@ class Graphiti:
|
|||
"""
|
||||
)
|
||||
|
||||
async def search(
|
||||
self, query: str, config
|
||||
) -> (list)[tuple[EntityNode, list[EntityEdge]]]:
|
||||
(vec_nodes, vec_edges) = similarity_search(query, embedder)
|
||||
(text_nodes, text_edges) = fulltext_search(query)
|
||||
await self.driver.execute_query(
|
||||
"""
|
||||
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
|
||||
FOR (n:Entity) ON (n.name_embedding)
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
nodes = vec_nodes.extend(text_nodes)
|
||||
edges = vec_edges.extend(text_edges)
|
||||
async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
|
||||
text = query.replace("\n", " ")
|
||||
search_vector = (
|
||||
(
|
||||
await self.llm_client.client.embeddings.create(
|
||||
input=[text], model="text-embedding-3-small"
|
||||
)
|
||||
)
|
||||
.data[0]
|
||||
.embedding[:EMBEDDING_DIM]
|
||||
)
|
||||
|
||||
results = bfs(nodes, edges, k=1)
|
||||
edges = await edge_similarity_search(search_vector, self.driver)
|
||||
nodes = await entity_fulltext_search(query, self.driver)
|
||||
|
||||
episode_ids = ["Mode of episode ids"]
|
||||
node_ids = [node.uuid for node in nodes]
|
||||
|
||||
episodes = get_episodes(episode_ids[:episode_count])
|
||||
for edge in edges:
|
||||
node_ids.append(edge.source_node_uuid)
|
||||
node_ids.append(edge.target_node_uuid)
|
||||
|
||||
return [(node, edges)], episodes
|
||||
node_ids = list(dict.fromkeys(node_ids))
|
||||
|
||||
# Invalidate edges that are no longer valid
|
||||
async def invalidate_edges(
|
||||
self,
|
||||
episode: EpisodicNode,
|
||||
new_nodes: list[EntityNode],
|
||||
new_edges: list[EntityEdge],
|
||||
relevant_schema: dict[str, any],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
): ...
|
||||
context = await bfs(node_ids, self.driver)
|
||||
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
EMBEDDING_DIM = 1024
|
||||
|
||||
|
||||
class LLMConfig:
|
||||
"""
|
||||
Configuration class for the Language Learning Model (LLM).
|
||||
|
|
@ -10,7 +13,7 @@ class LLMConfig:
|
|||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "gpt-4o",
|
||||
model: str = "gpt-4o-mini",
|
||||
base_url: str = "https://api.openai.com",
|
||||
):
|
||||
"""
|
||||
|
|
@ -21,7 +24,7 @@ class LLMConfig:
|
|||
This is required for making authorized requests.
|
||||
|
||||
model (str, optional): The specific LLM model to use for generating responses.
|
||||
Defaults to "gpt-4o", which appears to be a custom model name.
|
||||
Defaults to "gpt-4o-mini", which appears to be a custom model name.
|
||||
Common values might include "gpt-3.5-turbo" or "gpt-4".
|
||||
|
||||
base_url (str, optional): The base URL of the LLM API service.
|
||||
|
|
|
|||
|
|
@ -8,11 +8,13 @@ from pydantic import BaseModel, Field
|
|||
from neo4j import AsyncDriver
|
||||
import logging
|
||||
|
||||
from core.llm_client.config import EMBEDDING_DIM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Node(BaseModel, ABC):
|
||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||
uuid: str = Field(default_factory=lambda: uuid4().hex)
|
||||
name: str
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
created_at: datetime
|
||||
|
|
@ -66,21 +68,32 @@ class EpisodicNode(Node):
|
|||
|
||||
|
||||
class EntityNode(Node):
|
||||
name_embedding: list[float] | None = Field(
|
||||
default=None, description="embedding of the name"
|
||||
)
|
||||
summary: str = Field(description="regional summary of surrounding edges")
|
||||
|
||||
async def update_summary(self, driver: AsyncDriver): ...
|
||||
|
||||
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
|
||||
|
||||
async def generate_name_embedding(self, embedder, model="text-embedding-3-small"):
|
||||
text = self.name.replace("\n", " ")
|
||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||
self.name_embedding = embedding[:EMBEDDING_DIM]
|
||||
|
||||
return embedding
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MERGE (n:Entity {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, summary: $summary, created_at: $created_at}
|
||||
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at}
|
||||
RETURN n.uuid AS uuid""",
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
summary=self.summary,
|
||||
name_embedding=self.name_embedding,
|
||||
created_at=self.created_at,
|
||||
)
|
||||
|
||||
|
|
|
|||
56
core/prompts/dedupe_edges.py
Normal file
56
core/prompts/dedupe_edges.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import json
|
||||
from typing import TypedDict, Protocol
|
||||
|
||||
from .models import Message, PromptVersion, PromptFunction
|
||||
|
||||
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role="system",
|
||||
content="You are a helpful assistant that de-duplicates relationship from edge lists.",
|
||||
),
|
||||
Message(
|
||||
role="user",
|
||||
content=f"""
|
||||
Given the following context, deduplicate edges from a list of new edges given a list of existing edges:
|
||||
|
||||
Existing Edges:
|
||||
{json.dumps(context['existing_edges'], indent=2)}
|
||||
|
||||
New Edges:
|
||||
{json.dumps(context['extracted_edges'], indent=2)}
|
||||
|
||||
Task:
|
||||
1. start with the list of edges from New Edges
|
||||
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
|
||||
edge in the list
|
||||
3. Respond with the resulting list of edges
|
||||
|
||||
Guidelines:
|
||||
1. Use both the name and fact of edges to determine if they are duplicates,
|
||||
duplicate edges may have different names
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"new_edges": [
|
||||
{{
|
||||
"name": "Unique identifier for the edge",
|
||||
"fact": "one sentence description of the fact"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {"v1": v1}
|
||||
56
core/prompts/dedupe_nodes.py
Normal file
56
core/prompts/dedupe_nodes.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import json
|
||||
from typing import TypedDict, Protocol
|
||||
|
||||
from .models import Message, PromptVersion, PromptFunction
|
||||
|
||||
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role="system",
|
||||
content="You are a helpful assistant that de-duplicates nodes from node lists.",
|
||||
),
|
||||
Message(
|
||||
role="user",
|
||||
content=f"""
|
||||
Given the following context, deduplicate nodes from a list of new nodes given a list of existing nodes:
|
||||
|
||||
Existing Nodes:
|
||||
{json.dumps(context['existing_nodes'], indent=2)}
|
||||
|
||||
New Nodes:
|
||||
{json.dumps(context['extracted_nodes'], indent=2)}
|
||||
|
||||
Task:
|
||||
1. start with the list of nodes from New Nodes
|
||||
2. If any node in New Nodes is a duplicate of a node in Existing Nodes, replace the new node with the existing
|
||||
node in the list
|
||||
3. Respond with the resulting list of nodes
|
||||
|
||||
Guidelines:
|
||||
1. Use both the name and summary of nodes to determine if they are duplicates,
|
||||
duplicate nodes may have different names
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"new_nodes": [
|
||||
{{
|
||||
"name": "Unique identifier for the node",
|
||||
"summary": "Brief summary of the node's role or significance"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {"v1": v1}
|
||||
|
|
@ -6,10 +6,12 @@ from .models import Message, PromptVersion, PromptFunction
|
|||
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
v2: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
v2: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
|
|
@ -68,6 +70,108 @@ def v1(context: dict[str, any]) -> list[Message]:
|
|||
]
|
||||
|
||||
|
||||
versions: Versions = {
|
||||
"v1": v1,
|
||||
}
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role="system",
|
||||
content="You are a helpful assistant that extracts graph edges from provided context.",
|
||||
),
|
||||
Message(
|
||||
role="user",
|
||||
content=f"""
|
||||
Given the following context, extract new semantic edges (relationships) that need to be added to the knowledge graph:
|
||||
|
||||
Current Graph Structure:
|
||||
{context['relevant_schema']}
|
||||
|
||||
New Nodes:
|
||||
{json.dumps(context['new_nodes'], indent=2)}
|
||||
|
||||
New Episode:
|
||||
Content: {context['episode_content']}
|
||||
Timestamp: {context['episode_timestamp']}
|
||||
|
||||
Previous Episodes:
|
||||
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
|
||||
|
||||
Extract new semantic edges based on the content of the current episode, considering the existing graph structure, new nodes, and context from previous episodes.
|
||||
|
||||
Guidelines:
|
||||
1. Create edges only between semantic nodes (not episodic nodes like messages).
|
||||
2. Each edge should represent a clear relationship between two semantic nodes.
|
||||
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
|
||||
4. Provide a more detailed fact describing the relationship.
|
||||
5. If a relationship seems to update an existing one, create a new edge with the updated information.
|
||||
6. Consider temporal aspects of relationships when relevant.
|
||||
7. Do not create edges involving episodic nodes (like Message 1 or Message 2).
|
||||
8. Use existing nodes from the current graph structure when appropriate.
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"new_edges": [
|
||||
{{
|
||||
"relation_type": "RELATION_TYPE_IN_CAPS",
|
||||
"source_node": "Name of the source semantic node",
|
||||
"target_node": "Name of the target semantic node",
|
||||
"fact": "Detailed description of the relationship",
|
||||
"valid_at": "YYYY-MM-DDTHH:MM:SSZ or null if not explicitly mentioned",
|
||||
"invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null if ongoing or not explicitly mentioned"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
If no new edges need to be added, return an empty list for "new_edges".
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def v2(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role="system",
|
||||
content="You are a helpful assistant that extracts graph edges from provided context.",
|
||||
),
|
||||
Message(
|
||||
role="user",
|
||||
content=f"""
|
||||
Given the following context, extract new edges (relationships) that need to be added to the knowledge graph:
|
||||
Nodes:
|
||||
{json.dumps(context['nodes'], indent=2)}
|
||||
|
||||
New Episode:
|
||||
Content: {context['episode_content']}
|
||||
|
||||
Previous Episodes:
|
||||
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
|
||||
|
||||
Extract new entity edges based on the content of the current episode, the given nodes, and context from previous episodes.
|
||||
|
||||
Guidelines:
|
||||
1. Create edges only between the provided nodes.
|
||||
2. Each edge should represent a clear relationship between two nodes.
|
||||
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
|
||||
4. Provide a more detailed fact describing the relationship.
|
||||
5. Consider temporal aspects of relationships when relevant.
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"edges": [
|
||||
{{
|
||||
"relation_type": "RELATION_TYPE_IN_CAPS",
|
||||
"source_node_uuid": "uuid of the source entity node",
|
||||
"target_node_uuid": "uuid of the target entity node",
|
||||
"fact": "Detailed description of the relationship",
|
||||
"valid_at": "YYYY-MM-DDTHH:MM:SSZ or null if not explicitly mentioned",
|
||||
"invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null if ongoing or not explicitly mentioned"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
If no new edges need to be added, return an empty list for "new_edges".
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {"v1": v1, "v2": v2}
|
||||
|
|
|
|||
|
|
@ -6,10 +6,12 @@ from .models import Message, PromptVersion, PromptFunction
|
|||
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
v2: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
v2: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
|
|
@ -60,6 +62,45 @@ def v1(context: dict[str, any]) -> list[Message]:
|
|||
]
|
||||
|
||||
|
||||
versions: Versions = {
|
||||
"v1": v1,
|
||||
}
|
||||
def v2(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role="system",
|
||||
content="You are a helpful assistant that extracts graph nodes from provided context.",
|
||||
),
|
||||
Message(
|
||||
role="user",
|
||||
content=f"""
|
||||
Given the following context, extract new entity nodes that need to be added to the knowledge graph:
|
||||
|
||||
Previous Episodes:
|
||||
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
|
||||
|
||||
New Episode:
|
||||
Content: {context["episode_content"]}
|
||||
|
||||
Extract new entity nodes based on the content of the current episode, while considering the context from previous episodes.
|
||||
|
||||
Guidelines:
|
||||
1. Focus on entities, concepts, or actors that are central to the current episode.
|
||||
2. Avoid creating nodes for relationships or actions (these will be handled as edges later).
|
||||
3. Provide a brief but informative summary for each node.
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"new_nodes": [
|
||||
{{
|
||||
"name": "Unique identifier for the node",
|
||||
"labels": ["Entity", "OptionalAdditionalLabel"],
|
||||
"summary": "Brief summary of the node's role or significance"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
If no new nodes need to be added, return an empty list for "new_nodes".
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {"v1": v1, "v2": v2}
|
||||
|
|
|
|||
|
|
@ -8,21 +8,37 @@ from .extract_nodes import (
|
|||
versions as extract_nodes_versions,
|
||||
)
|
||||
|
||||
from .dedupe_nodes import (
|
||||
Prompt as DedupeNodesPrompt,
|
||||
Versions as DedupeNodesVersions,
|
||||
versions as dedupe_nodes_versions,
|
||||
)
|
||||
|
||||
from .extract_edges import (
|
||||
Prompt as ExtractEdgesPrompt,
|
||||
Versions as ExtractEdgesVersions,
|
||||
versions as extract_edges_versions,
|
||||
)
|
||||
|
||||
from .dedupe_edges import (
|
||||
Prompt as DedupeEdgesPrompt,
|
||||
Versions as DedupeEdgesVersions,
|
||||
versions as dedupe_edges_versions,
|
||||
)
|
||||
|
||||
|
||||
class PromptLibrary(Protocol):
|
||||
extract_nodes: ExtractNodesPrompt
|
||||
dedupe_nodes: DedupeNodesPrompt
|
||||
extract_edges: ExtractEdgesPrompt
|
||||
dedupe_edges: DedupeEdgesPrompt
|
||||
|
||||
|
||||
class PromptLibraryImpl(TypedDict):
|
||||
extract_nodes: ExtractNodesVersions
|
||||
dedupe_nodes: DedupeNodesVersions
|
||||
extract_edges: ExtractEdgesVersions
|
||||
dedupe_edges: DedupeEdgesVersions
|
||||
|
||||
|
||||
class VersionWrapper:
|
||||
|
|
@ -47,7 +63,9 @@ class PromptLibraryWrapper:
|
|||
|
||||
PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
||||
"extract_nodes": extract_nodes_versions,
|
||||
"dedupe_nodes": dedupe_nodes_versions,
|
||||
"extract_edges": extract_edges_versions,
|
||||
"dedupe_edges": dedupe_edges_versions,
|
||||
}
|
||||
|
||||
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)
|
||||
|
|
|
|||
|
|
@ -1,45 +0,0 @@
|
|||
from typing import Tuple
|
||||
|
||||
from core.edges import EpisodicEdge, EntityEdge, Edge
|
||||
from core.nodes import EntityNode, EpisodicNode, Node
|
||||
|
||||
|
||||
async def bfs(
|
||||
nodes: list[Node], edges: list[Edge], k: int
|
||||
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
|
||||
|
||||
|
||||
# Breadth first search over nodes and edges with desired depth
|
||||
|
||||
|
||||
async def similarity_search(
|
||||
query: str, embedder
|
||||
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
|
||||
|
||||
|
||||
# vector similarity search over embedded facts
|
||||
|
||||
|
||||
async def fulltext_search(
|
||||
query: str,
|
||||
) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
|
||||
|
||||
|
||||
# fulltext search over names and summary
|
||||
|
||||
|
||||
def build_episodic_edges(
|
||||
entity_nodes: list[EntityNode], episode: EpisodicNode
|
||||
) -> list[EpisodicEdge]:
|
||||
edges: list[EpisodicEdge] = []
|
||||
|
||||
for node in entity_nodes:
|
||||
edges.append(
|
||||
EpisodicEdge(
|
||||
source_node=episode,
|
||||
target_node=node,
|
||||
created_at=episode.created_at,
|
||||
)
|
||||
)
|
||||
|
||||
return edges
|
||||
|
|
@ -13,15 +13,17 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def build_episodic_edges(
|
||||
semantic_nodes: List[EntityNode],
|
||||
entity_nodes: List[EntityNode],
|
||||
episode: EpisodicNode,
|
||||
transaction_from: datetime,
|
||||
) -> List[EpisodicEdge]:
|
||||
edges: List[EpisodicEdge] = []
|
||||
|
||||
for node in semantic_nodes:
|
||||
for node in entity_nodes:
|
||||
edge = EpisodicEdge(
|
||||
source_node=episode, target_node=node, created_at=transaction_from
|
||||
source_node_uuid=episode.uuid,
|
||||
target_node_uuid=node.uuid,
|
||||
created_at=transaction_from,
|
||||
)
|
||||
edges.append(edge)
|
||||
|
||||
|
|
@ -132,3 +134,94 @@ async def extract_new_edges(
|
|||
affected_nodes.add(edge.source_node)
|
||||
affected_nodes.add(edge.target_node)
|
||||
return new_edges, list(affected_nodes)
|
||||
|
||||
|
||||
async def extract_edges(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
nodes: list[EntityNode],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityEdge]:
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
"episode_content": episode.content,
|
||||
"episode_timestamp": (
|
||||
episode.valid_at.isoformat() if episode.valid_at else None
|
||||
),
|
||||
"nodes": [
|
||||
{"uuid": node.uuid, "name": node.name, "summary": node.summary}
|
||||
for node in nodes
|
||||
],
|
||||
"previous_episodes": [
|
||||
{
|
||||
"content": ep.content,
|
||||
"timestamp": ep.valid_at.isoformat() if ep.valid_at else None,
|
||||
}
|
||||
for ep in previous_episodes
|
||||
],
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_edges.v2(context)
|
||||
)
|
||||
edges_data = llm_response.get("edges", [])
|
||||
logger.info(f"Extracted new edges: {edges_data}")
|
||||
|
||||
# Convert the extracted data into EntityEdge objects
|
||||
edges = []
|
||||
for edge_data in edges_data:
|
||||
edge = EntityEdge(
|
||||
source_node_uuid=edge_data["source_node_uuid"],
|
||||
target_node_uuid=edge_data["target_node_uuid"],
|
||||
name=edge_data["relation_type"],
|
||||
fact=edge_data["fact"],
|
||||
episodes=[episode.uuid],
|
||||
created_at=datetime.now(),
|
||||
valid_at=edge_data["valid_at"],
|
||||
invalid_at=edge_data["invalid_at"],
|
||||
)
|
||||
edges.append(edge)
|
||||
logger.info(
|
||||
f"Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})"
|
||||
)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def dedupe_extracted_edges(
|
||||
llm_client: LLMClient,
|
||||
extracted_edges: list[EntityEdge],
|
||||
existing_edges: list[EntityEdge],
|
||||
) -> list[EntityEdge]:
|
||||
# Create edge map
|
||||
edge_map = {}
|
||||
for edge in existing_edges:
|
||||
edge_map[edge.name] = edge
|
||||
for edge in extracted_edges:
|
||||
if edge.name in edge_map.keys():
|
||||
continue
|
||||
edge_map[edge.name] = edge
|
||||
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
"extracted_edges": [
|
||||
{"name": edge.name, "fact": edge.fact} for edge in extracted_edges
|
||||
],
|
||||
"existing_edges": [
|
||||
{"name": edge.name, "fact": edge.fact} for edge in extracted_edges
|
||||
],
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.dedupe_edges.v1(context)
|
||||
)
|
||||
new_edges_data = llm_response.get("new_edges", [])
|
||||
logger.info(f"Extracted new edges: {new_edges_data}")
|
||||
|
||||
# Get full edge data
|
||||
edges = []
|
||||
for edge_data in new_edges_data:
|
||||
edge = edge_map[edge_data["name"]]
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
|
|
|||
|
|
@ -61,3 +61,87 @@ async def extract_new_nodes(
|
|||
logger.info(f"Node {node_data['name']} already exists, skipping creation.")
|
||||
|
||||
return new_nodes
|
||||
|
||||
|
||||
async def extract_nodes(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityNode]:
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
"episode_content": episode.content,
|
||||
"episode_timestamp": (
|
||||
episode.valid_at.isoformat() if episode.valid_at else None
|
||||
),
|
||||
"previous_episodes": [
|
||||
{
|
||||
"content": ep.content,
|
||||
"timestamp": ep.valid_at.isoformat() if ep.valid_at else None,
|
||||
}
|
||||
for ep in previous_episodes
|
||||
],
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.v2(context)
|
||||
)
|
||||
new_nodes_data = llm_response.get("new_nodes", [])
|
||||
logger.info(f"Extracted new nodes: {new_nodes_data}")
|
||||
# Convert the extracted data into EntityNode objects
|
||||
new_nodes = []
|
||||
for node_data in new_nodes_data:
|
||||
new_node = EntityNode(
|
||||
name=node_data["name"],
|
||||
labels=node_data["labels"],
|
||||
summary=node_data["summary"],
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
new_nodes.append(new_node)
|
||||
logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})")
|
||||
|
||||
return new_nodes
|
||||
|
||||
|
||||
async def dedupe_extracted_nodes(
|
||||
llm_client: LLMClient,
|
||||
extracted_nodes: list[EntityNode],
|
||||
existing_nodes: list[EntityNode],
|
||||
) -> list[EntityNode]:
|
||||
# build node map
|
||||
node_map = {}
|
||||
for node in existing_nodes:
|
||||
node_map[node.name] = node
|
||||
for node in extracted_nodes:
|
||||
if node.name in node_map.keys():
|
||||
continue
|
||||
node_map[node.name] = node
|
||||
|
||||
# Prepare context for LLM
|
||||
existing_nodes_context = [
|
||||
{"name": node.name, "summary": node.summary} for node in existing_nodes
|
||||
]
|
||||
|
||||
extracted_nodes_context = [
|
||||
{"name": node.name, "summary": node.summary} for node in extracted_nodes
|
||||
]
|
||||
|
||||
context = {
|
||||
"existing_nodes": existing_nodes_context,
|
||||
"extracted_nodes": extracted_nodes_context,
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.dedupe_nodes.v1(context)
|
||||
)
|
||||
|
||||
new_nodes_data = llm_response.get("new_nodes", [])
|
||||
logger.info(f"Deduplicated nodes: {new_nodes_data}")
|
||||
|
||||
# Get full node data
|
||||
nodes = []
|
||||
for node_data in new_nodes_data:
|
||||
node = node_map[node_data["name"]]
|
||||
nodes.append(node)
|
||||
|
||||
return nodes
|
||||
|
|
|
|||
0
core/utils/search/__init__.py
Normal file
0
core/utils/search/__init__.py
Normal file
274
core/utils/search/search_utils.py
Normal file
274
core/utils/search/search_utils.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
|
||||
from core.edges import EntityEdge
|
||||
from core.nodes import EntityNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
|
||||
RETURN
|
||||
n.uuid AS source_node_uuid,
|
||||
n.name AS source_name,
|
||||
n.summary AS source_summary,
|
||||
m.uuid AS target_node_uuid,
|
||||
m.name AS target_name,
|
||||
m.summary AS target_summary,
|
||||
r.uuid AS uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
|
||||
""",
|
||||
node_ids=node_ids,
|
||||
)
|
||||
|
||||
context = {}
|
||||
|
||||
for record in records:
|
||||
n_uuid = record["source_node_uuid"]
|
||||
if n_uuid in context.keys():
|
||||
context[n_uuid]["facts"].append(record["fact"])
|
||||
else:
|
||||
context[n_uuid] = {
|
||||
"name": record["source_name"],
|
||||
"summary": record["source_summary"],
|
||||
"facts": [record["fact"]],
|
||||
}
|
||||
|
||||
m_uuid = record["target_node_uuid"]
|
||||
if m_uuid not in context:
|
||||
context[m_uuid] = {
|
||||
"name": record["target_name"],
|
||||
"summary": record["target_summary"],
|
||||
"facts": [],
|
||||
}
|
||||
logger.info(f"bfs search returned context: {context}")
|
||||
return context
|
||||
|
||||
|
||||
async def edge_similarity_search(
|
||||
search_vector: list[float], driver: AsyncDriver
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
|
||||
YIELD relationship AS r, score
|
||||
MATCH (n)-[r:RELATES_TO]->(m)
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT 10
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
for record in records:
|
||||
edge = EntityEdge(
|
||||
uuid=record["uuid"],
|
||||
source_node_uuid=record["source_node_uuid"],
|
||||
target_node_uuid=record["target_node_uuid"],
|
||||
fact=record["fact"],
|
||||
name=record["name"],
|
||||
episodes=record["episodes"],
|
||||
fact_embedding=record["fact_embedding"],
|
||||
created_at=now,
|
||||
expired_at=now,
|
||||
valid_at=now,
|
||||
invalid_At=now,
|
||||
)
|
||||
|
||||
edges.append(edge)
|
||||
|
||||
logger.info(f"similarity search results. RESULT: {[edge.uuid for edge in edges]}")
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def entity_similarity_search(
|
||||
search_vector: list[float], driver: AsyncDriver
|
||||
) -> list[EntityNode]:
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryNodes("name_embedding", 5, $search_vector)
|
||||
YIELD node AS n, score
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
)
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
for record in records:
|
||||
nodes.append(
|
||||
EntityNode(
|
||||
uuid=record["uuid"],
|
||||
name=record["name"],
|
||||
labels=[],
|
||||
created_at=datetime.now(),
|
||||
summary=record["summary"],
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"name semantic search results. RESULT: {nodes}")
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityNode]:
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = query + "~"
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
||||
RETURN
|
||||
node.uuid As uuid,
|
||||
node.name AS name,
|
||||
node.created_at AS created_at,
|
||||
node.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
)
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
for record in records:
|
||||
nodes.append(
|
||||
EntityNode(
|
||||
uuid=record["uuid"],
|
||||
name=record["name"],
|
||||
labels=[],
|
||||
created_at=datetime.now(),
|
||||
summary=record["summary"],
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"fulltext search results. QUERY:{query}. RESULT: {nodes}")
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEdge]:
|
||||
# fulltext search over facts
|
||||
fuzzy_query = query + "~"
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS r, score
|
||||
MATCH (n:Entity)-[r]->(m:Entity)
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT 10
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
for record in records:
|
||||
edge = EntityEdge(
|
||||
uuid=record["uuid"],
|
||||
source_node_uuid=record["source_node_uuid"],
|
||||
target_node_uuid=record["target_node_uuid"],
|
||||
fact=record["fact"],
|
||||
name=record["name"],
|
||||
episodes=record["episodes"],
|
||||
fact_embedding=record["fact_embedding"],
|
||||
created_at=now,
|
||||
expired_at=now,
|
||||
valid_at=now,
|
||||
invalid_At=now,
|
||||
)
|
||||
|
||||
edges.append(edge)
|
||||
|
||||
logger.info(
|
||||
f"similarity search results. QUERY:{query}. RESULT: {[edge.uuid for edge in edges]}"
|
||||
)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def get_relevant_nodes(
|
||||
nodes: list[EntityNode],
|
||||
driver: AsyncDriver,
|
||||
) -> list[EntityNode]:
|
||||
relevant_nodes: dict[str, EntityNode] = {}
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[entity_fulltext_search(node.name, driver) for node in nodes],
|
||||
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
|
||||
)
|
||||
|
||||
for result in results:
|
||||
for node in result:
|
||||
relevant_nodes[node.uuid] = node
|
||||
|
||||
logger.info(f"Found relevant nodes: {relevant_nodes.keys()}")
|
||||
|
||||
return relevant_nodes.values()
|
||||
|
||||
|
||||
async def get_relevant_edges(
|
||||
edges: list[EntityEdge],
|
||||
driver: AsyncDriver,
|
||||
) -> list[EntityEdge]:
|
||||
relevant_edges: dict[str, EntityEdge] = {}
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
|
||||
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
|
||||
)
|
||||
|
||||
for result in results:
|
||||
for edge in result:
|
||||
relevant_edges[edge.uuid] = edge
|
||||
|
||||
logger.info(f"Found relevant nodes: {relevant_edges.keys()}")
|
||||
|
||||
return relevant_edges.values()
|
||||
25
core/utils/utils.py
Normal file
25
core/utils/utils.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import logging
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
|
||||
from core.edges import EpisodicEdge, EntityEdge, Edge
|
||||
from core.nodes import EntityNode, EpisodicNode, Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_episodic_edges(
|
||||
entity_nodes: list[EntityNode], episode: EpisodicNode
|
||||
) -> list[EpisodicEdge]:
|
||||
edges: list[EpisodicEdge] = []
|
||||
|
||||
for node in entity_nodes:
|
||||
edges.append(
|
||||
EpisodicEdge(
|
||||
source_node_uuid=episode,
|
||||
target_node_uuid=node,
|
||||
created_at=episode.created_at,
|
||||
)
|
||||
)
|
||||
|
||||
return edges
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
|
@ -9,9 +11,11 @@ from openai import OpenAI
|
|||
|
||||
from core.edges import EpisodicEdge, EntityEdge
|
||||
from core.graphiti import Graphiti
|
||||
from core.llm_client.config import EMBEDDING_DIM
|
||||
from core.nodes import EpisodicNode, EntityNode
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
load_dotenv()
|
||||
|
|
@ -21,10 +25,59 @@ NEO4j_USER = os.getenv("NEO4J_USER")
|
|||
NEO4j_PASSWORD = os.getenv("NEO4J_PASSWORD")
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def format_context(context):
|
||||
formatted_string = ""
|
||||
for uuid, data in context.items():
|
||||
formatted_string += f"UUID: {uuid}\n"
|
||||
formatted_string += f" Name: {data['name']}\n"
|
||||
formatted_string += f" Summary: {data['summary']}\n"
|
||||
formatted_string += " Facts:\n"
|
||||
for fact in data["facts"]:
|
||||
formatted_string += f" - {fact}\n"
|
||||
formatted_string += "\n"
|
||||
return formatted_string.strip()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphiti_init():
|
||||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
|
||||
await graphiti.build_indices()
|
||||
|
||||
context = await graphiti.search("Freakenomics guest")
|
||||
|
||||
logger.info("QUERY: Freakenomics guest" + "RESULT:" + format_context(context))
|
||||
|
||||
context = await graphiti.search("tania tetlow")
|
||||
|
||||
logger.info("QUERY: Tania Tetlow" + "RESULT:" + format_context(context))
|
||||
|
||||
context = await graphiti.search("issues with higher ed")
|
||||
|
||||
logger.info("QUERY: issues with higher ed" + "RESULT:" + format_context(context))
|
||||
graphiti.close()
|
||||
|
||||
|
||||
|
|
@ -57,16 +110,16 @@ async def test_graph_integration():
|
|||
bob_node = EntityNode(name="Bob", labels=[], created_at=now, summary="Bob summary")
|
||||
|
||||
episodic_edge_1 = EpisodicEdge(
|
||||
source_node=episode, target_node=alice_node, created_at=now
|
||||
source_node_uuid=episode, target_node_uuid=alice_node, created_at=now
|
||||
)
|
||||
|
||||
episodic_edge_2 = EpisodicEdge(
|
||||
source_node=episode, target_node=bob_node, created_at=now
|
||||
source_node_uuid=episode, target_node_uuid=bob_node, created_at=now
|
||||
)
|
||||
|
||||
entity_edge = EntityEdge(
|
||||
source_node=alice_node,
|
||||
target_node=bob_node,
|
||||
source_node_uuid=alice_node.uuid,
|
||||
target_node_uuid=bob_node.uuid,
|
||||
created_at=now,
|
||||
name="likes",
|
||||
fact="Alice likes Bob",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue