Use sessions search (#197)
* use sessions for searches * correct DB name * fix typo
This commit is contained in:
parent
1290d0fecb
commit
50d2308c93
5 changed files with 122 additions and 102 deletions
|
|
@ -54,7 +54,7 @@ class Edge(BaseModel, ABC):
|
|||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Edge: {self.uuid}')
|
||||
|
|
@ -82,7 +82,7 @@ class EpisodicEdge(Edge):
|
|||
uuid=self.uuid,
|
||||
group_id=self.group_id,
|
||||
created_at=self.created_at,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||
|
|
@ -102,7 +102,7 @@ class EpisodicEdge(Edge):
|
|||
e.created_at AS created_at
|
||||
""",
|
||||
uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||
|
|
@ -125,7 +125,7 @@ class EpisodicEdge(Edge):
|
|||
e.created_at AS created_at
|
||||
""",
|
||||
uuids=uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||
|
|
@ -148,7 +148,7 @@ class EpisodicEdge(Edge):
|
|||
e.created_at AS created_at
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||
|
|
@ -202,7 +202,7 @@ class EntityEdge(Edge):
|
|||
expired_at=self.expired_at,
|
||||
valid_at=self.valid_at,
|
||||
invalid_at=self.invalid_at,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||
|
|
@ -229,7 +229,7 @@ class EntityEdge(Edge):
|
|||
e.invalid_at AS invalid_at
|
||||
""",
|
||||
uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
|
@ -259,7 +259,7 @@ class EntityEdge(Edge):
|
|||
e.invalid_at AS invalid_at
|
||||
""",
|
||||
uuids=uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
|
@ -289,7 +289,7 @@ class EntityEdge(Edge):
|
|||
e.invalid_at AS invalid_at
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
|
@ -308,7 +308,7 @@ class CommunityEdge(Edge):
|
|||
uuid=self.uuid,
|
||||
group_id=self.group_id,
|
||||
created_at=self.created_at,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||
|
|
@ -328,7 +328,7 @@ class CommunityEdge(Edge):
|
|||
e.created_at AS created_at
|
||||
""",
|
||||
uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_community_edge_from_record(record) for record in records]
|
||||
|
|
@ -349,7 +349,7 @@ class CommunityEdge(Edge):
|
|||
e.created_at AS created_at
|
||||
""",
|
||||
uuids=uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_community_edge_from_record(record) for record in records]
|
||||
|
|
@ -370,7 +370,7 @@ class CommunityEdge(Edge):
|
|||
e.created_at AS created_at
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_community_edge_from_record(record) for record in records]
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class Node(BaseModel, ABC):
|
|||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Node: {self.uuid}')
|
||||
|
|
@ -136,7 +136,7 @@ class EpisodicNode(Node):
|
|||
created_at=self.created_at,
|
||||
valid_at=self.valid_at,
|
||||
source=self.source.value,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||
|
|
@ -158,7 +158,7 @@ class EpisodicNode(Node):
|
|||
e.source AS source
|
||||
""",
|
||||
uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||
|
|
@ -184,7 +184,7 @@ class EpisodicNode(Node):
|
|||
e.source AS source
|
||||
""",
|
||||
uuids=uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||
|
|
@ -207,7 +207,7 @@ class EpisodicNode(Node):
|
|||
e.source AS source
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||
|
|
@ -237,7 +237,7 @@ class EntityNode(Node):
|
|||
summary=self.summary,
|
||||
name_embedding=self.name_embedding,
|
||||
created_at=self.created_at,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||
|
|
@ -258,7 +258,7 @@ class EntityNode(Node):
|
|||
n.summary AS summary
|
||||
""",
|
||||
uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
|
@ -282,7 +282,7 @@ class EntityNode(Node):
|
|||
n.summary AS summary
|
||||
""",
|
||||
uuids=uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
|
@ -303,7 +303,7 @@ class EntityNode(Node):
|
|||
n.summary AS summary
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
|
@ -324,7 +324,7 @@ class CommunityNode(Node):
|
|||
summary=self.summary,
|
||||
name_embedding=self.name_embedding,
|
||||
created_at=self.created_at,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||
|
|
@ -354,7 +354,7 @@ class CommunityNode(Node):
|
|||
n.summary AS summary
|
||||
""",
|
||||
uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
nodes = [get_community_node_from_record(record) for record in records]
|
||||
|
|
@ -378,7 +378,7 @@ class CommunityNode(Node):
|
|||
n.summary AS summary
|
||||
""",
|
||||
uuids=uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
|
@ -399,7 +399,7 @@ class CommunityNode(Node):
|
|||
n.summary AS summary
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
|
|
|||
|
|
@ -79,20 +79,21 @@ async def get_mentioned_nodes(
|
|||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||
) -> list[EntityNode]:
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
""",
|
||||
uuids=episode_uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
""",
|
||||
{'uuids': episode_uuids},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -103,8 +104,9 @@ async def get_communities_by_nodes(
|
|||
driver: AsyncDriver, nodes: list[EntityNode]
|
||||
) -> list[CommunityNode]:
|
||||
node_uuids = [node.uuid for node in nodes]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
c.uuid As uuid,
|
||||
|
|
@ -114,9 +116,9 @@ async def get_communities_by_nodes(
|
|||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
""",
|
||||
uuids=node_uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
{'uuids': node_uuids},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -156,15 +158,18 @@ async def edge_fulltext_search(
|
|||
ORDER BY score DESC LIMIT $limit
|
||||
""")
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
cypher_query,
|
||||
query=fuzzy_query,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
result = await session.run(
|
||||
cypher_query,
|
||||
{
|
||||
'query': fuzzy_query,
|
||||
'source_uuid': source_node_uuid,
|
||||
'target_uuid': target_node_uuid,
|
||||
'group_ids': group_ids,
|
||||
'limit': limit,
|
||||
},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -206,16 +211,19 @@ async def edge_similarity_search(
|
|||
LIMIT $limit
|
||||
""")
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
search_vector=search_vector,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
result = await session.run(
|
||||
query,
|
||||
{
|
||||
'search_vector': search_vector,
|
||||
'source_uuid': source_node_uuid,
|
||||
'target_uuid': target_node_uuid,
|
||||
'group_ids': group_ids,
|
||||
'limit': limit,
|
||||
'min_score': min_score,
|
||||
},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -233,8 +241,9 @@ async def node_fulltext_search(
|
|||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
||||
YIELD node AS n, score
|
||||
RETURN
|
||||
|
|
@ -247,11 +256,13 @@ async def node_fulltext_search(
|
|||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
{
|
||||
'query': fuzzy_query,
|
||||
'group_ids': group_ids,
|
||||
'limit': limit,
|
||||
},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
|
@ -265,8 +276,9 @@ async def node_similarity_search(
|
|||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
) -> list[EntityNode]:
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
MATCH (n:Entity)
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
|
|
@ -282,12 +294,14 @@ async def node_similarity_search(
|
|||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
{
|
||||
'search_vector': search_vector,
|
||||
'group_ids': group_ids,
|
||||
'limit': limit,
|
||||
'min_score': min_score,
|
||||
},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
|
@ -304,8 +318,9 @@ async def community_fulltext_search(
|
|||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||
YIELD node AS comm, score
|
||||
RETURN
|
||||
|
|
@ -318,11 +333,13 @@ async def community_fulltext_search(
|
|||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
{
|
||||
'query': fuzzy_query,
|
||||
'group_ids': group_ids,
|
||||
'limit': limit,
|
||||
},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
return communities
|
||||
|
|
@ -336,8 +353,9 @@ async def community_similarity_search(
|
|||
min_score=DEFAULT_MIN_SCORE,
|
||||
) -> list[CommunityNode]:
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
MATCH (comm:Community)
|
||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||
|
|
@ -353,12 +371,14 @@ async def community_similarity_search(
|
|||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
{
|
||||
'search_vector': search_vector,
|
||||
'group_ids': group_ids,
|
||||
'limit': limit,
|
||||
'min_score': min_score,
|
||||
},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
return communities
|
||||
|
|
@ -549,7 +569,7 @@ async def node_distance_reranker(
|
|||
query,
|
||||
node_uuid=uuid,
|
||||
center_uuid=center_node_uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
for uuid in filtered_uuids
|
||||
]
|
||||
|
|
@ -586,7 +606,7 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
|
|||
driver.execute_query(
|
||||
query,
|
||||
node_uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
for uuid in sorted_uuids
|
||||
]
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ async def get_community_clusters(
|
|||
RETURN
|
||||
collect(DISTINCT n.group_id) AS group_ids
|
||||
""",
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
group_ids = group_id_values[0]['group_ids']
|
||||
|
|
@ -59,7 +59,7 @@ async def get_community_clusters(
|
|||
""",
|
||||
uuid=node.uuid,
|
||||
group_id=group_id,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
projection[node.uuid] = [
|
||||
|
|
@ -223,7 +223,7 @@ async def remove_communities(driver: AsyncDriver):
|
|||
MATCH (c:Community)
|
||||
DETACH DELETE c
|
||||
""",
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -243,7 +243,7 @@ async def determine_entity_community(
|
|||
c.summary AS summary
|
||||
""",
|
||||
entity_uuid=entity.uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
if len(records) > 0:
|
||||
|
|
@ -262,7 +262,7 @@ async def determine_entity_community(
|
|||
c.summary AS summary
|
||||
""",
|
||||
entity_uuid=entity.uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
communities: list[CommunityNode] = [
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
|||
"""
|
||||
SHOW INDEXES YIELD name
|
||||
""",
|
||||
_database=DEFAULT_DATABASE,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
index_names = [record['name'] for record in records]
|
||||
await asyncio.gather(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue