From a26b25dc06dd53e70a465515a25b41c8c7de7f2f Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Sat, 26 Apr 2025 00:24:23 -0400 Subject: [PATCH] Add episode refactor (#399) * partial refactor * get relevant nodes refactor * load edges updates * refactor triplets * not there yet * node search update * working refactor * updates * mypy * mypy --- .env.example | 6 + graphiti_core/edges.py | 83 ++---- graphiti_core/graphiti.py | 159 ++--------- graphiti_core/graphiti_types.py | 31 ++ graphiti_core/helpers.py | 7 +- graphiti_core/models/edges/edge_db_queries.py | 2 +- graphiti_core/search/search.py | 18 +- graphiti_core/search/search_utils.py | 269 ++++++++++++------ graphiti_core/utils/bulk_utils.py | 17 +- .../utils/maintenance/edge_operations.py | 45 ++- .../maintenance/graph_data_operations.py | 12 +- .../utils/maintenance/node_operations.py | 31 +- pyproject.toml | 2 +- 13 files changed, 380 insertions(+), 302 deletions(-) create mode 100644 graphiti_core/graphiti_types.py diff --git a/.env.example b/.env.example index 110a33ef..910972b3 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,10 @@ OPENAI_API_KEY= NEO4J_URI= +NEO4J_PORT= NEO4J_USER= NEO4J_PASSWORD= +DEFAULT_DATABASE= +USE_PARALLEL_RUNTIME= +SEMAPHORE_LIMIT= +GITHUB_SHA= +MAX_REFLEXION_ITERATIONS= \ No newline at end of file diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 7e7c1868..108c143e 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -37,6 +37,21 @@ from graphiti_core.nodes import Node logger = logging.getLogger(__name__) +ENTITY_EDGE_RETURN: LiteralString = """ + RETURN + e.uuid AS uuid, + startNode(e).uuid AS source_node_uuid, + endNode(e).uuid AS target_node_uuid, + e.created_at AS created_at, + e.name AS name, + e.group_id AS group_id, + e.fact AS fact, + e.fact_embedding AS fact_embedding, + e.episodes AS episodes, + e.expired_at AS expired_at, + e.valid_at AS valid_at, + e.invalid_at AS invalid_at""" + class Edge(BaseModel, ABC): uuid: str = Field(default_factory=lambda: str(uuid4())) @@ -234,20 +249,8 @@ class EntityEdge(Edge): records, _, _ = await driver.execute_query( """ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) - RETURN - e.uuid AS uuid, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at, - e.name AS name, - e.group_id AS group_id, - e.fact AS fact, - e.fact_embedding AS fact_embedding, - e.episodes AS episodes, - e.expired_at AS expired_at, - e.valid_at AS valid_at, - e.invalid_at AS invalid_at - """, + """ + + ENTITY_EDGE_RETURN, uuid=uuid, database_=DEFAULT_DATABASE, routing_='r', @@ -268,20 +271,8 @@ class EntityEdge(Edge): """ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) WHERE e.uuid IN $uuids - RETURN - e.uuid AS uuid, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at, - e.name AS name, - e.group_id AS group_id, - e.fact AS fact, - e.fact_embedding AS fact_embedding, - e.episodes AS episodes, - e.expired_at AS expired_at, - e.valid_at AS valid_at, - e.invalid_at AS invalid_at - """, + """ + + ENTITY_EDGE_RETURN, uuids=uuids, database_=DEFAULT_DATABASE, routing_='r', @@ -308,20 +299,8 @@ class EntityEdge(Edge): WHERE e.group_id IN $group_ids """ + cursor_query + + ENTITY_EDGE_RETURN + """ - RETURN - e.uuid AS uuid, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at, - e.name AS name, - e.group_id AS group_id, - e.fact AS fact, - e.fact_embedding AS fact_embedding, - e.episodes AS episodes, - e.expired_at AS expired_at, - e.valid_at AS valid_at, - e.invalid_at AS invalid_at ORDER BY e.uuid DESC """ + limit_query, @@ -340,22 +319,12 @@ class EntityEdge(Edge): @classmethod async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str): - query: LiteralString = """ - MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) - RETURN DISTINCT - e.uuid AS uuid, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at, - e.name AS name, - e.group_id AS group_id, - e.fact AS fact, - e.fact_embedding AS fact_embedding, - e.episodes AS episodes, - e.expired_at AS expired_at, - e.valid_at AS valid_at, - e.invalid_at AS invalid_at - """ + query: LiteralString = ( + """ + MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) + """ + + ENTITY_EDGE_RETURN + ) records, _, _ = await driver.execute_query( query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r' ) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 913badcf..3dcb6910 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -27,6 +27,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder +from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode @@ -42,7 +43,6 @@ from graphiti_core.search.search_utils import ( RELEVANT_SCHEMA_LIMIT, get_mentioned_nodes, get_relevant_edges, - get_relevant_nodes, ) from graphiti_core.utils.bulk_utils import ( RawEpisode, @@ -150,6 +150,13 @@ class Graphiti: else: self.cross_encoder = OpenAIRerankerClient() + self.clients = GraphitiClients( + driver=self.driver, + llm_client=self.llm_client, + embedder=self.embedder, + cross_encoder=self.cross_encoder, + ) + async def close(self): """ Close the connection to the Neo4j database. @@ -222,6 +229,7 @@ class Graphiti: reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, group_ids: list[str] | None = None, + source: EpisodeType | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -248,7 +256,7 @@ class Graphiti: The actual retrieval is performed by the `retrieve_episodes` function from the `graphiti_core.utils` module. """ - return await retrieve_episodes(self.driver, reference_time, last_n, group_ids) + return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source) async def add_episode( self, @@ -314,15 +322,16 @@ class Graphiti: """ try: start = time() - - entity_edges: list[EntityEdge] = [] now = utc_now() validate_entity_types(entity_types) previous_episodes = ( await self.retrieve_episodes( - reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id] + reference_time, + last_n=RELEVANT_SCHEMA_LIMIT, + group_ids=[group_id], + source=source, ) if previous_episode_uuids is None else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids) @@ -346,132 +355,35 @@ class Graphiti: # Extract entities as nodes extracted_nodes = await extract_nodes( - self.llm_client, episode, previous_episodes, entity_types - ) - logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') - - # Calculate Embeddings - - await semaphore_gather( - *[node.generate_name_embedding(self.embedder) for node in extracted_nodes] + self.clients, episode, previous_episodes, entity_types ) - # Find relevant nodes already in the graph - existing_nodes_lists: list[list[EntityNode]] = list( - await semaphore_gather( - *[ - get_relevant_nodes(self.driver, SearchFilters(), [node]) - for node in extracted_nodes - ] - ) - ) - - # Resolve extracted nodes with nodes already in the graph and extract facts - logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') - - (mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather( + # Extract edges and resolve nodes + (nodes, uuid_map), extracted_edges = await semaphore_gather( resolve_extracted_nodes( - self.llm_client, + self.clients, extracted_nodes, - existing_nodes_lists, episode, previous_episodes, entity_types, ), - extract_edges( - self.llm_client, episode, extracted_nodes, previous_episodes, group_id - ), + extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id), ) - logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}') - nodes = mentioned_nodes extracted_edges_with_resolved_pointers = resolve_edge_pointers( extracted_edges, uuid_map ) - # calculate embeddings - await semaphore_gather( - *[ - edge.generate_embedding(self.embedder) - for edge in extracted_edges_with_resolved_pointers - ] - ) - - # Resolve extracted edges with related edges already in the graph - related_edges_list: list[list[EntityEdge]] = list( - await semaphore_gather( - *[ - get_relevant_edges( - self.driver, - [edge], - edge.source_node_uuid, - edge.target_node_uuid, - RELEVANT_SCHEMA_LIMIT, - ) - for edge in extracted_edges_with_resolved_pointers - ] - ) - ) - logger.debug( - f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}' - ) - logger.debug( - f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}' - ) - - existing_source_edges_list: list[list[EntityEdge]] = list( - await semaphore_gather( - *[ - get_relevant_edges( - self.driver, - [edge], - edge.source_node_uuid, - None, - RELEVANT_SCHEMA_LIMIT, - ) - for edge in extracted_edges_with_resolved_pointers - ] - ) - ) - - existing_target_edges_list: list[list[EntityEdge]] = list( - await semaphore_gather( - *[ - get_relevant_edges( - self.driver, - [edge], - None, - edge.target_node_uuid, - RELEVANT_SCHEMA_LIMIT, - ) - for edge in extracted_edges_with_resolved_pointers - ] - ) - ) - - existing_edges_list: list[list[EntityEdge]] = [ - source_lst + target_lst - for source_lst, target_lst in zip( - existing_source_edges_list, existing_target_edges_list, strict=False - ) - ] - resolved_edges, invalidated_edges = await resolve_extracted_edges( - self.llm_client, + self.clients, extracted_edges_with_resolved_pointers, - related_edges_list, - existing_edges_list, episode, previous_episodes, ) - entity_edges.extend(resolved_edges + invalidated_edges) + entity_edges = resolved_edges + invalidated_edges - logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}') - - episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now) - - logger.debug(f'Built episodic edges: {episodic_edges}') + episodic_edges = build_episodic_edges(nodes, episode, now) episode.entity_edges = [edge.uuid for edge in entity_edges] @@ -565,7 +477,7 @@ class Graphiti: extracted_nodes, extracted_edges, episodic_edges, - ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs) + ) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs) # Generate embeddings await semaphore_gather( @@ -684,9 +596,7 @@ class Graphiti: edges = ( await search( - self.driver, - self.embedder, - self.cross_encoder, + self.clients, query, group_ids, search_config, @@ -728,9 +638,7 @@ class Graphiti: """ return await search( - self.driver, - self.embedder, - self.cross_encoder, + self.clients, query, group_ids, config, @@ -761,26 +669,17 @@ class Graphiti: await edge.generate_embedding(self.embedder) resolved_nodes, uuid_map = await resolve_extracted_nodes( - self.llm_client, + self.clients, [source_node, target_node], - [ - await get_relevant_nodes(self.driver, SearchFilters(), [source_node]), - await get_relevant_nodes(self.driver, SearchFilters(), [target_node]), - ], ) updated_edge = resolve_edge_pointers([edge], uuid_map)[0] - related_edges = await get_relevant_edges( - self.driver, - [updated_edge], - source_node_uuid=resolved_nodes[0].uuid, - target_node_uuid=resolved_nodes[1].uuid, - ) + related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8) - resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges) + resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges[0]) - contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges) + contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0]) invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges) await add_nodes_and_edges_bulk( diff --git a/graphiti_core/graphiti_types.py b/graphiti_core/graphiti_types.py new file mode 100644 index 00000000..c765ee63 --- /dev/null +++ b/graphiti_core/graphiti_types.py @@ -0,0 +1,31 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from neo4j import AsyncDriver +from pydantic import BaseModel, ConfigDict + +from graphiti_core.cross_encoder import CrossEncoderClient +from graphiti_core.embedder import EmbedderClient +from graphiti_core.llm_client import LLMClient + + +class GraphitiClients(BaseModel): + driver: AsyncDriver + llm_client: LLMClient + embedder: EmbedderClient + cross_encoder: CrossEncoderClient + + model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 2c2abadc..43810200 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -22,15 +22,20 @@ from datetime import datetime import numpy as np from dotenv import load_dotenv from neo4j import time as neo4j_time +from typing_extensions import LiteralString load_dotenv() DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None) USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False)) SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20)) -MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 2)) +MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 1)) DEFAULT_PAGE_LIMIT = 20 +RUNTIME_QUERY: LiteralString = ( + 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else '' +) + def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None: return neo_date.to_native() if neo_date else None diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index 49567731..ba687219 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -47,7 +47,7 @@ ENTITY_EDGE_SAVE_BULK = """ SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes, created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at} WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding) - RETURN r.uuid AS uuid + RETURN edge.uuid AS uuid """ COMMUNITY_EDGE_SAVE = """ diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index ed50b9b2..579da576 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -22,8 +22,8 @@ from neo4j import AsyncDriver from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.edges import EntityEdge -from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import SearchRerankerError +from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import semaphore_gather from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode from graphiti_core.search.search_config import ( @@ -62,17 +62,21 @@ logger = logging.getLogger(__name__) async def search( - driver: AsyncDriver, - embedder: EmbedderClient, - cross_encoder: CrossEncoderClient, + clients: GraphitiClients, query: str, group_ids: list[str] | None, config: SearchConfig, search_filter: SearchFilters, center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, + query_vector: list[float] | None = None, ) -> SearchResults: start = time() + + driver = clients.driver + embedder = clients.embedder + cross_encoder = clients.cross_encoder + if query.strip() == '': return SearchResults( edges=[], @@ -80,7 +84,11 @@ async def search( episodes=[], communities=[], ) - query_vector = await embedder.create(input_data=[query.replace('\n', ' ')]) + query_vector = ( + query_vector + if query_vector is not None + else await embedder.create(input_data=[query.replace('\n', ' ')]) + ) # if group_ids is empty, set it to None group_ids = group_ids if group_ids else None diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 9a8079eb..56f87d97 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -26,7 +26,7 @@ from typing_extensions import LiteralString from graphiti_core.edges import EntityEdge, get_entity_edge_from_record from graphiti_core.helpers import ( DEFAULT_DATABASE, - USE_PARALLEL_RUNTIME, + RUNTIME_QUERY, lucene_sanitize, normalize_l2, semaphore_gather, @@ -207,10 +207,6 @@ async def edge_similarity_search( min_score: float = DEFAULT_MIN_SCORE, ) -> list[EntityEdge]: # vector similarity search over embedded facts - runtime_query: LiteralString = ( - 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else '' - ) - query_params: dict[str, Any] = {} filter_query, filter_params = edge_search_filter_query_constructor(search_filter) @@ -230,9 +226,10 @@ async def edge_similarity_search( group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])' query: LiteralString = ( - """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + RUNTIME_QUERY + + """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + group_filter_query + filter_query + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score @@ -256,7 +253,7 @@ async def edge_similarity_search( ) records, _, _ = await driver.execute_query( - runtime_query + query, + query, query_params, search_vector=search_vector, source_uuid=source_node_uuid, @@ -344,10 +341,10 @@ async def node_fulltext_search( query = ( """ - CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) - YIELD node AS n, score - WHERE n:Entity - """ + CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) + YIELD node AS n, score + WHERE n:Entity + """ + filter_query + ENTITY_NODE_RETURN + """ @@ -378,10 +375,6 @@ async def node_similarity_search( min_score: float = DEFAULT_MIN_SCORE, ) -> list[EntityNode]: # vector similarity search over entity names - runtime_query: LiteralString = ( - 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else '' - ) - query_params: dict[str, Any] = {} group_filter_query: LiteralString = '' @@ -393,7 +386,7 @@ async def node_similarity_search( query_params.update(filter_params) records, _, _ = await driver.execute_query( - runtime_query + RUNTIME_QUERY + """ MATCH (n:Entity) """ @@ -542,10 +535,6 @@ async def community_similarity_search( min_score=DEFAULT_MIN_SCORE, ) -> list[CommunityNode]: # vector similarity search over entity names - runtime_query: LiteralString = ( - 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else '' - ) - query_params: dict[str, Any] = {} group_filter_query: LiteralString = '' @@ -554,7 +543,7 @@ async def community_similarity_search( query_params['group_ids'] = group_ids records, _, _ = await driver.execute_query( - runtime_query + RUNTIME_QUERY + """ MATCH (comm:Community) """ @@ -660,86 +649,204 @@ async def hybrid_node_search( async def get_relevant_nodes( driver: AsyncDriver, - search_filter: SearchFilters, nodes: list[EntityNode], -) -> list[EntityNode]: - """ - Retrieve relevant nodes based on the provided list of EntityNodes. + search_filter: SearchFilters, + min_score: float = DEFAULT_MIN_SCORE, + limit: int = RELEVANT_SCHEMA_LIMIT, +) -> list[list[EntityNode]]: + if len(nodes) == 0: + return [] - This method performs a hybrid search using both the names and embeddings - of the input nodes to find relevant nodes in the graph database. + group_id = nodes[0].group_id - Parameters - ---------- - nodes : list[EntityNode] - A list of EntityNode objects to use as the basis for the search. - driver : AsyncDriver - The Neo4j driver instance for database operations. + # vector similarity search over entity names + query_params: dict[str, Any] = {} - Returns - ------- - list[EntityNode] - A list of EntityNode objects that are deemed relevant based on the input nodes. + filter_query, filter_params = node_search_filter_query_constructor(search_filter) + query_params.update(filter_params) - Notes - ----- - This method uses the hybrid_node_search function to perform the search, - which combines fulltext search and vector similarity search. - It extracts the names and name embeddings (if available) from the input nodes - to use as search criteria. - """ - relevant_nodes = await hybrid_node_search( - [node.name for node in nodes], - [node.name_embedding for node in nodes if node.name_embedding is not None], - driver, - search_filter, - [node.group_id for node in nodes], + query = ( + RUNTIME_QUERY + + """UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + + filter_query + + """ + WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score + WHERE score > $min_score + WITH node, n, score + ORDER BY score DESC + RETURN node.uuid AS search_node_uuid, + collect({ + uuid: n.uuid, + name: n.name, + name_embedding: n.name_embedding, + group_id: n.group_id, + created_at: n.created_at, + summary: n.summary, + labels: labels(n), + attributes: properties(n) + })[..$limit] AS matches + """ ) + results, _, _ = await driver.execute_query( + query, + query_params, + nodes=[ + {'uuid': node.uuid, 'name': node.name, 'name_embedding': node.name_embedding} + for node in nodes + ], + group_id=group_id, + limit=limit, + min_score=min_score, + database_=DEFAULT_DATABASE, + routing_='r', + ) + + relevant_nodes_dict: dict[str, list[EntityNode]] = { + result['search_node_uuid']: [ + get_entity_node_from_record(record) for record in result['matches'] + ] + for result in results + } + + relevant_nodes = [relevant_nodes_dict.get(node.uuid, []) for node in nodes] + return relevant_nodes async def get_relevant_edges( driver: AsyncDriver, edges: list[EntityEdge], - source_node_uuid: str | None, - target_node_uuid: str | None, + search_filter: SearchFilters, + min_score: float = DEFAULT_MIN_SCORE, limit: int = RELEVANT_SCHEMA_LIMIT, -) -> list[EntityEdge]: - start = time() - relevant_edges: list[EntityEdge] = [] - relevant_edge_uuids = set() +) -> list[list[EntityEdge]]: + if len(edges) == 0: + return [] - results = await semaphore_gather( - *[ - edge_similarity_search( - driver, - edge.fact_embedding, - source_node_uuid, - target_node_uuid, - SearchFilters(), - [edge.group_id], - limit, - ) - for edge in edges - if edge.fact_embedding is not None - ] + query_params: dict[str, Any] = {} + + filter_query, filter_params = edge_search_filter_query_constructor(search_filter) + query_params.update(filter_params) + + query = ( + RUNTIME_QUERY + + """UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + + filter_query + + """ + WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score + WHERE score > $min_score + WITH edge, e, score + ORDER BY score DESC + RETURN edge.uuid AS search_edge_uuid, + collect({ + uuid: e.uuid, + source_node_uuid: startNode(e).uuid, + target_node_uuid: endNode(e).uuid, + created_at: e.created_at, + name: e.name, + group_id: e.group_id, + fact: e.fact, + fact_embedding: e.fact_embedding, + episodes: e.episodes, + expired_at: e.expired_at, + valid_at: e.valid_at, + invalid_at: e.invalid_at + })[..$limit] AS matches + """ ) - for result in results: - for edge in result: - if edge.uuid in relevant_edge_uuids: - continue + results, _, _ = await driver.execute_query( + query, + query_params, + edges=[edge.model_dump() for edge in edges], + limit=limit, + min_score=min_score, + database_=DEFAULT_DATABASE, + routing_='r', + ) + relevant_edges_dict: dict[str, list[EntityEdge]] = { + result['search_edge_uuid']: [ + get_entity_edge_from_record(record) for record in result['matches'] + ] + for result in results + } - relevant_edge_uuids.add(edge.uuid) - relevant_edges.append(edge) - - end = time() - logger.debug(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms') + relevant_edges = [relevant_edges_dict.get(edge.uuid, []) for edge in edges] return relevant_edges +async def get_edge_invalidation_candidates( + driver: AsyncDriver, + edges: list[EntityEdge], + search_filter: SearchFilters, + min_score: float = DEFAULT_MIN_SCORE, + limit: int = RELEVANT_SCHEMA_LIMIT, +) -> list[list[EntityEdge]]: + if len(edges) == 0: + return [] + + query_params: dict[str, Any] = {} + + filter_query, filter_params = edge_search_filter_query_constructor(search_filter) + query_params.update(filter_params) + + query = ( + RUNTIME_QUERY + + """UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + + filter_query + + """ + WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score + WHERE score > $min_score + WITH edge, e, score + ORDER BY score DESC + RETURN edge.uuid AS search_edge_uuid, + collect({ + uuid: e.uuid, + source_node_uuid: startNode(e).uuid, + target_node_uuid: endNode(e).uuid, + created_at: e.created_at, + name: e.name, + group_id: e.group_id, + fact: e.fact, + fact_embedding: e.fact_embedding, + episodes: e.episodes, + expired_at: e.expired_at, + valid_at: e.valid_at, + invalid_at: e.invalid_at + })[..$limit] AS matches + """ + ) + + results, _, _ = await driver.execute_query( + query, + query_params, + edges=[edge.model_dump() for edge in edges], + limit=limit, + min_score=min_score, + database_=DEFAULT_DATABASE, + routing_='r', + ) + invalidation_edges_dict: dict[str, list[EntityEdge]] = { + result['search_edge_uuid']: [ + get_entity_edge_from_record(record) for record in result['matches'] + ] + for result in results + } + + invalidation_edges = [invalidation_edges_dict.get(edge.uuid, []) for edge in edges] + + return invalidation_edges + + # takes in a list of rankings of uuids def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]: scores: dict[str, float] = defaultdict(float) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 86cd8311..aafb54b5 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -26,6 +26,7 @@ from pydantic import BaseModel from typing_extensions import Any from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge +from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.models.edges.edge_db_queries import ( @@ -128,16 +129,18 @@ async def add_nodes_and_edges_bulk_tx( await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes) await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes) - await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges]) - await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges]) + await tx.run( + EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges] + ) + await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges]) async def extract_nodes_and_edges_bulk( - llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] + clients: GraphitiClients, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]: extracted_nodes_bulk = await semaphore_gather( *[ - extract_nodes(llm_client, episode, previous_episodes) + extract_nodes(clients, episode, previous_episodes) for episode, previous_episodes in episode_tuples ] ) @@ -150,7 +153,7 @@ async def extract_nodes_and_edges_bulk( extracted_edges_bulk = await semaphore_gather( *[ extract_edges( - llm_client, + clients, episode, extracted_nodes_bulk[i], previous_episodes_list[i], @@ -189,7 +192,7 @@ async def dedupe_nodes_bulk( existing_nodes_chunks: list[list[EntityNode]] = list( await semaphore_gather( - *[get_relevant_nodes(driver, SearchFilters(), node_chunk) for node_chunk in node_chunks] + *[get_relevant_nodes(driver, node_chunk, SearchFilters()) for node_chunk in node_chunks] ) ) @@ -223,7 +226,7 @@ async def dedupe_edges_bulk( relevant_edges_chunks: list[list[EntityEdge]] = list( await semaphore_gather( - *[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks] + *[get_relevant_edges(driver, edge_chunk, SearchFilters()) for edge_chunk in edge_chunks] ) ) diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index d3c6d0d2..97cd1c21 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -19,12 +19,15 @@ from datetime import datetime from time import time from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge +from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts +from graphiti_core.search.search_filters import SearchFilters +from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges from graphiti_core.utils.datetime_utils import utc_now from graphiti_core.utils.maintenance.temporal_operations import ( extract_edge_dates, @@ -39,7 +42,7 @@ def build_episodic_edges( episode: EpisodicNode, created_at: datetime, ) -> list[EpisodicEdge]: - edges: list[EpisodicEdge] = [ + episodic_edges: list[EpisodicEdge] = [ EpisodicEdge( source_node_uuid=episode.uuid, target_node_uuid=node.uuid, @@ -49,7 +52,9 @@ def build_episodic_edges( for node in entity_nodes ] - return edges + logger.debug(f'Built episodic edges: {episodic_edges}') + + return episodic_edges def build_community_edges( @@ -71,7 +76,7 @@ def build_community_edges( async def extract_edges( - llm_client: LLMClient, + clients: GraphitiClients, episode: EpisodicNode, nodes: list[EntityNode], previous_episodes: list[EpisodicNode], @@ -79,7 +84,9 @@ async def extract_edges( ) -> list[EntityEdge]: start = time() - EXTRACT_EDGES_MAX_TOKENS = 16384 + extract_edges_max_tokens = 16384 + llm_client = clients.llm_client + embedder = clients.embedder node_uuids_by_name_map = {node.name: node.uuid for node in nodes} @@ -97,7 +104,7 @@ async def extract_edges( llm_response = await llm_client.generate_response( prompt_library.extract_edges.edge(context), response_model=ExtractedEdges, - max_tokens=EXTRACT_EDGES_MAX_TOKENS, + max_tokens=extract_edges_max_tokens, ) edges_data = llm_response.get('edges', []) @@ -145,6 +152,11 @@ async def extract_edges( f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})' ) + # calculate embeddings + await semaphore_gather(*[edge.generate_embedding(embedder) for edge in edges]) + + logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}') + return edges @@ -193,13 +205,26 @@ async def dedupe_extracted_edges( async def resolve_extracted_edges( - llm_client: LLMClient, + clients: GraphitiClients, extracted_edges: list[EntityEdge], - related_edges_lists: list[list[EntityEdge]], - existing_edges_lists: list[list[EntityEdge]], current_episode: EpisodicNode, previous_episodes: list[EpisodicNode], ) -> tuple[list[EntityEdge], list[EntityEdge]]: + driver = clients.driver + llm_client = clients.llm_client + + related_edges_lists: list[list[EntityEdge]] = await get_relevant_edges( + driver, extracted_edges, SearchFilters(), 0.8 + ) + + logger.debug( + f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}' + ) + + edge_invalidation_candidates: list[list[EntityEdge]] = await get_edge_invalidation_candidates( + driver, extracted_edges, SearchFilters() + ) + # resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates results: list[tuple[EntityEdge, list[EntityEdge]]] = list( await semaphore_gather( @@ -213,7 +238,7 @@ async def resolve_extracted_edges( previous_episodes, ) for extracted_edge, related_edges, existing_edges in zip( - extracted_edges, related_edges_lists, existing_edges_lists, strict=False + extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=False ) ] ) @@ -228,6 +253,8 @@ async def resolve_extracted_edges( resolved_edges.append(resolved_edge) invalidated_edges.extend(invalidated_edge_chunk) + logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}') + return resolved_edges, invalidated_edges diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 32e64a30..dd71e6e3 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -117,6 +117,7 @@ async def retrieve_episodes( reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, group_ids: list[str] | None = None, + source: EpisodeType | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -132,13 +133,17 @@ async def retrieve_episodes( Returns: list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes. """ - group_id_filter: LiteralString = 'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else '' + group_id_filter: LiteralString = ( + 'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else '' + ) + source_filter: LiteralString = 'AND e.source = $source' if source is not None else '' query: LiteralString = ( """ - MATCH (e:Episodic) WHERE e.valid_at <= $reference_time - """ + MATCH (e:Episodic) WHERE e.valid_at <= $reference_time + """ + group_id_filter + + source_filter + """ RETURN e.content AS content, e.created_at AS created_at, @@ -156,6 +161,7 @@ async def retrieve_episodes( result = await driver.execute_query( query, reference_time=reference_time, + source=source, num_episodes=last_n, group_ids=group_ids, database_=DEFAULT_DATABASE, diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index e3700b91..69233cc8 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -22,6 +22,7 @@ from typing import Any import pydantic from pydantic import BaseModel +from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode @@ -29,6 +30,8 @@ from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_nodes import NodeDuplicate from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities from graphiti_core.prompts.summarize_nodes import Summary +from graphiti_core.search.search_filters import SearchFilters +from graphiti_core.search.search_utils import get_relevant_nodes from graphiti_core.utils.datetime_utils import utc_now logger = logging.getLogger(__name__) @@ -116,12 +119,14 @@ async def extract_nodes_reflexion( async def extract_nodes( - llm_client: LLMClient, + clients: GraphitiClients, episode: EpisodicNode, previous_episodes: list[EpisodicNode], entity_types: dict[str, BaseModel] | None = None, ) -> list[EntityNode]: start = time() + llm_client = clients.llm_client + embedder = clients.embedder extracted_node_names: list[str] = [] custom_prompt = '' entities_missed = True @@ -138,7 +143,6 @@ async def extract_nodes( elif episode.source == EpisodeType.json: extracted_node_names = await extract_json_nodes(llm_client, episode, custom_prompt) - reflexion_iterations += 1 if reflexion_iterations < MAX_REFLEXION_ITERATIONS: missing_entities = await extract_nodes_reflexion( llm_client, episode, previous_episodes, extracted_node_names @@ -149,6 +153,7 @@ async def extract_nodes( custom_prompt = 'The following entities were missed in a previous extraction: ' for entity in missing_entities: custom_prompt += f'\n{entity},' + reflexion_iterations += 1 node_classification_context = { 'episode_content': episode.content, @@ -184,7 +189,7 @@ async def extract_nodes( end = time() logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms') # Convert the extracted data into EntityNode objects - new_nodes = [] + extracted_nodes = [] for name in extracted_node_names: entity_type = node_classifications.get(name) if entity_types is not None and entity_type not in entity_types: @@ -203,10 +208,13 @@ async def extract_nodes( summary='', created_at=utc_now(), ) - new_nodes.append(new_node) + extracted_nodes.append(new_node) logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') - return new_nodes + await semaphore_gather(*[node.generate_name_embedding(embedder) for node in extracted_nodes]) + + logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') + return extracted_nodes async def dedupe_extracted_nodes( @@ -260,13 +268,20 @@ async def dedupe_extracted_nodes( async def resolve_extracted_nodes( - llm_client: LLMClient, + clients: GraphitiClients, extracted_nodes: list[EntityNode], - existing_nodes_lists: list[list[EntityNode]], episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, entity_types: dict[str, BaseModel] | None = None, ) -> tuple[list[EntityNode], dict[str, str]]: + llm_client = clients.llm_client + driver = clients.driver + + # Find relevant nodes already in the graph + existing_nodes_lists: list[list[EntityNode]] = await get_relevant_nodes( + driver, extracted_nodes, SearchFilters(), 0.8 + ) + uuid_map: dict[str, str] = {} resolved_nodes: list[EntityNode] = [] results: list[tuple[EntityNode, dict[str, str]]] = list( @@ -291,6 +306,8 @@ async def resolve_extracted_nodes( uuid_map.update(result[1]) resolved_nodes.append(result[0]) + logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}') + return resolved_nodes, uuid_map diff --git a/pyproject.toml b/pyproject.toml index 87b68b5c..e0ef3414 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.10.5" +version = "0.10.6" authors = [ { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },