Size optimizations (#456)
* memory optimizations for vectors * debugged * unused import * Update graphiti_core/edges.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
34b1cb5f58
commit
1f2f1eeab5
7 changed files with 68 additions and 31 deletions
|
|
@ -46,7 +46,6 @@ ENTITY_EDGE_RETURN: LiteralString = """
|
||||||
e.name AS name,
|
e.name AS name,
|
||||||
e.group_id AS group_id,
|
e.group_id AS group_id,
|
||||||
e.fact AS fact,
|
e.fact AS fact,
|
||||||
e.fact_embedding AS fact_embedding,
|
|
||||||
e.episodes AS episodes,
|
e.episodes AS episodes,
|
||||||
e.expired_at AS expired_at,
|
e.expired_at AS expired_at,
|
||||||
e.valid_at AS valid_at,
|
e.valid_at AS valid_at,
|
||||||
|
|
@ -222,6 +221,20 @@ class EntityEdge(Edge):
|
||||||
|
|
||||||
return self.fact_embedding
|
return self.fact_embedding
|
||||||
|
|
||||||
|
async def load_fact_embedding(self, driver: AsyncDriver):
|
||||||
|
query: LiteralString = """
|
||||||
|
MATCH (n:Entity)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
||||||
|
RETURN e.fact_embedding AS fact_embedding
|
||||||
|
"""
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(records) == 0:
|
||||||
|
raise EdgeNotFoundError(self.uuid)
|
||||||
|
|
||||||
|
self.fact_embedding = records[0]['fact_embedding']
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
ENTITY_EDGE_SAVE,
|
ENTITY_EDGE_SAVE,
|
||||||
|
|
@ -321,8 +334,8 @@ class EntityEdge(Edge):
|
||||||
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ ENTITY_EDGE_RETURN
|
+ ENTITY_EDGE_RETURN
|
||||||
)
|
)
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -452,7 +465,6 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
group_id=record['group_id'],
|
group_id=record['group_id'],
|
||||||
episodes=record['episodes'],
|
episodes=record['episodes'],
|
||||||
fact_embedding=record['fact_embedding'],
|
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=record['created_at'].to_native(),
|
||||||
expired_at=parse_db_date(record['expired_at']),
|
expired_at=parse_db_date(record['expired_at']),
|
||||||
valid_at=parse_db_date(record['valid_at']),
|
valid_at=parse_db_date(record['valid_at']),
|
||||||
|
|
@ -471,6 +483,8 @@ def get_community_edge_from_record(record: Any):
|
||||||
|
|
||||||
|
|
||||||
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
|
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
|
||||||
|
if len(edges) == 0:
|
||||||
|
return
|
||||||
fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
|
fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
|
||||||
for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
|
for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
|
||||||
edge.fact_embedding = fact_embedding
|
edge.fact_embedding = fact_embedding
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,6 @@ ENTITY_NODE_RETURN: LiteralString = """
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
n.name_embedding AS name_embedding,
|
|
||||||
n.group_id AS group_id,
|
n.group_id AS group_id,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.summary AS summary,
|
n.summary AS summary,
|
||||||
|
|
@ -305,6 +304,20 @@ class EntityNode(Node):
|
||||||
|
|
||||||
return self.name_embedding
|
return self.name_embedding
|
||||||
|
|
||||||
|
async def load_name_embedding(self, driver: AsyncDriver):
|
||||||
|
query: LiteralString = """
|
||||||
|
MATCH (n:Entity {uuid: $uuid})
|
||||||
|
RETURN n.name_embedding AS name_embedding
|
||||||
|
"""
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(records) == 0:
|
||||||
|
raise NodeNotFoundError(self.uuid)
|
||||||
|
|
||||||
|
self.name_embedding = records[0]['name_embedding']
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
entity_data: dict[str, Any] = {
|
entity_data: dict[str, Any] = {
|
||||||
'uuid': self.uuid,
|
'uuid': self.uuid,
|
||||||
|
|
@ -332,8 +345,8 @@ class EntityNode(Node):
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $uuid})
|
MATCH (n:Entity {uuid: $uuid})
|
||||||
"""
|
"""
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
)
|
)
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -428,6 +441,20 @@ class CommunityNode(Node):
|
||||||
|
|
||||||
return self.name_embedding
|
return self.name_embedding
|
||||||
|
|
||||||
|
async def load_name_embedding(self, driver: AsyncDriver):
|
||||||
|
query: LiteralString = """
|
||||||
|
MATCH (c:Community {uuid: $uuid})
|
||||||
|
RETURN c.name_embedding AS name_embedding
|
||||||
|
"""
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(records) == 0:
|
||||||
|
raise NodeNotFoundError(self.uuid)
|
||||||
|
|
||||||
|
self.name_embedding = records[0]['name_embedding']
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -436,7 +463,6 @@ class CommunityNode(Node):
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
n.name_embedding AS name_embedding,
|
|
||||||
n.group_id AS group_id,
|
n.group_id AS group_id,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
|
|
@ -461,7 +487,6 @@ class CommunityNode(Node):
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
n.name_embedding AS name_embedding,
|
|
||||||
n.group_id AS group_id,
|
n.group_id AS group_id,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
|
|
@ -495,7 +520,6 @@ class CommunityNode(Node):
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
n.name_embedding AS name_embedding,
|
|
||||||
n.group_id AS group_id,
|
n.group_id AS group_id,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
|
|
@ -534,7 +558,6 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
group_id=record['group_id'],
|
group_id=record['group_id'],
|
||||||
name_embedding=record['name_embedding'],
|
|
||||||
labels=record['labels'],
|
labels=record['labels'],
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=record['created_at'].to_native(),
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
|
|
|
||||||
|
|
@ -209,6 +209,9 @@ async def edge_search(
|
||||||
|
|
||||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||||
elif config.reranker == EdgeReranker.mmr:
|
elif config.reranker == EdgeReranker.mmr:
|
||||||
|
await semaphore_gather(
|
||||||
|
*[edge.load_fact_embedding(driver) for result in search_results for edge in result]
|
||||||
|
)
|
||||||
search_result_uuids_and_vectors = [
|
search_result_uuids_and_vectors = [
|
||||||
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
|
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
|
||||||
for result in search_results
|
for result in search_results
|
||||||
|
|
@ -308,6 +311,9 @@ async def node_search(
|
||||||
if config.reranker == NodeReranker.rrf:
|
if config.reranker == NodeReranker.rrf:
|
||||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||||
elif config.reranker == NodeReranker.mmr:
|
elif config.reranker == NodeReranker.mmr:
|
||||||
|
await semaphore_gather(
|
||||||
|
*[node.load_name_embedding(driver) for result in search_results for node in result]
|
||||||
|
)
|
||||||
search_result_uuids_and_vectors = [
|
search_result_uuids_and_vectors = [
|
||||||
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
|
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
|
||||||
for result in search_results
|
for result in search_results
|
||||||
|
|
@ -431,6 +437,13 @@ async def community_search(
|
||||||
if config.reranker == CommunityReranker.rrf:
|
if config.reranker == CommunityReranker.rrf:
|
||||||
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
||||||
elif config.reranker == CommunityReranker.mmr:
|
elif config.reranker == CommunityReranker.mmr:
|
||||||
|
await semaphore_gather(
|
||||||
|
*[
|
||||||
|
community.load_name_embedding(driver)
|
||||||
|
for result in search_results
|
||||||
|
for community in result
|
||||||
|
]
|
||||||
|
)
|
||||||
search_result_uuids_and_vectors = [
|
search_result_uuids_and_vectors = [
|
||||||
(
|
(
|
||||||
community.uuid,
|
community.uuid,
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,6 @@ async def get_mentioned_nodes(
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
n.group_id AS group_id,
|
n.group_id AS group_id,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
n.name_embedding AS name_embedding,
|
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.summary AS summary,
|
n.summary AS summary,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
|
|
@ -128,7 +127,6 @@ async def get_communities_by_nodes(
|
||||||
c.uuid As uuid,
|
c.uuid As uuid,
|
||||||
c.group_id AS group_id,
|
c.group_id AS group_id,
|
||||||
c.name AS name,
|
c.name AS name,
|
||||||
c.name_embedding AS name_embedding
|
|
||||||
c.created_at AS created_at,
|
c.created_at AS created_at,
|
||||||
c.summary AS summary
|
c.summary AS summary
|
||||||
""",
|
""",
|
||||||
|
|
@ -172,7 +170,6 @@ async def edge_fulltext_search(
|
||||||
r.created_at AS created_at,
|
r.created_at AS created_at,
|
||||||
r.name AS name,
|
r.name AS name,
|
||||||
r.fact AS fact,
|
r.fact AS fact,
|
||||||
r.fact_embedding AS fact_embedding,
|
|
||||||
r.episodes AS episodes,
|
r.episodes AS episodes,
|
||||||
r.expired_at AS expired_at,
|
r.expired_at AS expired_at,
|
||||||
r.valid_at AS valid_at,
|
r.valid_at AS valid_at,
|
||||||
|
|
@ -242,7 +239,6 @@ async def edge_similarity_search(
|
||||||
r.created_at AS created_at,
|
r.created_at AS created_at,
|
||||||
r.name AS name,
|
r.name AS name,
|
||||||
r.fact AS fact,
|
r.fact AS fact,
|
||||||
r.fact_embedding AS fact_embedding,
|
|
||||||
r.episodes AS episodes,
|
r.episodes AS episodes,
|
||||||
r.expired_at AS expired_at,
|
r.expired_at AS expired_at,
|
||||||
r.valid_at AS valid_at,
|
r.valid_at AS valid_at,
|
||||||
|
|
@ -301,7 +297,6 @@ async def edge_bfs_search(
|
||||||
r.created_at AS created_at,
|
r.created_at AS created_at,
|
||||||
r.name AS name,
|
r.name AS name,
|
||||||
r.fact AS fact,
|
r.fact AS fact,
|
||||||
r.fact_embedding AS fact_embedding,
|
|
||||||
r.episodes AS episodes,
|
r.episodes AS episodes,
|
||||||
r.expired_at AS expired_at,
|
r.expired_at AS expired_at,
|
||||||
r.valid_at AS valid_at,
|
r.valid_at AS valid_at,
|
||||||
|
|
@ -341,10 +336,10 @@ async def node_fulltext_search(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||||
YIELD node AS n, score
|
YIELD node AS n, score
|
||||||
WHERE n:Entity
|
WHERE n:Entity
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
|
|
@ -510,7 +505,6 @@ async def community_fulltext_search(
|
||||||
comm.uuid AS uuid,
|
comm.uuid AS uuid,
|
||||||
comm.group_id AS group_id,
|
comm.group_id AS group_id,
|
||||||
comm.name AS name,
|
comm.name AS name,
|
||||||
comm.name_embedding AS name_embedding,
|
|
||||||
comm.created_at AS created_at,
|
comm.created_at AS created_at,
|
||||||
comm.summary AS summary
|
comm.summary AS summary
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
|
|
@ -555,7 +549,6 @@ async def community_similarity_search(
|
||||||
comm.uuid As uuid,
|
comm.uuid As uuid,
|
||||||
comm.group_id AS group_id,
|
comm.group_id AS group_id,
|
||||||
comm.name AS name,
|
comm.name AS name,
|
||||||
comm.name_embedding AS name_embedding,
|
|
||||||
comm.created_at AS created_at,
|
comm.created_at AS created_at,
|
||||||
comm.summary AS summary
|
comm.summary AS summary
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
|
|
|
||||||
|
|
@ -239,7 +239,6 @@ async def determine_entity_community(
|
||||||
RETURN
|
RETURN
|
||||||
c.uuid As uuid,
|
c.uuid As uuid,
|
||||||
c.name AS name,
|
c.name AS name,
|
||||||
c.name_embedding AS name_embedding,
|
|
||||||
c.group_id AS group_id,
|
c.group_id AS group_id,
|
||||||
c.created_at AS created_at,
|
c.created_at AS created_at,
|
||||||
c.summary AS summary
|
c.summary AS summary
|
||||||
|
|
@ -258,7 +257,6 @@ async def determine_entity_community(
|
||||||
RETURN
|
RETURN
|
||||||
c.uuid As uuid,
|
c.uuid As uuid,
|
||||||
c.name AS name,
|
c.name AS name,
|
||||||
c.name_embedding AS name_embedding,
|
|
||||||
c.group_id AS group_id,
|
c.group_id AS group_id,
|
||||||
c.created_at AS created_at,
|
c.created_at AS created_at,
|
||||||
c.summary AS summary
|
c.summary AS summary
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,6 @@ async def extract_edges(
|
||||||
|
|
||||||
extract_edges_max_tokens = 16384
|
extract_edges_max_tokens = 16384
|
||||||
llm_client = clients.llm_client
|
llm_client = clients.llm_client
|
||||||
embedder = clients.embedder
|
|
||||||
|
|
||||||
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
||||||
|
|
||||||
|
|
@ -184,8 +183,6 @@ async def extract_edges(
|
||||||
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
||||||
)
|
)
|
||||||
|
|
||||||
await create_entity_edge_embeddings(embedder, edges)
|
|
||||||
|
|
||||||
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
|
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
@ -241,6 +238,9 @@ async def resolve_extracted_edges(
|
||||||
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
||||||
driver = clients.driver
|
driver = clients.driver
|
||||||
llm_client = clients.llm_client
|
llm_client = clients.llm_client
|
||||||
|
embedder = clients.embedder
|
||||||
|
|
||||||
|
await create_entity_edge_embeddings(embedder, extracted_edges)
|
||||||
|
|
||||||
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
|
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
|
||||||
get_relevant_edges(driver, extracted_edges, SearchFilters()),
|
get_relevant_edges(driver, extracted_edges, SearchFilters()),
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,6 @@ async def extract_nodes(
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
start = time()
|
start = time()
|
||||||
llm_client = clients.llm_client
|
llm_client = clients.llm_client
|
||||||
embedder = clients.embedder
|
|
||||||
llm_response = {}
|
llm_response = {}
|
||||||
custom_prompt = ''
|
custom_prompt = ''
|
||||||
entities_missed = True
|
entities_missed = True
|
||||||
|
|
@ -165,8 +164,6 @@ async def extract_nodes(
|
||||||
extracted_nodes.append(new_node)
|
extracted_nodes.append(new_node)
|
||||||
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
||||||
|
|
||||||
await create_entity_node_embeddings(embedder, extracted_nodes)
|
|
||||||
|
|
||||||
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||||
return extracted_nodes
|
return extracted_nodes
|
||||||
|
|
||||||
|
|
@ -235,7 +232,6 @@ async def resolve_extracted_nodes(
|
||||||
search(
|
search(
|
||||||
clients=clients,
|
clients=clients,
|
||||||
query=node.name,
|
query=node.name,
|
||||||
query_vector=node.name_embedding,
|
|
||||||
group_ids=[node.group_id],
|
group_ids=[node.group_id],
|
||||||
search_filter=SearchFilters(),
|
search_filter=SearchFilters(),
|
||||||
config=NODE_HYBRID_SEARCH_RRF,
|
config=NODE_HYBRID_SEARCH_RRF,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue