Add episode refactor (#399)

* partial refactor

* get relevant nodes refactor

* load edges updates

* refactor triplets

* not there yet

* node search update

* working refactor

* updates

* mypy

* mypy
This commit is contained in:
Preston Rasmussen 2025-04-26 00:24:23 -04:00 committed by GitHub
parent 15efa37da1
commit a26b25dc06
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 380 additions and 302 deletions

View file

@ -1,4 +1,10 @@
OPENAI_API_KEY=
NEO4J_URI=
NEO4J_PORT=
NEO4J_USER=
NEO4J_PASSWORD=
DEFAULT_DATABASE=
USE_PARALLEL_RUNTIME=
SEMAPHORE_LIMIT=
GITHUB_SHA=
MAX_REFLEXION_ITERATIONS=

View file

@ -37,6 +37,21 @@ from graphiti_core.nodes import Node
logger = logging.getLogger(__name__)
ENTITY_EDGE_RETURN: LiteralString = """
RETURN
e.uuid AS uuid,
startNode(e).uuid AS source_node_uuid,
endNode(e).uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at"""
class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
@ -234,20 +249,8 @@ class EntityEdge(Edge):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at
""",
"""
+ ENTITY_EDGE_RETURN,
uuid=uuid,
database_=DEFAULT_DATABASE,
routing_='r',
@ -268,20 +271,8 @@ class EntityEdge(Edge):
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.uuid IN $uuids
RETURN
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at
""",
"""
+ ENTITY_EDGE_RETURN,
uuids=uuids,
database_=DEFAULT_DATABASE,
routing_='r',
@ -308,20 +299,8 @@ class EntityEdge(Edge):
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ ENTITY_EDGE_RETURN
+ """
RETURN
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at
ORDER BY e.uuid DESC
"""
+ limit_query,
@ -340,22 +319,12 @@ class EntityEdge(Edge):
@classmethod
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
query: LiteralString = """
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
RETURN DISTINCT
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at
"""
query: LiteralString = (
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
+ ENTITY_EDGE_RETURN
)
records, _, _ = await driver.execute_query(
query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r'
)

View file

@ -27,6 +27,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
@ -42,7 +43,6 @@ from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_mentioned_nodes,
get_relevant_edges,
get_relevant_nodes,
)
from graphiti_core.utils.bulk_utils import (
RawEpisode,
@ -150,6 +150,13 @@ class Graphiti:
else:
self.cross_encoder = OpenAIRerankerClient()
self.clients = GraphitiClients(
driver=self.driver,
llm_client=self.llm_client,
embedder=self.embedder,
cross_encoder=self.cross_encoder,
)
async def close(self):
"""
Close the connection to the Neo4j database.
@ -222,6 +229,7 @@ class Graphiti:
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str] | None = None,
source: EpisodeType | None = None,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
@ -248,7 +256,7 @@ class Graphiti:
The actual retrieval is performed by the `retrieve_episodes` function
from the `graphiti_core.utils` module.
"""
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
async def add_episode(
self,
@ -314,15 +322,16 @@ class Graphiti:
"""
try:
start = time()
entity_edges: list[EntityEdge] = []
now = utc_now()
validate_entity_types(entity_types)
previous_episodes = (
await self.retrieve_episodes(
reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
reference_time,
last_n=RELEVANT_SCHEMA_LIMIT,
group_ids=[group_id],
source=source,
)
if previous_episode_uuids is None
else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
@ -346,132 +355,35 @@ class Graphiti:
# Extract entities as nodes
extracted_nodes = await extract_nodes(
self.llm_client, episode, previous_episodes, entity_types
)
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
# Calculate Embeddings
await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
self.clients, episode, previous_episodes, entity_types
)
# Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await semaphore_gather(
*[
get_relevant_nodes(self.driver, SearchFilters(), [node])
for node in extracted_nodes
]
)
)
# Resolve extracted nodes with nodes already in the graph and extract facts
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
(mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
# Extract edges and resolve nodes
(nodes, uuid_map), extracted_edges = await semaphore_gather(
resolve_extracted_nodes(
self.llm_client,
self.clients,
extracted_nodes,
existing_nodes_lists,
episode,
previous_episodes,
entity_types,
),
extract_edges(
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
),
extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
)
logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes = mentioned_nodes
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
extracted_edges, uuid_map
)
# calculate embeddings
await semaphore_gather(
*[
edge.generate_embedding(self.embedder)
for edge in extracted_edges_with_resolved_pointers
]
)
# Resolve extracted edges with related edges already in the graph
related_edges_list: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
[edge],
edge.source_node_uuid,
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)
logger.debug(
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
)
logger.debug(
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
)
existing_source_edges_list: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
[edge],
edge.source_node_uuid,
None,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)
existing_target_edges_list: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
get_relevant_edges(
self.driver,
[edge],
None,
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)
existing_edges_list: list[list[EntityEdge]] = [
source_lst + target_lst
for source_lst, target_lst in zip(
existing_source_edges_list, existing_target_edges_list, strict=False
)
]
resolved_edges, invalidated_edges = await resolve_extracted_edges(
self.llm_client,
self.clients,
extracted_edges_with_resolved_pointers,
related_edges_list,
existing_edges_list,
episode,
previous_episodes,
)
entity_edges.extend(resolved_edges + invalidated_edges)
entity_edges = resolved_edges + invalidated_edges
logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
logger.debug(f'Built episodic edges: {episodic_edges}')
episodic_edges = build_episodic_edges(nodes, episode, now)
episode.entity_edges = [edge.uuid for edge in entity_edges]
@ -565,7 +477,7 @@ class Graphiti:
extracted_nodes,
extracted_edges,
episodic_edges,
) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs)
# Generate embeddings
await semaphore_gather(
@ -684,9 +596,7 @@ class Graphiti:
edges = (
await search(
self.driver,
self.embedder,
self.cross_encoder,
self.clients,
query,
group_ids,
search_config,
@ -728,9 +638,7 @@ class Graphiti:
"""
return await search(
self.driver,
self.embedder,
self.cross_encoder,
self.clients,
query,
group_ids,
config,
@ -761,26 +669,17 @@ class Graphiti:
await edge.generate_embedding(self.embedder)
resolved_nodes, uuid_map = await resolve_extracted_nodes(
self.llm_client,
self.clients,
[source_node, target_node],
[
await get_relevant_nodes(self.driver, SearchFilters(), [source_node]),
await get_relevant_nodes(self.driver, SearchFilters(), [target_node]),
],
)
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
related_edges = await get_relevant_edges(
self.driver,
[updated_edge],
source_node_uuid=resolved_nodes[0].uuid,
target_node_uuid=resolved_nodes[1].uuid,
)
related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges)
resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges[0])
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges)
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
await add_nodes_and_edges_bulk(

View file

@ -0,0 +1,31 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from neo4j import AsyncDriver
from pydantic import BaseModel, ConfigDict
from graphiti_core.cross_encoder import CrossEncoderClient
from graphiti_core.embedder import EmbedderClient
from graphiti_core.llm_client import LLMClient
class GraphitiClients(BaseModel):
driver: AsyncDriver
llm_client: LLMClient
embedder: EmbedderClient
cross_encoder: CrossEncoderClient
model_config = ConfigDict(arbitrary_types_allowed=True)

View file

@ -22,15 +22,20 @@ from datetime import datetime
import numpy as np
from dotenv import load_dotenv
from neo4j import time as neo4j_time
from typing_extensions import LiteralString
load_dotenv()
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 2))
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 1))
DEFAULT_PAGE_LIMIT = 20
RUNTIME_QUERY: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
return neo_date.to_native() if neo_date else None

View file

@ -47,7 +47,7 @@ ENTITY_EDGE_SAVE_BULK = """
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
RETURN r.uuid AS uuid
RETURN edge.uuid AS uuid
"""
COMMUNITY_EDGE_SAVE = """

View file

@ -22,8 +22,8 @@ from neo4j import AsyncDriver
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import SearchRerankerError
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.search.search_config import (
@ -62,17 +62,21 @@ logger = logging.getLogger(__name__)
async def search(
driver: AsyncDriver,
embedder: EmbedderClient,
cross_encoder: CrossEncoderClient,
clients: GraphitiClients,
query: str,
group_ids: list[str] | None,
config: SearchConfig,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
query_vector: list[float] | None = None,
) -> SearchResults:
start = time()
driver = clients.driver
embedder = clients.embedder
cross_encoder = clients.cross_encoder
if query.strip() == '':
return SearchResults(
edges=[],
@ -80,7 +84,11 @@ async def search(
episodes=[],
communities=[],
)
query_vector = await embedder.create(input_data=[query.replace('\n', ' ')])
query_vector = (
query_vector
if query_vector is not None
else await embedder.create(input_data=[query.replace('\n', ' ')])
)
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None

View file

@ -26,7 +26,7 @@ from typing_extensions import LiteralString
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.helpers import (
DEFAULT_DATABASE,
USE_PARALLEL_RUNTIME,
RUNTIME_QUERY,
lucene_sanitize,
normalize_l2,
semaphore_gather,
@ -207,10 +207,6 @@ async def edge_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)
query_params: dict[str, Any] = {}
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
@ -230,9 +226,10 @@ async def edge_similarity_search(
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
query: LiteralString = (
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
RUNTIME_QUERY
+ """
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
+ group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@ -256,7 +253,7 @@ async def edge_similarity_search(
)
records, _, _ = await driver.execute_query(
runtime_query + query,
query,
query_params,
search_vector=search_vector,
source_uuid=source_node_uuid,
@ -344,10 +341,10 @@ async def node_fulltext_search(
query = (
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
WHERE n:Entity
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
WHERE n:Entity
"""
+ filter_query
+ ENTITY_NODE_RETURN
+ """
@ -378,10 +375,6 @@ async def node_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
# vector similarity search over entity names
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)
query_params: dict[str, Any] = {}
group_filter_query: LiteralString = ''
@ -393,7 +386,7 @@ async def node_similarity_search(
query_params.update(filter_params)
records, _, _ = await driver.execute_query(
runtime_query
RUNTIME_QUERY
+ """
MATCH (n:Entity)
"""
@ -542,10 +535,6 @@ async def community_similarity_search(
min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]:
# vector similarity search over entity names
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)
query_params: dict[str, Any] = {}
group_filter_query: LiteralString = ''
@ -554,7 +543,7 @@ async def community_similarity_search(
query_params['group_ids'] = group_ids
records, _, _ = await driver.execute_query(
runtime_query
RUNTIME_QUERY
+ """
MATCH (comm:Community)
"""
@ -660,86 +649,204 @@ async def hybrid_node_search(
async def get_relevant_nodes(
driver: AsyncDriver,
search_filter: SearchFilters,
nodes: list[EntityNode],
) -> list[EntityNode]:
"""
Retrieve relevant nodes based on the provided list of EntityNodes.
search_filter: SearchFilters,
min_score: float = DEFAULT_MIN_SCORE,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[list[EntityNode]]:
if len(nodes) == 0:
return []
This method performs a hybrid search using both the names and embeddings
of the input nodes to find relevant nodes in the graph database.
group_id = nodes[0].group_id
Parameters
----------
nodes : list[EntityNode]
A list of EntityNode objects to use as the basis for the search.
driver : AsyncDriver
The Neo4j driver instance for database operations.
# vector similarity search over entity names
query_params: dict[str, Any] = {}
Returns
-------
list[EntityNode]
A list of EntityNode objects that are deemed relevant based on the input nodes.
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
query_params.update(filter_params)
Notes
-----
This method uses the hybrid_node_search function to perform the search,
which combines fulltext search and vector similarity search.
It extracts the names and name embeddings (if available) from the input nodes
to use as search criteria.
"""
relevant_nodes = await hybrid_node_search(
[node.name for node in nodes],
[node.name_embedding for node in nodes if node.name_embedding is not None],
driver,
search_filter,
[node.group_id for node in nodes],
query = (
RUNTIME_QUERY
+ """UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
WHERE score > $min_score
WITH node, n, score
ORDER BY score DESC
RETURN node.uuid AS search_node_uuid,
collect({
uuid: n.uuid,
name: n.name,
name_embedding: n.name_embedding,
group_id: n.group_id,
created_at: n.created_at,
summary: n.summary,
labels: labels(n),
attributes: properties(n)
})[..$limit] AS matches
"""
)
results, _, _ = await driver.execute_query(
query,
query_params,
nodes=[
{'uuid': node.uuid, 'name': node.name, 'name_embedding': node.name_embedding}
for node in nodes
],
group_id=group_id,
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
relevant_nodes_dict: dict[str, list[EntityNode]] = {
result['search_node_uuid']: [
get_entity_node_from_record(record) for record in result['matches']
]
for result in results
}
relevant_nodes = [relevant_nodes_dict.get(node.uuid, []) for node in nodes]
return relevant_nodes
async def get_relevant_edges(
driver: AsyncDriver,
edges: list[EntityEdge],
source_node_uuid: str | None,
target_node_uuid: str | None,
search_filter: SearchFilters,
min_score: float = DEFAULT_MIN_SCORE,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
start = time()
relevant_edges: list[EntityEdge] = []
relevant_edge_uuids = set()
) -> list[list[EntityEdge]]:
if len(edges) == 0:
return []
results = await semaphore_gather(
*[
edge_similarity_search(
driver,
edge.fact_embedding,
source_node_uuid,
target_node_uuid,
SearchFilters(),
[edge.group_id],
limit,
)
for edge in edges
if edge.fact_embedding is not None
]
query_params: dict[str, Any] = {}
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query_params.update(filter_params)
query = (
RUNTIME_QUERY
+ """UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
WHERE score > $min_score
WITH edge, e, score
ORDER BY score DESC
RETURN edge.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: startNode(e).uuid,
target_node_uuid: endNode(e).uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: e.fact_embedding,
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at
})[..$limit] AS matches
"""
)
for result in results:
for edge in result:
if edge.uuid in relevant_edge_uuids:
continue
results, _, _ = await driver.execute_query(
query,
query_params,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
relevant_edges_dict: dict[str, list[EntityEdge]] = {
result['search_edge_uuid']: [
get_entity_edge_from_record(record) for record in result['matches']
]
for result in results
}
relevant_edge_uuids.add(edge.uuid)
relevant_edges.append(edge)
end = time()
logger.debug(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
relevant_edges = [relevant_edges_dict.get(edge.uuid, []) for edge in edges]
return relevant_edges
async def get_edge_invalidation_candidates(
driver: AsyncDriver,
edges: list[EntityEdge],
search_filter: SearchFilters,
min_score: float = DEFAULT_MIN_SCORE,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[list[EntityEdge]]:
if len(edges) == 0:
return []
query_params: dict[str, Any] = {}
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query_params.update(filter_params)
query = (
RUNTIME_QUERY
+ """UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
WHERE score > $min_score
WITH edge, e, score
ORDER BY score DESC
RETURN edge.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: startNode(e).uuid,
target_node_uuid: endNode(e).uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: e.fact_embedding,
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at
})[..$limit] AS matches
"""
)
results, _, _ = await driver.execute_query(
query,
query_params,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
result['search_edge_uuid']: [
get_entity_edge_from_record(record) for record in result['matches']
]
for result in results
}
invalidation_edges = [invalidation_edges_dict.get(edge.uuid, []) for edge in edges]
return invalidation_edges
# takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
scores: dict[str, float] = defaultdict(float)

View file

@ -26,6 +26,7 @@ from pydantic import BaseModel
from typing_extensions import Any
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.models.edges.edge_db_queries import (
@ -128,16 +129,18 @@ async def add_nodes_and_edges_bulk_tx(
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])
await tx.run(
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
)
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges])
async def extract_nodes_and_edges_bulk(
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
clients: GraphitiClients, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
extracted_nodes_bulk = await semaphore_gather(
*[
extract_nodes(llm_client, episode, previous_episodes)
extract_nodes(clients, episode, previous_episodes)
for episode, previous_episodes in episode_tuples
]
)
@ -150,7 +153,7 @@ async def extract_nodes_and_edges_bulk(
extracted_edges_bulk = await semaphore_gather(
*[
extract_edges(
llm_client,
clients,
episode,
extracted_nodes_bulk[i],
previous_episodes_list[i],
@ -189,7 +192,7 @@ async def dedupe_nodes_bulk(
existing_nodes_chunks: list[list[EntityNode]] = list(
await semaphore_gather(
*[get_relevant_nodes(driver, SearchFilters(), node_chunk) for node_chunk in node_chunks]
*[get_relevant_nodes(driver, node_chunk, SearchFilters()) for node_chunk in node_chunks]
)
)
@ -223,7 +226,7 @@ async def dedupe_edges_bulk(
relevant_edges_chunks: list[list[EntityEdge]] = list(
await semaphore_gather(
*[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
*[get_relevant_edges(driver, edge_chunk, SearchFilters()) for edge_chunk in edge_chunks]
)
)

View file

@ -19,12 +19,15 @@ from datetime import datetime
from time import time
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
@ -39,7 +42,7 @@ def build_episodic_edges(
episode: EpisodicNode,
created_at: datetime,
) -> list[EpisodicEdge]:
edges: list[EpisodicEdge] = [
episodic_edges: list[EpisodicEdge] = [
EpisodicEdge(
source_node_uuid=episode.uuid,
target_node_uuid=node.uuid,
@ -49,7 +52,9 @@ def build_episodic_edges(
for node in entity_nodes
]
return edges
logger.debug(f'Built episodic edges: {episodic_edges}')
return episodic_edges
def build_community_edges(
@ -71,7 +76,7 @@ def build_community_edges(
async def extract_edges(
llm_client: LLMClient,
clients: GraphitiClients,
episode: EpisodicNode,
nodes: list[EntityNode],
previous_episodes: list[EpisodicNode],
@ -79,7 +84,9 @@ async def extract_edges(
) -> list[EntityEdge]:
start = time()
EXTRACT_EDGES_MAX_TOKENS = 16384
extract_edges_max_tokens = 16384
llm_client = clients.llm_client
embedder = clients.embedder
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
@ -97,7 +104,7 @@ async def extract_edges(
llm_response = await llm_client.generate_response(
prompt_library.extract_edges.edge(context),
response_model=ExtractedEdges,
max_tokens=EXTRACT_EDGES_MAX_TOKENS,
max_tokens=extract_edges_max_tokens,
)
edges_data = llm_response.get('edges', [])
@ -145,6 +152,11 @@ async def extract_edges(
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
)
# calculate embeddings
await semaphore_gather(*[edge.generate_embedding(embedder) for edge in edges])
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
return edges
@ -193,13 +205,26 @@ async def dedupe_extracted_edges(
async def resolve_extracted_edges(
llm_client: LLMClient,
clients: GraphitiClients,
extracted_edges: list[EntityEdge],
related_edges_lists: list[list[EntityEdge]],
existing_edges_lists: list[list[EntityEdge]],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> tuple[list[EntityEdge], list[EntityEdge]]:
driver = clients.driver
llm_client = clients.llm_client
related_edges_lists: list[list[EntityEdge]] = await get_relevant_edges(
driver, extracted_edges, SearchFilters(), 0.8
)
logger.debug(
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
)
edge_invalidation_candidates: list[list[EntityEdge]] = await get_edge_invalidation_candidates(
driver, extracted_edges, SearchFilters()
)
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
await semaphore_gather(
@ -213,7 +238,7 @@ async def resolve_extracted_edges(
previous_episodes,
)
for extracted_edge, related_edges, existing_edges in zip(
extracted_edges, related_edges_lists, existing_edges_lists, strict=False
extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=False
)
]
)
@ -228,6 +253,8 @@ async def resolve_extracted_edges(
resolved_edges.append(resolved_edge)
invalidated_edges.extend(invalidated_edge_chunk)
logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
return resolved_edges, invalidated_edges

View file

@ -117,6 +117,7 @@ async def retrieve_episodes(
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str] | None = None,
source: EpisodeType | None = None,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
@ -132,13 +133,17 @@ async def retrieve_episodes(
Returns:
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
"""
group_id_filter: LiteralString = 'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
group_id_filter: LiteralString = (
'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
)
source_filter: LiteralString = 'AND e.source = $source' if source is not None else ''
query: LiteralString = (
"""
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
"""
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
"""
+ group_id_filter
+ source_filter
+ """
RETURN e.content AS content,
e.created_at AS created_at,
@ -156,6 +161,7 @@ async def retrieve_episodes(
result = await driver.execute_query(
query,
reference_time=reference_time,
source=source,
num_episodes=last_n,
group_ids=group_ids,
database_=DEFAULT_DATABASE,

View file

@ -22,6 +22,7 @@ from typing import Any
import pydantic
from pydantic import BaseModel
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
@ -29,6 +30,8 @@ from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
from graphiti_core.prompts.summarize_nodes import Summary
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_relevant_nodes
from graphiti_core.utils.datetime_utils import utc_now
logger = logging.getLogger(__name__)
@ -116,12 +119,14 @@ async def extract_nodes_reflexion(
async def extract_nodes(
llm_client: LLMClient,
clients: GraphitiClients,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
entity_types: dict[str, BaseModel] | None = None,
) -> list[EntityNode]:
start = time()
llm_client = clients.llm_client
embedder = clients.embedder
extracted_node_names: list[str] = []
custom_prompt = ''
entities_missed = True
@ -138,7 +143,6 @@ async def extract_nodes(
elif episode.source == EpisodeType.json:
extracted_node_names = await extract_json_nodes(llm_client, episode, custom_prompt)
reflexion_iterations += 1
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
missing_entities = await extract_nodes_reflexion(
llm_client, episode, previous_episodes, extracted_node_names
@ -149,6 +153,7 @@ async def extract_nodes(
custom_prompt = 'The following entities were missed in a previous extraction: '
for entity in missing_entities:
custom_prompt += f'\n{entity},'
reflexion_iterations += 1
node_classification_context = {
'episode_content': episode.content,
@ -184,7 +189,7 @@ async def extract_nodes(
end = time()
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
# Convert the extracted data into EntityNode objects
new_nodes = []
extracted_nodes = []
for name in extracted_node_names:
entity_type = node_classifications.get(name)
if entity_types is not None and entity_type not in entity_types:
@ -203,10 +208,13 @@ async def extract_nodes(
summary='',
created_at=utc_now(),
)
new_nodes.append(new_node)
extracted_nodes.append(new_node)
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
return new_nodes
await semaphore_gather(*[node.generate_name_embedding(embedder) for node in extracted_nodes])
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
return extracted_nodes
async def dedupe_extracted_nodes(
@ -260,13 +268,20 @@ async def dedupe_extracted_nodes(
async def resolve_extracted_nodes(
llm_client: LLMClient,
clients: GraphitiClients,
extracted_nodes: list[EntityNode],
existing_nodes_lists: list[list[EntityNode]],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None,
) -> tuple[list[EntityNode], dict[str, str]]:
llm_client = clients.llm_client
driver = clients.driver
# Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = await get_relevant_nodes(
driver, extracted_nodes, SearchFilters(), 0.8
)
uuid_map: dict[str, str] = {}
resolved_nodes: list[EntityNode] = []
results: list[tuple[EntityNode, dict[str, str]]] = list(
@ -291,6 +306,8 @@ async def resolve_extracted_nodes(
uuid_map.update(result[1])
resolved_nodes.append(result[0])
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
return resolved_nodes, uuid_map

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.10.5"
version = "0.10.6"
authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },