diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index e392a0a5..72c16d7a 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -46,7 +46,6 @@ ENTITY_EDGE_RETURN: LiteralString = """ e.name AS name, e.group_id AS group_id, e.fact AS fact, - e.fact_embedding AS fact_embedding, e.episodes AS episodes, e.expired_at AS expired_at, e.valid_at AS valid_at, @@ -222,6 +221,20 @@ class EntityEdge(Edge): return self.fact_embedding + async def load_fact_embedding(self, driver: AsyncDriver): + query: LiteralString = """ + MATCH (n:Entity)-[e:MENTIONS {uuid: $uuid}]->(m:Entity) + RETURN e.fact_embedding AS fact_embedding + """ + records, _, _ = await driver.execute_query( + query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r' + ) + + if len(records) == 0: + raise EdgeNotFoundError(self.uuid) + + self.fact_embedding = records[0]['fact_embedding'] + async def save(self, driver: AsyncDriver): result = await driver.execute_query( ENTITY_EDGE_SAVE, @@ -321,8 +334,8 @@ class EntityEdge(Edge): async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str): query: LiteralString = ( """ - MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) - """ + MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) + """ + ENTITY_EDGE_RETURN ) records, _, _ = await driver.execute_query( @@ -452,7 +465,6 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge: name=record['name'], group_id=record['group_id'], episodes=record['episodes'], - fact_embedding=record['fact_embedding'], created_at=record['created_at'].to_native(), expired_at=parse_db_date(record['expired_at']), valid_at=parse_db_date(record['valid_at']), @@ -471,6 +483,8 @@ def get_community_edge_from_record(record: Any): async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]): + if len(edges) == 0: + return fact_embeddings = await embedder.create_batch([edge.fact for edge in edges]) for edge, fact_embedding in zip(edges, fact_embeddings, strict=True): edge.fact_embedding = fact_embedding diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 3e53584c..58ec41ce 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -42,7 +42,6 @@ ENTITY_NODE_RETURN: LiteralString = """ RETURN n.uuid As uuid, n.name AS name, - n.name_embedding AS name_embedding, n.group_id AS group_id, n.created_at AS created_at, n.summary AS summary, @@ -305,6 +304,20 @@ class EntityNode(Node): return self.name_embedding + async def load_name_embedding(self, driver: AsyncDriver): + query: LiteralString = """ + MATCH (n:Entity {uuid: $uuid}) + RETURN n.name_embedding AS name_embedding + """ + records, _, _ = await driver.execute_query( + query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r' + ) + + if len(records) == 0: + raise NodeNotFoundError(self.uuid) + + self.name_embedding = records[0]['name_embedding'] + async def save(self, driver: AsyncDriver): entity_data: dict[str, Any] = { 'uuid': self.uuid, @@ -332,8 +345,8 @@ class EntityNode(Node): async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): query = ( """ - MATCH (n:Entity {uuid: $uuid}) - """ + MATCH (n:Entity {uuid: $uuid}) + """ + ENTITY_NODE_RETURN ) records, _, _ = await driver.execute_query( @@ -428,6 +441,20 @@ class CommunityNode(Node): return self.name_embedding + async def load_name_embedding(self, driver: AsyncDriver): + query: LiteralString = """ + MATCH (c:Community {uuid: $uuid}) + RETURN c.name_embedding AS name_embedding + """ + records, _, _ = await driver.execute_query( + query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r' + ) + + if len(records) == 0: + raise NodeNotFoundError(self.uuid) + + self.name_embedding = records[0]['name_embedding'] + @classmethod async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): records, _, _ = await driver.execute_query( @@ -436,7 +463,6 @@ class CommunityNode(Node): RETURN n.uuid As uuid, n.name AS name, - n.name_embedding AS name_embedding, n.group_id AS group_id, n.created_at AS created_at, n.summary AS summary @@ -461,7 +487,6 @@ class CommunityNode(Node): RETURN n.uuid As uuid, n.name AS name, - n.name_embedding AS name_embedding, n.group_id AS group_id, n.created_at AS created_at, n.summary AS summary @@ -495,7 +520,6 @@ class CommunityNode(Node): RETURN n.uuid As uuid, n.name AS name, - n.name_embedding AS name_embedding, n.group_id AS group_id, n.created_at AS created_at, n.summary AS summary @@ -534,7 +558,6 @@ def get_entity_node_from_record(record: Any) -> EntityNode: uuid=record['uuid'], name=record['name'], group_id=record['group_id'], - name_embedding=record['name_embedding'], labels=record['labels'], created_at=record['created_at'].to_native(), summary=record['summary'], diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 579da576..7f8fc2d1 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -209,6 +209,9 @@ async def edge_search( reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == EdgeReranker.mmr: + await semaphore_gather( + *[edge.load_fact_embedding(driver) for result in search_results for edge in result] + ) search_result_uuids_and_vectors = [ (edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024) for result in search_results @@ -308,6 +311,9 @@ async def node_search( if config.reranker == NodeReranker.rrf: reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == NodeReranker.mmr: + await semaphore_gather( + *[node.load_name_embedding(driver) for result in search_results for node in result] + ) search_result_uuids_and_vectors = [ (node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024) for result in search_results @@ -431,6 +437,13 @@ async def community_search( if config.reranker == CommunityReranker.rrf: reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score) elif config.reranker == CommunityReranker.mmr: + await semaphore_gather( + *[ + community.load_name_embedding(driver) + for result in search_results + for community in result + ] + ) search_result_uuids_and_vectors = [ ( community.uuid, diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 81a713ca..8619c6c8 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -101,7 +101,6 @@ async def get_mentioned_nodes( n.uuid As uuid, n.group_id AS group_id, n.name AS name, - n.name_embedding AS name_embedding, n.created_at AS created_at, n.summary AS summary, labels(n) AS labels, @@ -128,7 +127,6 @@ async def get_communities_by_nodes( c.uuid As uuid, c.group_id AS group_id, c.name AS name, - c.name_embedding AS name_embedding c.created_at AS created_at, c.summary AS summary """, @@ -172,7 +170,6 @@ async def edge_fulltext_search( r.created_at AS created_at, r.name AS name, r.fact AS fact, - r.fact_embedding AS fact_embedding, r.episodes AS episodes, r.expired_at AS expired_at, r.valid_at AS valid_at, @@ -242,7 +239,6 @@ async def edge_similarity_search( r.created_at AS created_at, r.name AS name, r.fact AS fact, - r.fact_embedding AS fact_embedding, r.episodes AS episodes, r.expired_at AS expired_at, r.valid_at AS valid_at, @@ -301,7 +297,6 @@ async def edge_bfs_search( r.created_at AS created_at, r.name AS name, r.fact AS fact, - r.fact_embedding AS fact_embedding, r.episodes AS episodes, r.expired_at AS expired_at, r.valid_at AS valid_at, @@ -341,10 +336,10 @@ async def node_fulltext_search( query = ( """ - CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) - YIELD node AS n, score - WHERE n:Entity - """ + CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) + YIELD node AS n, score + WHERE n:Entity + """ + filter_query + ENTITY_NODE_RETURN + """ @@ -510,7 +505,6 @@ async def community_fulltext_search( comm.uuid AS uuid, comm.group_id AS group_id, comm.name AS name, - comm.name_embedding AS name_embedding, comm.created_at AS created_at, comm.summary AS summary ORDER BY score DESC @@ -555,7 +549,6 @@ async def community_similarity_search( comm.uuid As uuid, comm.group_id AS group_id, comm.name AS name, - comm.name_embedding AS name_embedding, comm.created_at AS created_at, comm.summary AS summary ORDER BY score DESC diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index 15bab425..f35e4c39 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -239,7 +239,6 @@ async def determine_entity_community( RETURN c.uuid As uuid, c.name AS name, - c.name_embedding AS name_embedding, c.group_id AS group_id, c.created_at AS created_at, c.summary AS summary @@ -258,7 +257,6 @@ async def determine_entity_community( RETURN c.uuid As uuid, c.name AS name, - c.name_embedding AS name_embedding, c.group_id AS group_id, c.created_at AS created_at, c.summary AS summary diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index daf2b0d1..2fd6e9b7 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -91,7 +91,6 @@ async def extract_edges( extract_edges_max_tokens = 16384 llm_client = clients.llm_client - embedder = clients.embedder node_uuids_by_name_map = {node.name: node.uuid for node in nodes} @@ -184,8 +183,6 @@ async def extract_edges( f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})' ) - await create_entity_edge_embeddings(embedder, edges) - logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}') return edges @@ -241,6 +238,9 @@ async def resolve_extracted_edges( ) -> tuple[list[EntityEdge], list[EntityEdge]]: driver = clients.driver llm_client = clients.llm_client + embedder = clients.embedder + + await create_entity_edge_embeddings(embedder, extracted_edges) search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather( get_relevant_edges(driver, extracted_edges, SearchFilters()), diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index b9d181da..9c27523e 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -72,7 +72,6 @@ async def extract_nodes( ) -> list[EntityNode]: start = time() llm_client = clients.llm_client - embedder = clients.embedder llm_response = {} custom_prompt = '' entities_missed = True @@ -165,8 +164,6 @@ async def extract_nodes( extracted_nodes.append(new_node) logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') - await create_entity_node_embeddings(embedder, extracted_nodes) - logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') return extracted_nodes @@ -235,7 +232,6 @@ async def resolve_extracted_nodes( search( clients=clients, query=node.name, - query_vector=node.name_embedding, group_ids=[node.group_id], search_filter=SearchFilters(), config=NODE_HYBRID_SEARCH_RRF,