parent
2b6adb5279
commit
e56a599a72
7 changed files with 212 additions and 88 deletions
|
|
@ -86,4 +86,4 @@ async def main(use_bulk: bool = True):
|
||||||
await client.add_episode_bulk(episodes)
|
await client.add_episode_bulk(episodes)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main(False))
|
asyncio.run(main(True))
|
||||||
|
|
|
||||||
|
|
@ -180,9 +180,9 @@ class Graphiti:
|
||||||
await build_indices_and_constraints(self.driver)
|
await build_indices_and_constraints(self.driver)
|
||||||
|
|
||||||
async def retrieve_episodes(
|
async def retrieve_episodes(
|
||||||
self,
|
self,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int = EPISODE_WINDOW_LEN,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve the last n episodic nodes from the graph.
|
Retrieve the last n episodic nodes from the graph.
|
||||||
|
|
@ -210,14 +210,14 @@ class Graphiti:
|
||||||
return await retrieve_episodes(self.driver, reference_time, last_n)
|
return await retrieve_episodes(self.driver, reference_time, last_n)
|
||||||
|
|
||||||
async def add_episode(
|
async def add_episode(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
episode_body: str,
|
episode_body: str,
|
||||||
source_description: str,
|
source_description: str,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
source: EpisodeType = EpisodeType.message,
|
source: EpisodeType = EpisodeType.message,
|
||||||
success_callback: Callable | None = None,
|
success_callback: Callable | None = None,
|
||||||
error_callback: Callable | None = None,
|
error_callback: Callable | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Process an episode and update the graph.
|
Process an episode and update the graph.
|
||||||
|
|
@ -321,11 +321,11 @@ class Graphiti:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
get_relevant_edges(
|
get_relevant_edges(
|
||||||
[edge],
|
|
||||||
self.driver,
|
self.driver,
|
||||||
RELEVANT_SCHEMA_LIMIT,
|
[edge],
|
||||||
edge.source_node_uuid,
|
edge.source_node_uuid,
|
||||||
edge.target_node_uuid,
|
edge.target_node_uuid,
|
||||||
|
RELEVANT_SCHEMA_LIMIT,
|
||||||
)
|
)
|
||||||
for edge in extracted_edges
|
for edge in extracted_edges
|
||||||
]
|
]
|
||||||
|
|
@ -422,8 +422,8 @@ class Graphiti:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def add_episode_bulk(
|
async def add_episode_bulk(
|
||||||
self,
|
self,
|
||||||
bulk_episodes: list[RawEpisode],
|
bulk_episodes: list[RawEpisode],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Process multiple episodes in bulk and update the graph.
|
Process multiple episodes in bulk and update the graph.
|
||||||
|
|
@ -587,18 +587,18 @@ class Graphiti:
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
async def _search(
|
async def _search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
timestamp: datetime,
|
timestamp: datetime,
|
||||||
config: SearchConfig,
|
config: SearchConfig,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
):
|
):
|
||||||
return await hybrid_search(
|
return await hybrid_search(
|
||||||
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
|
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_nodes_by_query(
|
async def get_nodes_by_query(
|
||||||
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
|
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve nodes from the graph database based on a text query.
|
Retrieve nodes from the graph database based on a text query.
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ async def hybrid_search(
|
||||||
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
||||||
|
|
||||||
if SearchMethod.bm25 in config.search_methods:
|
if SearchMethod.bm25 in config.search_methods:
|
||||||
text_search = await edge_fulltext_search(driver, query, 2 * config.num_edges)
|
text_search = await edge_fulltext_search(driver, query, None, None, 2 * config.num_edges)
|
||||||
search_results.append(text_search)
|
search_results.append(text_search)
|
||||||
|
|
||||||
if SearchMethod.cosine_similarity in config.search_methods:
|
if SearchMethod.cosine_similarity in config.search_methods:
|
||||||
|
|
@ -95,7 +95,7 @@ async def hybrid_search(
|
||||||
)
|
)
|
||||||
|
|
||||||
similarity_search = await edge_similarity_search(
|
similarity_search = await edge_similarity_search(
|
||||||
driver, search_vector, 2 * config.num_edges
|
driver, search_vector, None, None, 2 * config.num_edges
|
||||||
)
|
)
|
||||||
search_results.append(similarity_search)
|
search_results.append(similarity_search)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import typing
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from time import time
|
from time import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver, Query
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.helpers import parse_db_date
|
from graphiti_core.helpers import parse_db_date
|
||||||
|
|
@ -66,12 +66,12 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||||
r.expired_at AS expired_at,
|
r.expired_at AS expired_at,
|
||||||
r.valid_at AS valid_at,
|
r.valid_at AS valid_at,
|
||||||
r.invalid_at AS invalid_at
|
r.invalid_at AS invalid_at
|
||||||
|
|
||||||
""",
|
""",
|
||||||
node_ids=node_ids,
|
node_ids=node_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
context: dict[str, typing.Any] = {}
|
context: dict[str, Any] = {}
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
n_uuid = record['source_node_uuid']
|
n_uuid = record['source_node_uuid']
|
||||||
|
|
@ -96,15 +96,14 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||||
|
|
||||||
|
|
||||||
async def edge_similarity_search(
|
async def edge_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
source_node_uuid: str | None,
|
||||||
source_node_uuid: str = '*',
|
target_node_uuid: str | None,
|
||||||
target_node_uuid: str = '*',
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
records, _, _ = await driver.execute_query(
|
query = Query("""
|
||||||
"""
|
|
||||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||||
|
|
@ -121,7 +120,68 @@ async def edge_similarity_search(
|
||||||
r.valid_at AS valid_at,
|
r.valid_at AS valid_at,
|
||||||
r.invalid_at AS invalid_at
|
r.invalid_at AS invalid_at
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
""",
|
""")
|
||||||
|
|
||||||
|
if source_node_uuid is None and target_node_uuid is None:
|
||||||
|
query = Query("""
|
||||||
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
|
YIELD relationship AS rel, score
|
||||||
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||||
|
RETURN
|
||||||
|
r.uuid AS uuid,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
r.created_at AS created_at,
|
||||||
|
r.name AS name,
|
||||||
|
r.fact AS fact,
|
||||||
|
r.fact_embedding AS fact_embedding,
|
||||||
|
r.episodes AS episodes,
|
||||||
|
r.expired_at AS expired_at,
|
||||||
|
r.valid_at AS valid_at,
|
||||||
|
r.invalid_at AS invalid_at
|
||||||
|
ORDER BY score DESC
|
||||||
|
""")
|
||||||
|
elif source_node_uuid is None:
|
||||||
|
query = Query("""
|
||||||
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
|
YIELD relationship AS rel, score
|
||||||
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||||
|
RETURN
|
||||||
|
r.uuid AS uuid,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
r.created_at AS created_at,
|
||||||
|
r.name AS name,
|
||||||
|
r.fact AS fact,
|
||||||
|
r.fact_embedding AS fact_embedding,
|
||||||
|
r.episodes AS episodes,
|
||||||
|
r.expired_at AS expired_at,
|
||||||
|
r.valid_at AS valid_at,
|
||||||
|
r.invalid_at AS invalid_at
|
||||||
|
ORDER BY score DESC
|
||||||
|
""")
|
||||||
|
elif target_node_uuid is None:
|
||||||
|
query = Query("""
|
||||||
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||||
|
YIELD relationship AS rel, score
|
||||||
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||||
|
RETURN
|
||||||
|
r.uuid AS uuid,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
r.created_at AS created_at,
|
||||||
|
r.name AS name,
|
||||||
|
r.fact AS fact,
|
||||||
|
r.fact_embedding AS fact_embedding,
|
||||||
|
r.episodes AS episodes,
|
||||||
|
r.expired_at AS expired_at,
|
||||||
|
r.valid_at AS valid_at,
|
||||||
|
r.invalid_at AS invalid_at
|
||||||
|
ORDER BY score DESC
|
||||||
|
""")
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
source_uuid=source_node_uuid,
|
source_uuid=source_node_uuid,
|
||||||
target_uuid=target_node_uuid,
|
target_uuid=target_node_uuid,
|
||||||
|
|
@ -151,7 +211,7 @@ async def edge_similarity_search(
|
||||||
|
|
||||||
|
|
||||||
async def entity_similarity_search(
|
async def entity_similarity_search(
|
||||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -161,6 +221,7 @@ async def entity_similarity_search(
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
|
n.name_embeddings AS name_embedding,
|
||||||
n.created_at AS created_at,
|
n.created_at AS created_at,
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
|
|
@ -175,6 +236,7 @@ async def entity_similarity_search(
|
||||||
EntityNode(
|
EntityNode(
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
|
name_embedding=record['name_embedding'],
|
||||||
labels=['Entity'],
|
labels=['Entity'],
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=record['created_at'].to_native(),
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
|
|
@ -185,7 +247,7 @@ async def entity_similarity_search(
|
||||||
|
|
||||||
|
|
||||||
async def entity_fulltext_search(
|
async def entity_fulltext_search(
|
||||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# BM25 search to get top nodes
|
# BM25 search to get top nodes
|
||||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||||
|
|
@ -193,8 +255,9 @@ async def entity_fulltext_search(
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
||||||
RETURN
|
RETURN
|
||||||
node.uuid As uuid,
|
node.uuid AS uuid,
|
||||||
node.name AS name,
|
node.name AS name,
|
||||||
|
node.name_embeddings AS name_embedding,
|
||||||
node.created_at AS created_at,
|
node.created_at AS created_at,
|
||||||
node.summary AS summary
|
node.summary AS summary
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
|
|
@ -210,6 +273,7 @@ async def entity_fulltext_search(
|
||||||
EntityNode(
|
EntityNode(
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
|
name_embedding=record['name_embedding'],
|
||||||
labels=['Entity'],
|
labels=['Entity'],
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=record['created_at'].to_native(),
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
|
|
@ -220,21 +284,18 @@ async def entity_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def edge_fulltext_search(
|
async def edge_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
query: str,
|
query: str,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
source_node_uuid: str | None,
|
||||||
source_node_uuid: str = '*',
|
target_node_uuid: str | None,
|
||||||
target_node_uuid: str = '*',
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# fulltext search over facts
|
# fulltext search over facts
|
||||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
cypher_query = Query("""
|
||||||
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||||
records, _, _ = await driver.execute_query(
|
YIELD relationship AS rel, score
|
||||||
"""
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
RETURN
|
||||||
YIELD relationship AS rel, score
|
|
||||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
||||||
RETURN
|
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
n.uuid AS source_node_uuid,
|
n.uuid AS source_node_uuid,
|
||||||
m.uuid AS target_node_uuid,
|
m.uuid AS target_node_uuid,
|
||||||
|
|
@ -247,7 +308,70 @@ async def edge_fulltext_search(
|
||||||
r.valid_at AS valid_at,
|
r.valid_at AS valid_at,
|
||||||
r.invalid_at AS invalid_at
|
r.invalid_at AS invalid_at
|
||||||
ORDER BY score DESC LIMIT $limit
|
ORDER BY score DESC LIMIT $limit
|
||||||
""",
|
""")
|
||||||
|
|
||||||
|
if source_node_uuid is None and target_node_uuid is None:
|
||||||
|
cypher_query = Query("""
|
||||||
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||||
|
YIELD relationship AS rel, score
|
||||||
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||||
|
RETURN
|
||||||
|
r.uuid AS uuid,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
r.created_at AS created_at,
|
||||||
|
r.name AS name,
|
||||||
|
r.fact AS fact,
|
||||||
|
r.fact_embedding AS fact_embedding,
|
||||||
|
r.episodes AS episodes,
|
||||||
|
r.expired_at AS expired_at,
|
||||||
|
r.valid_at AS valid_at,
|
||||||
|
r.invalid_at AS invalid_at
|
||||||
|
ORDER BY score DESC LIMIT $limit
|
||||||
|
""")
|
||||||
|
elif source_node_uuid is None:
|
||||||
|
cypher_query = Query("""
|
||||||
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||||
|
YIELD relationship AS rel, score
|
||||||
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||||
|
RETURN
|
||||||
|
r.uuid AS uuid,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
r.created_at AS created_at,
|
||||||
|
r.name AS name,
|
||||||
|
r.fact AS fact,
|
||||||
|
r.fact_embedding AS fact_embedding,
|
||||||
|
r.episodes AS episodes,
|
||||||
|
r.expired_at AS expired_at,
|
||||||
|
r.valid_at AS valid_at,
|
||||||
|
r.invalid_at AS invalid_at
|
||||||
|
ORDER BY score DESC LIMIT $limit
|
||||||
|
""")
|
||||||
|
elif target_node_uuid is None:
|
||||||
|
cypher_query = Query("""
|
||||||
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||||
|
YIELD relationship AS rel, score
|
||||||
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||||
|
RETURN
|
||||||
|
r.uuid AS uuid,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
r.created_at AS created_at,
|
||||||
|
r.name AS name,
|
||||||
|
r.fact AS fact,
|
||||||
|
r.fact_embedding AS fact_embedding,
|
||||||
|
r.episodes AS episodes,
|
||||||
|
r.expired_at AS expired_at,
|
||||||
|
r.valid_at AS valid_at,
|
||||||
|
r.invalid_at AS invalid_at
|
||||||
|
ORDER BY score DESC LIMIT $limit
|
||||||
|
""")
|
||||||
|
|
||||||
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
cypher_query,
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
source_uuid=source_node_uuid,
|
source_uuid=source_node_uuid,
|
||||||
target_uuid=target_node_uuid,
|
target_uuid=target_node_uuid,
|
||||||
|
|
@ -277,16 +401,16 @@ async def edge_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_node_search(
|
async def hybrid_node_search(
|
||||||
queries: list[str],
|
queries: list[str],
|
||||||
embeddings: list[list[float]],
|
embeddings: list[list[float]],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Perform a hybrid search for nodes using both text queries and embeddings.
|
Perform a hybrid search for nodes using both text queries and embeddings.
|
||||||
|
|
||||||
This method combines fulltext search and vector similarity search to find
|
This method combines fulltext search and vector similarity search to find
|
||||||
relevant nodes in the graph database. It uses an rrf reranker.
|
relevant nodes in the graph database. It uses a rrf reranker.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|
@ -342,8 +466,8 @@ async def hybrid_node_search(
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_nodes(
|
async def get_relevant_nodes(
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve relevant nodes based on the provided list of EntityNodes.
|
Retrieve relevant nodes based on the provided list of EntityNodes.
|
||||||
|
|
@ -379,11 +503,11 @@ async def get_relevant_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_edges(
|
async def get_relevant_edges(
|
||||||
edges: list[EntityEdge],
|
driver: AsyncDriver,
|
||||||
driver: AsyncDriver,
|
edges: list[EntityEdge],
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
source_node_uuid: str | None,
|
||||||
source_node_uuid: str = '*',
|
target_node_uuid: str | None,
|
||||||
target_node_uuid: str = '*',
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
start = time()
|
start = time()
|
||||||
relevant_edges: list[EntityEdge] = []
|
relevant_edges: list[EntityEdge] = []
|
||||||
|
|
@ -392,13 +516,13 @@ async def get_relevant_edges(
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
edge_similarity_search(
|
edge_similarity_search(
|
||||||
driver, edge.fact_embedding, limit, source_node_uuid, target_node_uuid
|
driver, edge.fact_embedding, source_node_uuid, target_node_uuid, limit
|
||||||
)
|
)
|
||||||
for edge in edges
|
for edge in edges
|
||||||
if edge.fact_embedding is not None
|
if edge.fact_embedding is not None
|
||||||
],
|
],
|
||||||
*[
|
*[
|
||||||
edge_fulltext_search(driver, edge.fact, limit, source_node_uuid, target_node_uuid)
|
edge_fulltext_search(driver, edge.fact, source_node_uuid, target_node_uuid, limit)
|
||||||
for edge in edges
|
for edge in edges
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -433,14 +557,14 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||||
|
|
||||||
|
|
||||||
async def node_distance_reranker(
|
async def node_distance_reranker(
|
||||||
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
|
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# use rrf as a preliminary ranker
|
# use rrf as a preliminary ranker
|
||||||
sorted_uuids = rrf(results)
|
sorted_uuids = rrf(results)
|
||||||
scores: dict[str, float] = {}
|
scores: dict[str, float] = {}
|
||||||
|
|
||||||
for uuid in sorted_uuids:
|
for uuid in sorted_uuids:
|
||||||
# Find shortest path to center node
|
# Find the shortest path to center node
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
|
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
|
||||||
|
|
@ -455,8 +579,8 @@ async def node_distance_reranker(
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
if (
|
if (
|
||||||
record['source_uuid'] == center_node_uuid
|
record['source_uuid'] == center_node_uuid
|
||||||
or record['target_uuid'] == center_node_uuid
|
or record['target_uuid'] == center_node_uuid
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
distance = record['score']
|
distance = record['score']
|
||||||
|
|
|
||||||
|
|
@ -158,7 +158,7 @@ async def dedupe_edges_bulk(
|
||||||
|
|
||||||
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[get_relevant_edges(edge_chunk, driver) for edge_chunk in edge_chunks]
|
*[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def extract_message_nodes(
|
async def extract_message_nodes(
|
||||||
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
|
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
|
|
@ -49,8 +49,8 @@ async def extract_message_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def extract_json_nodes(
|
async def extract_json_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
|
|
@ -67,9 +67,9 @@ async def extract_json_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def extract_nodes(
|
async def extract_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
start = time()
|
start = time()
|
||||||
extracted_node_data: list[dict[str, Any]] = []
|
extracted_node_data: list[dict[str, Any]] = []
|
||||||
|
|
@ -96,9 +96,9 @@ async def extract_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_extracted_nodes(
|
async def dedupe_extracted_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
extracted_nodes: list[EntityNode],
|
extracted_nodes: list[EntityNode],
|
||||||
existing_nodes: list[EntityNode],
|
existing_nodes: list[EntityNode],
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
|
@ -146,9 +146,9 @@ async def dedupe_extracted_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def resolve_extracted_nodes(
|
async def resolve_extracted_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
extracted_nodes: list[EntityNode],
|
extracted_nodes: list[EntityNode],
|
||||||
existing_nodes_lists: list[list[EntityNode]],
|
existing_nodes_lists: list[list[EntityNode]],
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
uuid_map: dict[str, str] = {}
|
uuid_map: dict[str, str] = {}
|
||||||
resolved_nodes: list[EntityNode] = []
|
resolved_nodes: list[EntityNode] = []
|
||||||
|
|
@ -169,7 +169,7 @@ async def resolve_extracted_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def resolve_extracted_node(
|
async def resolve_extracted_node(
|
||||||
llm_client: LLMClient, extracted_node: EntityNode, existing_nodes: list[EntityNode]
|
llm_client: LLMClient, extracted_node: EntityNode, existing_nodes: list[EntityNode]
|
||||||
) -> tuple[EntityNode, dict[str, str]]:
|
) -> tuple[EntityNode, dict[str, str]]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
|
@ -214,8 +214,8 @@ async def resolve_extracted_node(
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_node_list(
|
async def dedupe_node_list(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.2.0"
|
version = "0.2.1"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
authors = [
|
authors = [
|
||||||
"Paul Paliychuk <paul@getzep.com>",
|
"Paul Paliychuk <paul@getzep.com>",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue