Size optimizations (#456)

* memory optimizations for vectors

* debugged

* unused import

* Update graphiti_core/edges.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Preston Rasmussen 2025-05-07 20:08:30 -04:00 committed by GitHub
parent 34b1cb5f58
commit 1f2f1eeab5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 68 additions and 31 deletions

View file

@ -46,7 +46,6 @@ ENTITY_EDGE_RETURN: LiteralString = """
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
@ -222,6 +221,20 @@ class EntityEdge(Edge):
return self.fact_embedding
async def load_fact_embedding(self, driver: AsyncDriver):
query: LiteralString = """
MATCH (n:Entity)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
"""
records, _, _ = await driver.execute_query(
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
)
if len(records) == 0:
raise EdgeNotFoundError(self.uuid)
self.fact_embedding = records[0]['fact_embedding']
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
ENTITY_EDGE_SAVE,
@ -321,8 +334,8 @@ class EntityEdge(Edge):
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
query: LiteralString = (
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
+ ENTITY_EDGE_RETURN
)
records, _, _ = await driver.execute_query(
@ -452,7 +465,6 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
name=record['name'],
group_id=record['group_id'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
@ -471,6 +483,8 @@ def get_community_edge_from_record(record: Any):
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
if len(edges) == 0:
return
fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
edge.fact_embedding = fact_embedding

View file

@ -42,7 +42,6 @@ ENTITY_NODE_RETURN: LiteralString = """
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
@ -305,6 +304,20 @@ class EntityNode(Node):
return self.name_embedding
async def load_name_embedding(self, driver: AsyncDriver):
query: LiteralString = """
MATCH (n:Entity {uuid: $uuid})
RETURN n.name_embedding AS name_embedding
"""
records, _, _ = await driver.execute_query(
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
)
if len(records) == 0:
raise NodeNotFoundError(self.uuid)
self.name_embedding = records[0]['name_embedding']
async def save(self, driver: AsyncDriver):
entity_data: dict[str, Any] = {
'uuid': self.uuid,
@ -332,8 +345,8 @@ class EntityNode(Node):
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
query = (
"""
MATCH (n:Entity {uuid: $uuid})
"""
MATCH (n:Entity {uuid: $uuid})
"""
+ ENTITY_NODE_RETURN
)
records, _, _ = await driver.execute_query(
@ -428,6 +441,20 @@ class CommunityNode(Node):
return self.name_embedding
async def load_name_embedding(self, driver: AsyncDriver):
query: LiteralString = """
MATCH (c:Community {uuid: $uuid})
RETURN c.name_embedding AS name_embedding
"""
records, _, _ = await driver.execute_query(
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
)
if len(records) == 0:
raise NodeNotFoundError(self.uuid)
self.name_embedding = records[0]['name_embedding']
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
@ -436,7 +463,6 @@ class CommunityNode(Node):
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
@ -461,7 +487,6 @@ class CommunityNode(Node):
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
@ -495,7 +520,6 @@ class CommunityNode(Node):
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
@ -534,7 +558,6 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
uuid=record['uuid'],
name=record['name'],
group_id=record['group_id'],
name_embedding=record['name_embedding'],
labels=record['labels'],
created_at=record['created_at'].to_native(),
summary=record['summary'],

View file

@ -209,6 +209,9 @@ async def edge_search(
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EdgeReranker.mmr:
await semaphore_gather(
*[edge.load_fact_embedding(driver) for result in search_results for edge in result]
)
search_result_uuids_and_vectors = [
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
for result in search_results
@ -308,6 +311,9 @@ async def node_search(
if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == NodeReranker.mmr:
await semaphore_gather(
*[node.load_name_embedding(driver) for result in search_results for node in result]
)
search_result_uuids_and_vectors = [
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
for result in search_results
@ -431,6 +437,13 @@ async def community_search(
if config.reranker == CommunityReranker.rrf:
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == CommunityReranker.mmr:
await semaphore_gather(
*[
community.load_name_embedding(driver)
for result in search_results
for community in result
]
)
search_result_uuids_and_vectors = [
(
community.uuid,

View file

@ -101,7 +101,6 @@ async def get_mentioned_nodes(
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
@ -128,7 +127,6 @@ async def get_communities_by_nodes(
c.uuid As uuid,
c.group_id AS group_id,
c.name AS name,
c.name_embedding AS name_embedding
c.created_at AS created_at,
c.summary AS summary
""",
@ -172,7 +170,6 @@ async def edge_fulltext_search(
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
@ -242,7 +239,6 @@ async def edge_similarity_search(
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
@ -301,7 +297,6 @@ async def edge_bfs_search(
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
@ -341,10 +336,10 @@ async def node_fulltext_search(
query = (
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
WHERE n:Entity
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
WHERE n:Entity
"""
+ filter_query
+ ENTITY_NODE_RETURN
+ """
@ -510,7 +505,6 @@ async def community_fulltext_search(
comm.uuid AS uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
@ -555,7 +549,6 @@ async def community_similarity_search(
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC

View file

@ -239,7 +239,6 @@ async def determine_entity_community(
RETURN
c.uuid As uuid,
c.name AS name,
c.name_embedding AS name_embedding,
c.group_id AS group_id,
c.created_at AS created_at,
c.summary AS summary
@ -258,7 +257,6 @@ async def determine_entity_community(
RETURN
c.uuid As uuid,
c.name AS name,
c.name_embedding AS name_embedding,
c.group_id AS group_id,
c.created_at AS created_at,
c.summary AS summary

View file

@ -91,7 +91,6 @@ async def extract_edges(
extract_edges_max_tokens = 16384
llm_client = clients.llm_client
embedder = clients.embedder
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
@ -184,8 +183,6 @@ async def extract_edges(
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
)
await create_entity_edge_embeddings(embedder, edges)
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
return edges
@ -241,6 +238,9 @@ async def resolve_extracted_edges(
) -> tuple[list[EntityEdge], list[EntityEdge]]:
driver = clients.driver
llm_client = clients.llm_client
embedder = clients.embedder
await create_entity_edge_embeddings(embedder, extracted_edges)
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
get_relevant_edges(driver, extracted_edges, SearchFilters()),

View file

@ -72,7 +72,6 @@ async def extract_nodes(
) -> list[EntityNode]:
start = time()
llm_client = clients.llm_client
embedder = clients.embedder
llm_response = {}
custom_prompt = ''
entities_missed = True
@ -165,8 +164,6 @@ async def extract_nodes(
extracted_nodes.append(new_node)
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
await create_entity_node_embeddings(embedder, extracted_nodes)
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
return extracted_nodes
@ -235,7 +232,6 @@ async def resolve_extracted_nodes(
search(
clients=clients,
query=node.name,
query_vector=node.name_embedding,
group_ids=[node.group_id],
search_filter=SearchFilters(),
config=NODE_HYBRID_SEARCH_RRF,