Use sessions search (#197)

* use sessions for searches

* correct DB name

* fix typo
This commit is contained in:
Preston Rasmussen 2024-10-22 10:01:56 -04:00 committed by GitHub
parent 1290d0fecb
commit 50d2308c93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 122 additions and 102 deletions

View file

@ -54,7 +54,7 @@ class Edge(BaseModel, ABC):
DELETE e
""",
uuid=self.uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Deleted Edge: {self.uuid}')
@ -82,7 +82,7 @@ class EpisodicEdge(Edge):
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved edge to neo4j: {self.uuid}')
@ -102,7 +102,7 @@ class EpisodicEdge(Edge):
e.created_at AS created_at
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
edges = [get_episodic_edge_from_record(record) for record in records]
@ -125,7 +125,7 @@ class EpisodicEdge(Edge):
e.created_at AS created_at
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
edges = [get_episodic_edge_from_record(record) for record in records]
@ -148,7 +148,7 @@ class EpisodicEdge(Edge):
e.created_at AS created_at
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
edges = [get_episodic_edge_from_record(record) for record in records]
@ -202,7 +202,7 @@ class EntityEdge(Edge):
expired_at=self.expired_at,
valid_at=self.valid_at,
invalid_at=self.invalid_at,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved edge to neo4j: {self.uuid}')
@ -229,7 +229,7 @@ class EntityEdge(Edge):
e.invalid_at AS invalid_at
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -259,7 +259,7 @@ class EntityEdge(Edge):
e.invalid_at AS invalid_at
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -289,7 +289,7 @@ class EntityEdge(Edge):
e.invalid_at AS invalid_at
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -308,7 +308,7 @@ class CommunityEdge(Edge):
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved edge to neo4j: {self.uuid}')
@ -328,7 +328,7 @@ class CommunityEdge(Edge):
e.created_at AS created_at
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
edges = [get_community_edge_from_record(record) for record in records]
@ -349,7 +349,7 @@ class CommunityEdge(Edge):
e.created_at AS created_at
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
edges = [get_community_edge_from_record(record) for record in records]
@ -370,7 +370,7 @@ class CommunityEdge(Edge):
e.created_at AS created_at
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
edges = [get_community_edge_from_record(record) for record in records]

View file

@ -90,7 +90,7 @@ class Node(BaseModel, ABC):
DETACH DELETE n
""",
uuid=self.uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Deleted Node: {self.uuid}')
@ -136,7 +136,7 @@ class EpisodicNode(Node):
created_at=self.created_at,
valid_at=self.valid_at,
source=self.source.value,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved Node to neo4j: {self.uuid}')
@ -158,7 +158,7 @@ class EpisodicNode(Node):
e.source AS source
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
episodes = [get_episodic_node_from_record(record) for record in records]
@ -184,7 +184,7 @@ class EpisodicNode(Node):
e.source AS source
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
episodes = [get_episodic_node_from_record(record) for record in records]
@ -207,7 +207,7 @@ class EpisodicNode(Node):
e.source AS source
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
episodes = [get_episodic_node_from_record(record) for record in records]
@ -237,7 +237,7 @@ class EntityNode(Node):
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved Node to neo4j: {self.uuid}')
@ -258,7 +258,7 @@ class EntityNode(Node):
n.summary AS summary
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -282,7 +282,7 @@ class EntityNode(Node):
n.summary AS summary
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -303,7 +303,7 @@ class EntityNode(Node):
n.summary AS summary
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -324,7 +324,7 @@ class CommunityNode(Node):
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved Node to neo4j: {self.uuid}')
@ -354,7 +354,7 @@ class CommunityNode(Node):
n.summary AS summary
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
nodes = [get_community_node_from_record(record) for record in records]
@ -378,7 +378,7 @@ class CommunityNode(Node):
n.summary AS summary
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
communities = [get_community_node_from_record(record) for record in records]
@ -399,7 +399,7 @@ class CommunityNode(Node):
n.summary AS summary
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
communities = [get_community_node_from_record(record) for record in records]

View file

@ -79,20 +79,21 @@ async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
"""
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding
n.created_at AS created_at,
n.summary AS summary
""",
uuids=episode_uuids,
_database=DEFAULT_DATABASE,
)
async with driver.session(database=DEFAULT_DATABASE) as session:
result = await session.run(
"""
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
""",
{'uuids': episode_uuids},
)
records = [record async for record in result]
nodes = [get_entity_node_from_record(record) for record in records]
@ -103,8 +104,9 @@ async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode]
) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes]
records, _, _ = await driver.execute_query(
"""
async with driver.session(database=DEFAULT_DATABASE) as session:
result = await session.run(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
RETURN DISTINCT
c.uuid As uuid,
@ -114,9 +116,9 @@ async def get_communities_by_nodes(
c.created_at AS created_at,
c.summary AS summary
""",
uuids=node_uuids,
_database=DEFAULT_DATABASE,
)
{'uuids': node_uuids},
)
records = [record async for record in result]
communities = [get_community_node_from_record(record) for record in records]
@ -156,15 +158,18 @@ async def edge_fulltext_search(
ORDER BY score DESC LIMIT $limit
""")
records, _, _ = await driver.execute_query(
cypher_query,
query=fuzzy_query,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
group_ids=group_ids,
limit=limit,
_database=DEFAULT_DATABASE,
)
async with driver.session(database=DEFAULT_DATABASE) as session:
result = await session.run(
cypher_query,
{
'query': fuzzy_query,
'source_uuid': source_node_uuid,
'target_uuid': target_node_uuid,
'group_ids': group_ids,
'limit': limit,
},
)
records = [record async for record in result]
edges = [get_entity_edge_from_record(record) for record in records]
@ -206,16 +211,19 @@ async def edge_similarity_search(
LIMIT $limit
""")
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
group_ids=group_ids,
limit=limit,
min_score=min_score,
_database=DEFAULT_DATABASE,
)
async with driver.session(database=DEFAULT_DATABASE) as session:
result = await session.run(
query,
{
'search_vector': search_vector,
'source_uuid': source_node_uuid,
'target_uuid': target_node_uuid,
'group_ids': group_ids,
'limit': limit,
'min_score': min_score,
},
)
records = [record async for record in result]
edges = [get_entity_edge_from_record(record) for record in records]
@ -233,8 +241,9 @@ async def node_fulltext_search(
if fuzzy_query == '':
return []
records, _, _ = await driver.execute_query(
"""
async with driver.session(database=DEFAULT_DATABASE) as session:
result = await session.run(
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
YIELD node AS n, score
RETURN
@ -247,11 +256,13 @@ async def node_fulltext_search(
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
_database=DEFAULT_DATABASE,
)
{
'query': fuzzy_query,
'group_ids': group_ids,
'limit': limit,
},
)
records = [record async for record in result]
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
@ -265,8 +276,9 @@ async def node_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
async with driver.session(database=DEFAULT_DATABASE) as session:
result = await session.run(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
@ -282,12 +294,14 @@ async def node_similarity_search(
ORDER BY score DESC
LIMIT $limit
""",
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
min_score=min_score,
_database=DEFAULT_DATABASE,
)
{
'search_vector': search_vector,
'group_ids': group_ids,
'limit': limit,
'min_score': min_score,
},
)
records = [record async for record in result]
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
@ -304,8 +318,9 @@ async def community_fulltext_search(
if fuzzy_query == '':
return []
records, _, _ = await driver.execute_query(
"""
async with driver.session(database=DEFAULT_DATABASE) as session:
result = await session.run(
"""
CALL db.index.fulltext.queryNodes("community_name", $query)
YIELD node AS comm, score
RETURN
@ -318,11 +333,13 @@ async def community_fulltext_search(
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
_database=DEFAULT_DATABASE,
)
{
'query': fuzzy_query,
'group_ids': group_ids,
'limit': limit,
},
)
records = [record async for record in result]
communities = [get_community_node_from_record(record) for record in records]
return communities
@ -336,8 +353,9 @@ async def community_similarity_search(
min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
async with driver.session(database=DEFAULT_DATABASE) as session:
result = await session.run(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
@ -353,12 +371,14 @@ async def community_similarity_search(
ORDER BY score DESC
LIMIT $limit
""",
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
min_score=min_score,
_database=DEFAULT_DATABASE,
)
{
'search_vector': search_vector,
'group_ids': group_ids,
'limit': limit,
'min_score': min_score,
},
)
records = [record async for record in result]
communities = [get_community_node_from_record(record) for record in records]
return communities
@ -549,7 +569,7 @@ async def node_distance_reranker(
query,
node_uuid=uuid,
center_uuid=center_node_uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
for uuid in filtered_uuids
]
@ -586,7 +606,7 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
driver.execute_query(
query,
node_uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
for uuid in sorted_uuids
]

View file

@ -40,7 +40,7 @@ async def get_community_clusters(
RETURN
collect(DISTINCT n.group_id) AS group_ids
""",
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
group_ids = group_id_values[0]['group_ids']
@ -59,7 +59,7 @@ async def get_community_clusters(
""",
uuid=node.uuid,
group_id=group_id,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
projection[node.uuid] = [
@ -223,7 +223,7 @@ async def remove_communities(driver: AsyncDriver):
MATCH (c:Community)
DETACH DELETE c
""",
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
@ -243,7 +243,7 @@ async def determine_entity_community(
c.summary AS summary
""",
entity_uuid=entity.uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
if len(records) > 0:
@ -262,7 +262,7 @@ async def determine_entity_community(
c.summary AS summary
""",
entity_uuid=entity.uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
communities: list[CommunityNode] = [

View file

@ -35,7 +35,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
"""
SHOW INDEXES YIELD name
""",
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)
index_names = [record['name'] for record in records]
await asyncio.gather(