update
This commit is contained in:
parent
06fccd6829
commit
836668e9ee
7 changed files with 72 additions and 115 deletions
|
|
@ -25,7 +25,6 @@ from pydantic import BaseModel, Field
|
||||||
from transcript_parser import parse_podcast_messages
|
from transcript_parser import parse_podcast_messages
|
||||||
|
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
|
||||||
from graphiti_core.nodes import EpisodeType
|
from graphiti_core.nodes import EpisodeType
|
||||||
from graphiti_core.utils.bulk_utils import RawEpisode
|
from graphiti_core.utils.bulk_utils import RawEpisode
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||||
|
|
@ -35,8 +34,6 @@ load_dotenv()
|
||||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
||||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
||||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
||||||
aoss_host = os.environ.get('AOSS_HOST') or None
|
|
||||||
aoss_port = os.environ.get('AOSS_PORT') or None
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging():
|
def setup_logging():
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
@ -38,10 +39,10 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_SIZE = 10
|
DEFAULT_SIZE = 10
|
||||||
|
|
||||||
EPISODE_INDEX_NAME = 'episodes-test'
|
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX', 'entities')
|
||||||
ENTTITY_INDEX_NAME = 'entities_test'
|
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX', 'episodes')
|
||||||
COMMUNITY_INDEX_NAME = 'communities-test'
|
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
||||||
ENTITY_EDGE_INDEX_NAME = 'entity_edges_test'
|
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
|
||||||
|
|
||||||
|
|
||||||
class GraphProvider(Enum):
|
class GraphProvider(Enum):
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ class Neo4jDriver(GraphDriver):
|
||||||
region = aws_region
|
region = aws_region
|
||||||
service = aws_service
|
service = aws_service
|
||||||
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
|
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
|
||||||
auth = AWSV4SignerAuth(credentials, region, service)
|
auth = AWSV4SignerAuth(credentials, region or '', service or '')
|
||||||
|
|
||||||
self.aoss_client = OpenSearch(
|
self.aoss_client = OpenSearch(
|
||||||
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
||||||
|
|
|
||||||
|
|
@ -1035,7 +1035,7 @@ class Graphiti:
|
||||||
|
|
||||||
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
||||||
|
|
||||||
valid_uuids = await EntityEdge.get_between_nodes(
|
valid_edges = await EntityEdge.get_between_nodes(
|
||||||
self.driver, edge.source_node_uuid, edge.target_node_uuid
|
self.driver, edge.source_node_uuid, edge.target_node_uuid
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1045,7 +1045,7 @@ class Graphiti:
|
||||||
updated_edge.fact,
|
updated_edge.fact,
|
||||||
group_ids=[updated_edge.group_id],
|
group_ids=[updated_edge.group_id],
|
||||||
config=EDGE_HYBRID_SEARCH_RRF,
|
config=EDGE_HYBRID_SEARCH_RRF,
|
||||||
search_filter=SearchFilters(uuids=valid_uuids),
|
search_filter=SearchFilters(uuids=[edge.uuid for edge in valid_edges]),
|
||||||
)
|
)
|
||||||
).edges
|
).edges
|
||||||
existing_edges = (
|
existing_edges = (
|
||||||
|
|
|
||||||
|
|
@ -215,11 +215,11 @@ async def edge_fulltext_search(
|
||||||
# Match the edge ids and return the values
|
# Match the edge ids and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as id
|
UNWIND $ids as id
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
AND id(e)=id
|
AND id(e)=id
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
AND id(e)=id
|
AND id(e)=id
|
||||||
|
|
@ -353,8 +353,8 @@ async def edge_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
||||||
|
|
@ -423,7 +423,11 @@ async def edge_similarity_search(
|
||||||
body={
|
body={
|
||||||
'query': {
|
'query': {
|
||||||
'knn': {
|
'knn': {
|
||||||
'fact_embedding': {'vector': list(map(float, search_vector)), 'k': limit}
|
'fact_embedding': {
|
||||||
|
'vector': list(map(float, search_vector)),
|
||||||
|
'k': limit,
|
||||||
|
'filter': {'bool': {'filter': filters}},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
@ -633,11 +637,11 @@ async def node_fulltext_search(
|
||||||
# Match the edge ides and return the values
|
# Match the edge ides and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as i
|
UNWIND $ids as i
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE n.uuid=i.id
|
WHERE n.uuid=i.id
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ get_entity_node_return_query(driver.provider)
|
+ get_entity_node_return_query(driver.provider)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY i.score DESC
|
ORDER BY i.score DESC
|
||||||
|
|
@ -747,8 +751,8 @@ async def node_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||||
|
|
@ -777,11 +781,11 @@ async def node_similarity_search(
|
||||||
# Match the edge ides and return the values
|
# Match the edge ides and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as i
|
UNWIND $ids as i
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE id(n)=i.id
|
WHERE id(n)=i.id
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ get_entity_node_return_query(driver.provider)
|
+ get_entity_node_return_query(driver.provider)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY i.score DESC
|
ORDER BY i.score DESC
|
||||||
|
|
@ -810,7 +814,11 @@ async def node_similarity_search(
|
||||||
body={
|
body={
|
||||||
'query': {
|
'query': {
|
||||||
'knn': {
|
'knn': {
|
||||||
'name_embedding': {'vector': list(map(float, search_vector)), 'k': limit}
|
'name_embedding': {
|
||||||
|
'vector': list(map(float, search_vector)),
|
||||||
|
'k': limit,
|
||||||
|
'filter': {'bool': {'filter': filters}},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
@ -829,8 +837,8 @@ async def node_similarity_search(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH n, """
|
WITH n, """
|
||||||
|
|
@ -1162,8 +1170,8 @@ async def community_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community)
|
MATCH (n:Community)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||||
|
|
@ -1222,8 +1230,8 @@ async def community_similarity_search(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (c:Community)
|
MATCH (c:Community)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH c,
|
WITH c,
|
||||||
|
|
@ -1365,9 +1373,9 @@ async def get_relevant_nodes(
|
||||||
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MATCH (n:Entity {group_id: $group_id})
|
MATCH (n:Entity {group_id: $group_id})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH node, n, """
|
WITH node, n, """
|
||||||
|
|
@ -1412,9 +1420,9 @@ async def get_relevant_nodes(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MATCH (n:Entity {group_id: $group_id})
|
MATCH (n:Entity {group_id: $group_id})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH node, n, """
|
WITH node, n, """
|
||||||
|
|
@ -1503,9 +1511,9 @@ async def get_relevant_edges(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge
|
WITH e, edge
|
||||||
|
|
@ -1575,9 +1583,9 @@ async def get_relevant_edges(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge, n, m, """
|
WITH e, edge, n, m, """
|
||||||
|
|
@ -1610,61 +1618,12 @@ async def get_relevant_edges(
|
||||||
}) AS matches
|
}) AS matches
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
elif driver.aoss_client:
|
|
||||||
# First get edge candidates
|
|
||||||
query = (
|
|
||||||
"""
|
|
||||||
UNWIND $edges AS edge
|
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
||||||
"""
|
|
||||||
+ filter_query
|
|
||||||
+ """
|
|
||||||
RETURN
|
|
||||||
e.uuid AS search_edge_uuid,
|
|
||||||
collect({
|
|
||||||
uuid: e.uuid,
|
|
||||||
source_node_uuid: startNode(e).uuid,
|
|
||||||
target_node_uuid: endNode(e).uuid,
|
|
||||||
created_at: e.created_at,
|
|
||||||
name: e.name,
|
|
||||||
group_id: e.group_id,
|
|
||||||
fact: e.fact,
|
|
||||||
fact_embedding: e.fact_embedding,
|
|
||||||
episodes: e.episodes,
|
|
||||||
expired_at: e.expired_at,
|
|
||||||
valid_at: e.valid_at,
|
|
||||||
invalid_at: e.invalid_at,
|
|
||||||
attributes: properties(e)
|
|
||||||
}) AS matches
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
results, _, _ = await driver.execute_query(
|
|
||||||
query,
|
|
||||||
edges=[edge.model_dump() for edge in edges],
|
|
||||||
limit=limit,
|
|
||||||
min_score=min_score,
|
|
||||||
routing_='r',
|
|
||||||
**filter_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
|
||||||
result['search_edge_uuid']: [
|
|
||||||
get_entity_edge_from_record(record, driver.provider)
|
|
||||||
for record in result['matches']
|
|
||||||
]
|
|
||||||
for result in results
|
|
||||||
}
|
|
||||||
|
|
||||||
group_id = edges[0].group_id
|
|
||||||
# semaphore_gather(*[edge_similarity_search(driver, )])
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge, """
|
WITH e, edge, """
|
||||||
|
|
@ -1737,10 +1696,10 @@ async def get_edge_invalidation_candidates(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge
|
WITH e, edge
|
||||||
|
|
@ -1810,10 +1769,10 @@ async def get_edge_invalidation_candidates(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH edge, e, n, m, """
|
WITH edge, e, n, m, """
|
||||||
|
|
@ -1849,10 +1808,10 @@ async def get_edge_invalidation_candidates(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH edge, e, """
|
WITH edge, e, """
|
||||||
|
|
|
||||||
|
|
@ -276,7 +276,7 @@ async def resolve_extracted_edges(
|
||||||
config=EDGE_HYBRID_SEARCH_RRF,
|
config=EDGE_HYBRID_SEARCH_RRF,
|
||||||
search_filter=SearchFilters(uuids=valid_uuids),
|
search_filter=SearchFilters(uuids=valid_uuids),
|
||||||
)
|
)
|
||||||
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list)
|
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.21.0rc1"
|
version = "0.21.0rc2"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue