Size optimizations (#456)
* memory optimizations for vectors * debugged * unused import * Update graphiti_core/edges.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
34b1cb5f58
commit
1f2f1eeab5
7 changed files with 68 additions and 31 deletions
|
|
@ -46,7 +46,6 @@ ENTITY_EDGE_RETURN: LiteralString = """
|
|||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.fact AS fact,
|
||||
e.fact_embedding AS fact_embedding,
|
||||
e.episodes AS episodes,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
|
|
@ -222,6 +221,20 @@ class EntityEdge(Edge):
|
|||
|
||||
return self.fact_embedding
|
||||
|
||||
async def load_fact_embedding(self, driver: AsyncDriver):
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
||||
RETURN e.fact_embedding AS fact_embedding
|
||||
"""
|
||||
records, _, _ = await driver.execute_query(
|
||||
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||
)
|
||||
|
||||
if len(records) == 0:
|
||||
raise EdgeNotFoundError(self.uuid)
|
||||
|
||||
self.fact_embedding = records[0]['fact_embedding']
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
ENTITY_EDGE_SAVE,
|
||||
|
|
@ -321,8 +334,8 @@ class EntityEdge(Edge):
|
|||
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -452,7 +465,6 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
episodes=record['episodes'],
|
||||
fact_embedding=record['fact_embedding'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
expired_at=parse_db_date(record['expired_at']),
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
|
|
@ -471,6 +483,8 @@ def get_community_edge_from_record(record: Any):
|
|||
|
||||
|
||||
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
|
||||
if len(edges) == 0:
|
||||
return
|
||||
fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
|
||||
for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
|
||||
edge.fact_embedding = fact_embedding
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ ENTITY_NODE_RETURN: LiteralString = """
|
|||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
|
|
@ -305,6 +304,20 @@ class EntityNode(Node):
|
|||
|
||||
return self.name_embedding
|
||||
|
||||
async def load_name_embedding(self, driver: AsyncDriver):
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN n.name_embedding AS name_embedding
|
||||
"""
|
||||
records, _, _ = await driver.execute_query(
|
||||
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||
)
|
||||
|
||||
if len(records) == 0:
|
||||
raise NodeNotFoundError(self.uuid)
|
||||
|
||||
self.name_embedding = records[0]['name_embedding']
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
entity_data: dict[str, Any] = {
|
||||
'uuid': self.uuid,
|
||||
|
|
@ -332,8 +345,8 @@ class EntityNode(Node):
|
|||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -428,6 +441,20 @@ class CommunityNode(Node):
|
|||
|
||||
return self.name_embedding
|
||||
|
||||
async def load_name_embedding(self, driver: AsyncDriver):
|
||||
query: LiteralString = """
|
||||
MATCH (c:Community {uuid: $uuid})
|
||||
RETURN c.name_embedding AS name_embedding
|
||||
"""
|
||||
records, _, _ = await driver.execute_query(
|
||||
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||
)
|
||||
|
||||
if len(records) == 0:
|
||||
raise NodeNotFoundError(self.uuid)
|
||||
|
||||
self.name_embedding = records[0]['name_embedding']
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -436,7 +463,6 @@ class CommunityNode(Node):
|
|||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
|
|
@ -461,7 +487,6 @@ class CommunityNode(Node):
|
|||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
|
|
@ -495,7 +520,6 @@ class CommunityNode(Node):
|
|||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
|
|
@ -534,7 +558,6 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
|||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
name_embedding=record['name_embedding'],
|
||||
labels=record['labels'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
|
|
|
|||
|
|
@ -209,6 +209,9 @@ async def edge_search(
|
|||
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
elif config.reranker == EdgeReranker.mmr:
|
||||
await semaphore_gather(
|
||||
*[edge.load_fact_embedding(driver) for result in search_results for edge in result]
|
||||
)
|
||||
search_result_uuids_and_vectors = [
|
||||
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
|
||||
for result in search_results
|
||||
|
|
@ -308,6 +311,9 @@ async def node_search(
|
|||
if config.reranker == NodeReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
elif config.reranker == NodeReranker.mmr:
|
||||
await semaphore_gather(
|
||||
*[node.load_name_embedding(driver) for result in search_results for node in result]
|
||||
)
|
||||
search_result_uuids_and_vectors = [
|
||||
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
|
||||
for result in search_results
|
||||
|
|
@ -431,6 +437,13 @@ async def community_search(
|
|||
if config.reranker == CommunityReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||
elif config.reranker == CommunityReranker.mmr:
|
||||
await semaphore_gather(
|
||||
*[
|
||||
community.load_name_embedding(driver)
|
||||
for result in search_results
|
||||
for community in result
|
||||
]
|
||||
)
|
||||
search_result_uuids_and_vectors = [
|
||||
(
|
||||
community.uuid,
|
||||
|
|
|
|||
|
|
@ -101,7 +101,6 @@ async def get_mentioned_nodes(
|
|||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
|
|
@ -128,7 +127,6 @@ async def get_communities_by_nodes(
|
|||
c.uuid As uuid,
|
||||
c.group_id AS group_id,
|
||||
c.name AS name,
|
||||
c.name_embedding AS name_embedding
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
""",
|
||||
|
|
@ -172,7 +170,6 @@ async def edge_fulltext_search(
|
|||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
|
|
@ -242,7 +239,6 @@ async def edge_similarity_search(
|
|||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
|
|
@ -301,7 +297,6 @@ async def edge_bfs_search(
|
|||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
|
|
@ -341,10 +336,10 @@ async def node_fulltext_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity
|
||||
"""
|
||||
+ filter_query
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
|
|
@ -510,7 +505,6 @@ async def community_fulltext_search(
|
|||
comm.uuid AS uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
comm.name_embedding AS name_embedding,
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
|
|
@ -555,7 +549,6 @@ async def community_similarity_search(
|
|||
comm.uuid As uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
comm.name_embedding AS name_embedding,
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
|
|
|
|||
|
|
@ -239,7 +239,6 @@ async def determine_entity_community(
|
|||
RETURN
|
||||
c.uuid As uuid,
|
||||
c.name AS name,
|
||||
c.name_embedding AS name_embedding,
|
||||
c.group_id AS group_id,
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
|
|
@ -258,7 +257,6 @@ async def determine_entity_community(
|
|||
RETURN
|
||||
c.uuid As uuid,
|
||||
c.name AS name,
|
||||
c.name_embedding AS name_embedding,
|
||||
c.group_id AS group_id,
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
|
|
|
|||
|
|
@ -91,7 +91,6 @@ async def extract_edges(
|
|||
|
||||
extract_edges_max_tokens = 16384
|
||||
llm_client = clients.llm_client
|
||||
embedder = clients.embedder
|
||||
|
||||
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
||||
|
||||
|
|
@ -184,8 +183,6 @@ async def extract_edges(
|
|||
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
||||
)
|
||||
|
||||
await create_entity_edge_embeddings(embedder, edges)
|
||||
|
||||
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
|
||||
|
||||
return edges
|
||||
|
|
@ -241,6 +238,9 @@ async def resolve_extracted_edges(
|
|||
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
||||
driver = clients.driver
|
||||
llm_client = clients.llm_client
|
||||
embedder = clients.embedder
|
||||
|
||||
await create_entity_edge_embeddings(embedder, extracted_edges)
|
||||
|
||||
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
|
||||
get_relevant_edges(driver, extracted_edges, SearchFilters()),
|
||||
|
|
|
|||
|
|
@ -72,7 +72,6 @@ async def extract_nodes(
|
|||
) -> list[EntityNode]:
|
||||
start = time()
|
||||
llm_client = clients.llm_client
|
||||
embedder = clients.embedder
|
||||
llm_response = {}
|
||||
custom_prompt = ''
|
||||
entities_missed = True
|
||||
|
|
@ -165,8 +164,6 @@ async def extract_nodes(
|
|||
extracted_nodes.append(new_node)
|
||||
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
||||
|
||||
await create_entity_node_embeddings(embedder, extracted_nodes)
|
||||
|
||||
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||
return extracted_nodes
|
||||
|
||||
|
|
@ -235,7 +232,6 @@ async def resolve_extracted_nodes(
|
|||
search(
|
||||
clients=clients,
|
||||
query=node.name,
|
||||
query_vector=node.name_embedding,
|
||||
group_ids=[node.group_id],
|
||||
search_filter=SearchFilters(),
|
||||
config=NODE_HYBRID_SEARCH_RRF,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue