Use sessions search (#197)

* use sessions for searches

* correct DB name

* fix typo
This commit is contained in:
Preston Rasmussen 2024-10-22 10:01:56 -04:00 committed by GitHub
parent 1290d0fecb
commit 50d2308c93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 122 additions and 102 deletions

View file

@ -54,7 +54,7 @@ class Edge(BaseModel, ABC):
DELETE e DELETE e
""", """,
uuid=self.uuid, uuid=self.uuid,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
logger.debug(f'Deleted Edge: {self.uuid}') logger.debug(f'Deleted Edge: {self.uuid}')
@ -82,7 +82,7 @@ class EpisodicEdge(Edge):
uuid=self.uuid, uuid=self.uuid,
group_id=self.group_id, group_id=self.group_id,
created_at=self.created_at, created_at=self.created_at,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
logger.debug(f'Saved edge to neo4j: {self.uuid}') logger.debug(f'Saved edge to neo4j: {self.uuid}')
@ -102,7 +102,7 @@ class EpisodicEdge(Edge):
e.created_at AS created_at e.created_at AS created_at
""", """,
uuid=uuid, uuid=uuid,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
edges = [get_episodic_edge_from_record(record) for record in records] edges = [get_episodic_edge_from_record(record) for record in records]
@ -125,7 +125,7 @@ class EpisodicEdge(Edge):
e.created_at AS created_at e.created_at AS created_at
""", """,
uuids=uuids, uuids=uuids,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
edges = [get_episodic_edge_from_record(record) for record in records] edges = [get_episodic_edge_from_record(record) for record in records]
@ -148,7 +148,7 @@ class EpisodicEdge(Edge):
e.created_at AS created_at e.created_at AS created_at
""", """,
group_ids=group_ids, group_ids=group_ids,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
edges = [get_episodic_edge_from_record(record) for record in records] edges = [get_episodic_edge_from_record(record) for record in records]
@ -202,7 +202,7 @@ class EntityEdge(Edge):
expired_at=self.expired_at, expired_at=self.expired_at,
valid_at=self.valid_at, valid_at=self.valid_at,
invalid_at=self.invalid_at, invalid_at=self.invalid_at,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
logger.debug(f'Saved edge to neo4j: {self.uuid}') logger.debug(f'Saved edge to neo4j: {self.uuid}')
@ -229,7 +229,7 @@ class EntityEdge(Edge):
e.invalid_at AS invalid_at e.invalid_at AS invalid_at
""", """,
uuid=uuid, uuid=uuid,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
edges = [get_entity_edge_from_record(record) for record in records] edges = [get_entity_edge_from_record(record) for record in records]
@ -259,7 +259,7 @@ class EntityEdge(Edge):
e.invalid_at AS invalid_at e.invalid_at AS invalid_at
""", """,
uuids=uuids, uuids=uuids,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
edges = [get_entity_edge_from_record(record) for record in records] edges = [get_entity_edge_from_record(record) for record in records]
@ -289,7 +289,7 @@ class EntityEdge(Edge):
e.invalid_at AS invalid_at e.invalid_at AS invalid_at
""", """,
group_ids=group_ids, group_ids=group_ids,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
edges = [get_entity_edge_from_record(record) for record in records] edges = [get_entity_edge_from_record(record) for record in records]
@ -308,7 +308,7 @@ class CommunityEdge(Edge):
uuid=self.uuid, uuid=self.uuid,
group_id=self.group_id, group_id=self.group_id,
created_at=self.created_at, created_at=self.created_at,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
logger.debug(f'Saved edge to neo4j: {self.uuid}') logger.debug(f'Saved edge to neo4j: {self.uuid}')
@ -328,7 +328,7 @@ class CommunityEdge(Edge):
e.created_at AS created_at e.created_at AS created_at
""", """,
uuid=uuid, uuid=uuid,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
edges = [get_community_edge_from_record(record) for record in records] edges = [get_community_edge_from_record(record) for record in records]
@ -349,7 +349,7 @@ class CommunityEdge(Edge):
e.created_at AS created_at e.created_at AS created_at
""", """,
uuids=uuids, uuids=uuids,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
edges = [get_community_edge_from_record(record) for record in records] edges = [get_community_edge_from_record(record) for record in records]
@ -370,7 +370,7 @@ class CommunityEdge(Edge):
e.created_at AS created_at e.created_at AS created_at
""", """,
group_ids=group_ids, group_ids=group_ids,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
edges = [get_community_edge_from_record(record) for record in records] edges = [get_community_edge_from_record(record) for record in records]

View file

@ -90,7 +90,7 @@ class Node(BaseModel, ABC):
DETACH DELETE n DETACH DELETE n
""", """,
uuid=self.uuid, uuid=self.uuid,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
logger.debug(f'Deleted Node: {self.uuid}') logger.debug(f'Deleted Node: {self.uuid}')
@ -136,7 +136,7 @@ class EpisodicNode(Node):
created_at=self.created_at, created_at=self.created_at,
valid_at=self.valid_at, valid_at=self.valid_at,
source=self.source.value, source=self.source.value,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
logger.debug(f'Saved Node to neo4j: {self.uuid}') logger.debug(f'Saved Node to neo4j: {self.uuid}')
@ -158,7 +158,7 @@ class EpisodicNode(Node):
e.source AS source e.source AS source
""", """,
uuid=uuid, uuid=uuid,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
episodes = [get_episodic_node_from_record(record) for record in records] episodes = [get_episodic_node_from_record(record) for record in records]
@ -184,7 +184,7 @@ class EpisodicNode(Node):
e.source AS source e.source AS source
""", """,
uuids=uuids, uuids=uuids,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
episodes = [get_episodic_node_from_record(record) for record in records] episodes = [get_episodic_node_from_record(record) for record in records]
@ -207,7 +207,7 @@ class EpisodicNode(Node):
e.source AS source e.source AS source
""", """,
group_ids=group_ids, group_ids=group_ids,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
episodes = [get_episodic_node_from_record(record) for record in records] episodes = [get_episodic_node_from_record(record) for record in records]
@ -237,7 +237,7 @@ class EntityNode(Node):
summary=self.summary, summary=self.summary,
name_embedding=self.name_embedding, name_embedding=self.name_embedding,
created_at=self.created_at, created_at=self.created_at,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
logger.debug(f'Saved Node to neo4j: {self.uuid}') logger.debug(f'Saved Node to neo4j: {self.uuid}')
@ -258,7 +258,7 @@ class EntityNode(Node):
n.summary AS summary n.summary AS summary
""", """,
uuid=uuid, uuid=uuid,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record) for record in records]
@ -282,7 +282,7 @@ class EntityNode(Node):
n.summary AS summary n.summary AS summary
""", """,
uuids=uuids, uuids=uuids,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record) for record in records]
@ -303,7 +303,7 @@ class EntityNode(Node):
n.summary AS summary n.summary AS summary
""", """,
group_ids=group_ids, group_ids=group_ids,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record) for record in records]
@ -324,7 +324,7 @@ class CommunityNode(Node):
summary=self.summary, summary=self.summary,
name_embedding=self.name_embedding, name_embedding=self.name_embedding,
created_at=self.created_at, created_at=self.created_at,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
logger.debug(f'Saved Node to neo4j: {self.uuid}') logger.debug(f'Saved Node to neo4j: {self.uuid}')
@ -354,7 +354,7 @@ class CommunityNode(Node):
n.summary AS summary n.summary AS summary
""", """,
uuid=uuid, uuid=uuid,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
nodes = [get_community_node_from_record(record) for record in records] nodes = [get_community_node_from_record(record) for record in records]
@ -378,7 +378,7 @@ class CommunityNode(Node):
n.summary AS summary n.summary AS summary
""", """,
uuids=uuids, uuids=uuids,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
communities = [get_community_node_from_record(record) for record in records] communities = [get_community_node_from_record(record) for record in records]
@ -399,7 +399,7 @@ class CommunityNode(Node):
n.summary AS summary n.summary AS summary
""", """,
group_ids=group_ids, group_ids=group_ids,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
communities = [get_community_node_from_record(record) for record in records] communities = [get_community_node_from_record(record) for record in records]

View file

@ -79,20 +79,21 @@ async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode] driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]: ) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes] episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query( async with driver.session(database=DEFAULT_DATABASE) as session:
""" result = await session.run(
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids """
RETURN DISTINCT MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
n.uuid As uuid, RETURN DISTINCT
n.group_id AS group_id, n.uuid As uuid,
n.name AS name, n.group_id AS group_id,
n.name_embedding AS name_embedding n.name AS name,
n.created_at AS created_at, n.name_embedding AS name_embedding,
n.summary AS summary n.created_at AS created_at,
""", n.summary AS summary
uuids=episode_uuids, """,
_database=DEFAULT_DATABASE, {'uuids': episode_uuids},
) )
records = [record async for record in result]
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record) for record in records]
@ -103,8 +104,9 @@ async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode] driver: AsyncDriver, nodes: list[EntityNode]
) -> list[CommunityNode]: ) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes] node_uuids = [node.uuid for node in nodes]
records, _, _ = await driver.execute_query( async with driver.session(database=DEFAULT_DATABASE) as session:
""" result = await session.run(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
RETURN DISTINCT RETURN DISTINCT
c.uuid As uuid, c.uuid As uuid,
@ -114,9 +116,9 @@ async def get_communities_by_nodes(
c.created_at AS created_at, c.created_at AS created_at,
c.summary AS summary c.summary AS summary
""", """,
uuids=node_uuids, {'uuids': node_uuids},
_database=DEFAULT_DATABASE, )
) records = [record async for record in result]
communities = [get_community_node_from_record(record) for record in records] communities = [get_community_node_from_record(record) for record in records]
@ -156,15 +158,18 @@ async def edge_fulltext_search(
ORDER BY score DESC LIMIT $limit ORDER BY score DESC LIMIT $limit
""") """)
records, _, _ = await driver.execute_query( async with driver.session(database=DEFAULT_DATABASE) as session:
cypher_query, result = await session.run(
query=fuzzy_query, cypher_query,
source_uuid=source_node_uuid, {
target_uuid=target_node_uuid, 'query': fuzzy_query,
group_ids=group_ids, 'source_uuid': source_node_uuid,
limit=limit, 'target_uuid': target_node_uuid,
_database=DEFAULT_DATABASE, 'group_ids': group_ids,
) 'limit': limit,
},
)
records = [record async for record in result]
edges = [get_entity_edge_from_record(record) for record in records] edges = [get_entity_edge_from_record(record) for record in records]
@ -206,16 +211,19 @@ async def edge_similarity_search(
LIMIT $limit LIMIT $limit
""") """)
records, _, _ = await driver.execute_query( async with driver.session(database=DEFAULT_DATABASE) as session:
query, result = await session.run(
search_vector=search_vector, query,
source_uuid=source_node_uuid, {
target_uuid=target_node_uuid, 'search_vector': search_vector,
group_ids=group_ids, 'source_uuid': source_node_uuid,
limit=limit, 'target_uuid': target_node_uuid,
min_score=min_score, 'group_ids': group_ids,
_database=DEFAULT_DATABASE, 'limit': limit,
) 'min_score': min_score,
},
)
records = [record async for record in result]
edges = [get_entity_edge_from_record(record) for record in records] edges = [get_entity_edge_from_record(record) for record in records]
@ -233,8 +241,9 @@ async def node_fulltext_search(
if fuzzy_query == '': if fuzzy_query == '':
return [] return []
records, _, _ = await driver.execute_query( async with driver.session(database=DEFAULT_DATABASE) as session:
""" result = await session.run(
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query) CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
YIELD node AS n, score YIELD node AS n, score
RETURN RETURN
@ -247,11 +256,13 @@ async def node_fulltext_search(
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
""", """,
query=fuzzy_query, {
group_ids=group_ids, 'query': fuzzy_query,
limit=limit, 'group_ids': group_ids,
_database=DEFAULT_DATABASE, 'limit': limit,
) },
)
records = [record async for record in result]
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record) for record in records]
return nodes return nodes
@ -265,8 +276,9 @@ async def node_similarity_search(
min_score: float = DEFAULT_MIN_SCORE, min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]: ) -> list[EntityNode]:
# vector similarity search over entity names # vector similarity search over entity names
records, _, _ = await driver.execute_query( async with driver.session(database=DEFAULT_DATABASE) as session:
""" result = await session.run(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (n:Entity) MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids WHERE $group_ids IS NULL OR n.group_id IN $group_ids
@ -282,12 +294,14 @@ async def node_similarity_search(
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
""", """,
search_vector=search_vector, {
group_ids=group_ids, 'search_vector': search_vector,
limit=limit, 'group_ids': group_ids,
min_score=min_score, 'limit': limit,
_database=DEFAULT_DATABASE, 'min_score': min_score,
) },
)
records = [record async for record in result]
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record) for record in records]
return nodes return nodes
@ -304,8 +318,9 @@ async def community_fulltext_search(
if fuzzy_query == '': if fuzzy_query == '':
return [] return []
records, _, _ = await driver.execute_query( async with driver.session(database=DEFAULT_DATABASE) as session:
""" result = await session.run(
"""
CALL db.index.fulltext.queryNodes("community_name", $query) CALL db.index.fulltext.queryNodes("community_name", $query)
YIELD node AS comm, score YIELD node AS comm, score
RETURN RETURN
@ -318,11 +333,13 @@ async def community_fulltext_search(
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
""", """,
query=fuzzy_query, {
group_ids=group_ids, 'query': fuzzy_query,
limit=limit, 'group_ids': group_ids,
_database=DEFAULT_DATABASE, 'limit': limit,
) },
)
records = [record async for record in result]
communities = [get_community_node_from_record(record) for record in records] communities = [get_community_node_from_record(record) for record in records]
return communities return communities
@ -336,8 +353,9 @@ async def community_similarity_search(
min_score=DEFAULT_MIN_SCORE, min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]: ) -> list[CommunityNode]:
# vector similarity search over entity names # vector similarity search over entity names
records, _, _ = await driver.execute_query( async with driver.session(database=DEFAULT_DATABASE) as session:
""" result = await session.run(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (comm:Community) MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids) WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
@ -353,12 +371,14 @@ async def community_similarity_search(
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
""", """,
search_vector=search_vector, {
group_ids=group_ids, 'search_vector': search_vector,
limit=limit, 'group_ids': group_ids,
min_score=min_score, 'limit': limit,
_database=DEFAULT_DATABASE, 'min_score': min_score,
) },
)
records = [record async for record in result]
communities = [get_community_node_from_record(record) for record in records] communities = [get_community_node_from_record(record) for record in records]
return communities return communities
@ -549,7 +569,7 @@ async def node_distance_reranker(
query, query,
node_uuid=uuid, node_uuid=uuid,
center_uuid=center_node_uuid, center_uuid=center_node_uuid,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
for uuid in filtered_uuids for uuid in filtered_uuids
] ]
@ -586,7 +606,7 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
driver.execute_query( driver.execute_query(
query, query,
node_uuid=uuid, node_uuid=uuid,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
for uuid in sorted_uuids for uuid in sorted_uuids
] ]

View file

@ -40,7 +40,7 @@ async def get_community_clusters(
RETURN RETURN
collect(DISTINCT n.group_id) AS group_ids collect(DISTINCT n.group_id) AS group_ids
""", """,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
group_ids = group_id_values[0]['group_ids'] group_ids = group_id_values[0]['group_ids']
@ -59,7 +59,7 @@ async def get_community_clusters(
""", """,
uuid=node.uuid, uuid=node.uuid,
group_id=group_id, group_id=group_id,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
projection[node.uuid] = [ projection[node.uuid] = [
@ -223,7 +223,7 @@ async def remove_communities(driver: AsyncDriver):
MATCH (c:Community) MATCH (c:Community)
DETACH DELETE c DETACH DELETE c
""", """,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
@ -243,7 +243,7 @@ async def determine_entity_community(
c.summary AS summary c.summary AS summary
""", """,
entity_uuid=entity.uuid, entity_uuid=entity.uuid,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
if len(records) > 0: if len(records) > 0:
@ -262,7 +262,7 @@ async def determine_entity_community(
c.summary AS summary c.summary AS summary
""", """,
entity_uuid=entity.uuid, entity_uuid=entity.uuid,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
communities: list[CommunityNode] = [ communities: list[CommunityNode] = [

View file

@ -35,7 +35,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
""" """
SHOW INDEXES YIELD name SHOW INDEXES YIELD name
""", """,
_database=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
index_names = [record['name'] for record in records] index_names = [record['name'] for record in records]
await asyncio.gather( await asyncio.gather(