This commit is contained in:
prestonrasmussen 2025-09-12 12:08:13 -04:00
parent 06fccd6829
commit 836668e9ee
7 changed files with 72 additions and 115 deletions

View file

@ -25,7 +25,6 @@ from pydantic import BaseModel, Field
from transcript_parser import parse_podcast_messages from transcript_parser import parse_podcast_messages
from graphiti_core import Graphiti from graphiti_core import Graphiti
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.nodes import EpisodeType from graphiti_core.nodes import EpisodeType
from graphiti_core.utils.bulk_utils import RawEpisode from graphiti_core.utils.bulk_utils import RawEpisode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data from graphiti_core.utils.maintenance.graph_data_operations import clear_data
@ -35,8 +34,6 @@ load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
aoss_host = os.environ.get('AOSS_HOST') or None
aoss_port = os.environ.get('AOSS_PORT') or None
def setup_logging(): def setup_logging():

View file

@ -17,6 +17,7 @@ limitations under the License.
import asyncio import asyncio
import copy import copy
import logging import logging
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Coroutine from collections.abc import Coroutine
from datetime import datetime from datetime import datetime
@ -38,10 +39,10 @@ logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10 DEFAULT_SIZE = 10
EPISODE_INDEX_NAME = 'episodes-test' ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX', 'entities')
ENTTITY_INDEX_NAME = 'entities_test' EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX', 'episodes')
COMMUNITY_INDEX_NAME = 'communities-test' COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
ENTITY_EDGE_INDEX_NAME = 'entity_edges_test' ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
class GraphProvider(Enum): class GraphProvider(Enum):

View file

@ -73,7 +73,7 @@ class Neo4jDriver(GraphDriver):
region = aws_region region = aws_region
service = aws_service service = aws_service
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials() credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
auth = AWSV4SignerAuth(credentials, region, service) auth = AWSV4SignerAuth(credentials, region or '', service or '')
self.aoss_client = OpenSearch( self.aoss_client = OpenSearch(
hosts=[{'host': aoss_host, 'port': aoss_port}], hosts=[{'host': aoss_host, 'port': aoss_port}],

View file

@ -1035,7 +1035,7 @@ class Graphiti:
updated_edge = resolve_edge_pointers([edge], uuid_map)[0] updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
valid_uuids = await EntityEdge.get_between_nodes( valid_edges = await EntityEdge.get_between_nodes(
self.driver, edge.source_node_uuid, edge.target_node_uuid self.driver, edge.source_node_uuid, edge.target_node_uuid
) )
@ -1045,7 +1045,7 @@ class Graphiti:
updated_edge.fact, updated_edge.fact,
group_ids=[updated_edge.group_id], group_ids=[updated_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF, config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(uuids=valid_uuids), search_filter=SearchFilters(uuids=[edge.uuid for edge in valid_edges]),
) )
).edges ).edges
existing_edges = ( existing_edges = (

View file

@ -215,11 +215,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values # Match the edge ids and return the values
query = ( query = (
""" """
UNWIND $ids as id UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
AND id(e)=id AND id(e)=id
""" """
+ filter_query + filter_query
+ """ + """
AND id(e)=id AND id(e)=id
@ -353,8 +353,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -423,7 +423,11 @@ async def edge_similarity_search(
body={ body={
'query': { 'query': {
'knn': { 'knn': {
'fact_embedding': {'vector': list(map(float, search_vector)), 'k': limit} 'fact_embedding': {
'vector': list(map(float, search_vector)),
'k': limit,
'filter': {'bool': {'filter': filters}},
}
} }
} }
}, },
@ -633,11 +637,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE n.uuid=i.id WHERE n.uuid=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -747,8 +751,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -777,11 +781,11 @@ async def node_similarity_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE id(n)=i.id WHERE id(n)=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -810,7 +814,11 @@ async def node_similarity_search(
body={ body={
'query': { 'query': {
'knn': { 'knn': {
'name_embedding': {'vector': list(map(float, search_vector)), 'k': limit} 'name_embedding': {
'vector': list(map(float, search_vector)),
'k': limit,
'filter': {'bool': {'filter': filters}},
}
} }
} }
}, },
@ -829,8 +837,8 @@ async def node_similarity_search(
else: else:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
WITH n, """ WITH n, """
@ -1162,8 +1170,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Community) MATCH (n:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1222,8 +1230,8 @@ async def community_similarity_search(
query = ( query = (
""" """
MATCH (c:Community) MATCH (c:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
WITH c, WITH c,
@ -1365,9 +1373,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1412,9 +1420,9 @@ async def get_relevant_nodes(
else: else:
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1503,9 +1511,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1575,9 +1583,9 @@ async def get_relevant_edges(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, n, m, """ WITH e, edge, n, m, """
@ -1610,61 +1618,12 @@ async def get_relevant_edges(
}) AS matches }) AS matches
""" """
) )
elif driver.aoss_client:
# First get edge candidates
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
RETURN
e.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: startNode(e).uuid,
target_node_uuid: endNode(e).uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: e.fact_embedding,
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at,
attributes: properties(e)
}) AS matches
"""
)
results, _, _ = await driver.execute_query(
query,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
relevant_edges_dict: dict[str, list[EntityEdge]] = {
result['search_edge_uuid']: [
get_entity_edge_from_record(record, driver.provider)
for record in result['matches']
]
for result in results
}
group_id = edges[0].group_id
# semaphore_gather(*[edge_similarity_search(driver, )])
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, """ WITH e, edge, """
@ -1737,10 +1696,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1810,10 +1769,10 @@ async def get_edge_invalidation_candidates(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, n, m, """ WITH edge, e, n, m, """
@ -1849,10 +1808,10 @@ async def get_edge_invalidation_candidates(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, """ WITH edge, e, """

View file

@ -276,7 +276,7 @@ async def resolve_extracted_edges(
config=EDGE_HYBRID_SEARCH_RRF, config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(uuids=valid_uuids), search_filter=SearchFilters(uuids=valid_uuids),
) )
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list) for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True)
] ]
) )

2
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.21.0rc1" version = "0.21.0rc2"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },