Use sessions search (#197)
* use sessions for searches * correct DB name * fix typo
This commit is contained in:
parent
1290d0fecb
commit
50d2308c93
5 changed files with 122 additions and 102 deletions
|
|
@ -54,7 +54,7 @@ class Edge(BaseModel, ABC):
|
||||||
DELETE e
|
DELETE e
|
||||||
""",
|
""",
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Deleted Edge: {self.uuid}')
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
||||||
|
|
@ -82,7 +82,7 @@ class EpisodicEdge(Edge):
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
group_id=self.group_id,
|
group_id=self.group_id,
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||||
|
|
@ -102,7 +102,7 @@ class EpisodicEdge(Edge):
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||||
|
|
@ -125,7 +125,7 @@ class EpisodicEdge(Edge):
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||||
|
|
@ -148,7 +148,7 @@ class EpisodicEdge(Edge):
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||||
|
|
@ -202,7 +202,7 @@ class EntityEdge(Edge):
|
||||||
expired_at=self.expired_at,
|
expired_at=self.expired_at,
|
||||||
valid_at=self.valid_at,
|
valid_at=self.valid_at,
|
||||||
invalid_at=self.invalid_at,
|
invalid_at=self.invalid_at,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||||
|
|
@ -229,7 +229,7 @@ class EntityEdge(Edge):
|
||||||
e.invalid_at AS invalid_at
|
e.invalid_at AS invalid_at
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
@ -259,7 +259,7 @@ class EntityEdge(Edge):
|
||||||
e.invalid_at AS invalid_at
|
e.invalid_at AS invalid_at
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
@ -289,7 +289,7 @@ class EntityEdge(Edge):
|
||||||
e.invalid_at AS invalid_at
|
e.invalid_at AS invalid_at
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
@ -308,7 +308,7 @@ class CommunityEdge(Edge):
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
group_id=self.group_id,
|
group_id=self.group_id,
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||||
|
|
@ -328,7 +328,7 @@ class CommunityEdge(Edge):
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_community_edge_from_record(record) for record in records]
|
edges = [get_community_edge_from_record(record) for record in records]
|
||||||
|
|
@ -349,7 +349,7 @@ class CommunityEdge(Edge):
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_community_edge_from_record(record) for record in records]
|
edges = [get_community_edge_from_record(record) for record in records]
|
||||||
|
|
@ -370,7 +370,7 @@ class CommunityEdge(Edge):
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_community_edge_from_record(record) for record in records]
|
edges = [get_community_edge_from_record(record) for record in records]
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ class Node(BaseModel, ABC):
|
||||||
DETACH DELETE n
|
DETACH DELETE n
|
||||||
""",
|
""",
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Deleted Node: {self.uuid}')
|
logger.debug(f'Deleted Node: {self.uuid}')
|
||||||
|
|
@ -136,7 +136,7 @@ class EpisodicNode(Node):
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
valid_at=self.valid_at,
|
valid_at=self.valid_at,
|
||||||
source=self.source.value,
|
source=self.source.value,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||||
|
|
@ -158,7 +158,7 @@ class EpisodicNode(Node):
|
||||||
e.source AS source
|
e.source AS source
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||||
|
|
@ -184,7 +184,7 @@ class EpisodicNode(Node):
|
||||||
e.source AS source
|
e.source AS source
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||||
|
|
@ -207,7 +207,7 @@ class EpisodicNode(Node):
|
||||||
e.source AS source
|
e.source AS source
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||||
|
|
@ -237,7 +237,7 @@ class EntityNode(Node):
|
||||||
summary=self.summary,
|
summary=self.summary,
|
||||||
name_embedding=self.name_embedding,
|
name_embedding=self.name_embedding,
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||||
|
|
@ -258,7 +258,7 @@ class EntityNode(Node):
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
@ -282,7 +282,7 @@ class EntityNode(Node):
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
@ -303,7 +303,7 @@ class EntityNode(Node):
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
@ -324,7 +324,7 @@ class CommunityNode(Node):
|
||||||
summary=self.summary,
|
summary=self.summary,
|
||||||
name_embedding=self.name_embedding,
|
name_embedding=self.name_embedding,
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||||
|
|
@ -354,7 +354,7 @@ class CommunityNode(Node):
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes = [get_community_node_from_record(record) for record in records]
|
nodes = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
@ -378,7 +378,7 @@ class CommunityNode(Node):
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
@ -399,7 +399,7 @@ class CommunityNode(Node):
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
|
||||||
|
|
@ -79,20 +79,21 @@ async def get_mentioned_nodes(
|
||||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
episode_uuids = [episode.uuid for episode in episodes]
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
records, _, _ = await driver.execute_query(
|
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||||
"""
|
result = await session.run(
|
||||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
"""
|
||||||
RETURN DISTINCT
|
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||||
n.uuid As uuid,
|
RETURN DISTINCT
|
||||||
n.group_id AS group_id,
|
n.uuid As uuid,
|
||||||
n.name AS name,
|
n.group_id AS group_id,
|
||||||
n.name_embedding AS name_embedding
|
n.name AS name,
|
||||||
n.created_at AS created_at,
|
n.name_embedding AS name_embedding,
|
||||||
n.summary AS summary
|
n.created_at AS created_at,
|
||||||
""",
|
n.summary AS summary
|
||||||
uuids=episode_uuids,
|
""",
|
||||||
_database=DEFAULT_DATABASE,
|
{'uuids': episode_uuids},
|
||||||
)
|
)
|
||||||
|
records = [record async for record in result]
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -103,8 +104,9 @@ async def get_communities_by_nodes(
|
||||||
driver: AsyncDriver, nodes: list[EntityNode]
|
driver: AsyncDriver, nodes: list[EntityNode]
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
node_uuids = [node.uuid for node in nodes]
|
node_uuids = [node.uuid for node in nodes]
|
||||||
records, _, _ = await driver.execute_query(
|
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||||
"""
|
result = await session.run(
|
||||||
|
"""
|
||||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
c.uuid As uuid,
|
c.uuid As uuid,
|
||||||
|
|
@ -114,9 +116,9 @@ async def get_communities_by_nodes(
|
||||||
c.created_at AS created_at,
|
c.created_at AS created_at,
|
||||||
c.summary AS summary
|
c.summary AS summary
|
||||||
""",
|
""",
|
||||||
uuids=node_uuids,
|
{'uuids': node_uuids},
|
||||||
_database=DEFAULT_DATABASE,
|
)
|
||||||
)
|
records = [record async for record in result]
|
||||||
|
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -156,15 +158,18 @@ async def edge_fulltext_search(
|
||||||
ORDER BY score DESC LIMIT $limit
|
ORDER BY score DESC LIMIT $limit
|
||||||
""")
|
""")
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||||
cypher_query,
|
result = await session.run(
|
||||||
query=fuzzy_query,
|
cypher_query,
|
||||||
source_uuid=source_node_uuid,
|
{
|
||||||
target_uuid=target_node_uuid,
|
'query': fuzzy_query,
|
||||||
group_ids=group_ids,
|
'source_uuid': source_node_uuid,
|
||||||
limit=limit,
|
'target_uuid': target_node_uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
'group_ids': group_ids,
|
||||||
)
|
'limit': limit,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
records = [record async for record in result]
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -206,16 +211,19 @@ async def edge_similarity_search(
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""")
|
""")
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||||
query,
|
result = await session.run(
|
||||||
search_vector=search_vector,
|
query,
|
||||||
source_uuid=source_node_uuid,
|
{
|
||||||
target_uuid=target_node_uuid,
|
'search_vector': search_vector,
|
||||||
group_ids=group_ids,
|
'source_uuid': source_node_uuid,
|
||||||
limit=limit,
|
'target_uuid': target_node_uuid,
|
||||||
min_score=min_score,
|
'group_ids': group_ids,
|
||||||
_database=DEFAULT_DATABASE,
|
'limit': limit,
|
||||||
)
|
'min_score': min_score,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
records = [record async for record in result]
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -233,8 +241,9 @@ async def node_fulltext_search(
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||||
"""
|
result = await session.run(
|
||||||
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
||||||
YIELD node AS n, score
|
YIELD node AS n, score
|
||||||
RETURN
|
RETURN
|
||||||
|
|
@ -247,11 +256,13 @@ async def node_fulltext_search(
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
query=fuzzy_query,
|
{
|
||||||
group_ids=group_ids,
|
'query': fuzzy_query,
|
||||||
limit=limit,
|
'group_ids': group_ids,
|
||||||
_database=DEFAULT_DATABASE,
|
'limit': limit,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
records = [record async for record in result]
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
@ -265,8 +276,9 @@ async def node_similarity_search(
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||||
"""
|
result = await session.run(
|
||||||
|
"""
|
||||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||||
|
|
@ -282,12 +294,14 @@ async def node_similarity_search(
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
search_vector=search_vector,
|
{
|
||||||
group_ids=group_ids,
|
'search_vector': search_vector,
|
||||||
limit=limit,
|
'group_ids': group_ids,
|
||||||
min_score=min_score,
|
'limit': limit,
|
||||||
_database=DEFAULT_DATABASE,
|
'min_score': min_score,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
records = [record async for record in result]
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
@ -304,8 +318,9 @@ async def community_fulltext_search(
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||||
"""
|
result = await session.run(
|
||||||
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||||
YIELD node AS comm, score
|
YIELD node AS comm, score
|
||||||
RETURN
|
RETURN
|
||||||
|
|
@ -318,11 +333,13 @@ async def community_fulltext_search(
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
query=fuzzy_query,
|
{
|
||||||
group_ids=group_ids,
|
'query': fuzzy_query,
|
||||||
limit=limit,
|
'group_ids': group_ids,
|
||||||
_database=DEFAULT_DATABASE,
|
'limit': limit,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
records = [record async for record in result]
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
return communities
|
return communities
|
||||||
|
|
@ -336,8 +353,9 @@ async def community_similarity_search(
|
||||||
min_score=DEFAULT_MIN_SCORE,
|
min_score=DEFAULT_MIN_SCORE,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||||
"""
|
result = await session.run(
|
||||||
|
"""
|
||||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||||
MATCH (comm:Community)
|
MATCH (comm:Community)
|
||||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||||
|
|
@ -353,12 +371,14 @@ async def community_similarity_search(
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
search_vector=search_vector,
|
{
|
||||||
group_ids=group_ids,
|
'search_vector': search_vector,
|
||||||
limit=limit,
|
'group_ids': group_ids,
|
||||||
min_score=min_score,
|
'limit': limit,
|
||||||
_database=DEFAULT_DATABASE,
|
'min_score': min_score,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
records = [record async for record in result]
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
return communities
|
return communities
|
||||||
|
|
@ -549,7 +569,7 @@ async def node_distance_reranker(
|
||||||
query,
|
query,
|
||||||
node_uuid=uuid,
|
node_uuid=uuid,
|
||||||
center_uuid=center_node_uuid,
|
center_uuid=center_node_uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
for uuid in filtered_uuids
|
for uuid in filtered_uuids
|
||||||
]
|
]
|
||||||
|
|
@ -586,7 +606,7 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
|
||||||
driver.execute_query(
|
driver.execute_query(
|
||||||
query,
|
query,
|
||||||
node_uuid=uuid,
|
node_uuid=uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
for uuid in sorted_uuids
|
for uuid in sorted_uuids
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ async def get_community_clusters(
|
||||||
RETURN
|
RETURN
|
||||||
collect(DISTINCT n.group_id) AS group_ids
|
collect(DISTINCT n.group_id) AS group_ids
|
||||||
""",
|
""",
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
group_ids = group_id_values[0]['group_ids']
|
group_ids = group_id_values[0]['group_ids']
|
||||||
|
|
@ -59,7 +59,7 @@ async def get_community_clusters(
|
||||||
""",
|
""",
|
||||||
uuid=node.uuid,
|
uuid=node.uuid,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
projection[node.uuid] = [
|
projection[node.uuid] = [
|
||||||
|
|
@ -223,7 +223,7 @@ async def remove_communities(driver: AsyncDriver):
|
||||||
MATCH (c:Community)
|
MATCH (c:Community)
|
||||||
DETACH DELETE c
|
DETACH DELETE c
|
||||||
""",
|
""",
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -243,7 +243,7 @@ async def determine_entity_community(
|
||||||
c.summary AS summary
|
c.summary AS summary
|
||||||
""",
|
""",
|
||||||
entity_uuid=entity.uuid,
|
entity_uuid=entity.uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(records) > 0:
|
if len(records) > 0:
|
||||||
|
|
@ -262,7 +262,7 @@ async def determine_entity_community(
|
||||||
c.summary AS summary
|
c.summary AS summary
|
||||||
""",
|
""",
|
||||||
entity_uuid=entity.uuid,
|
entity_uuid=entity.uuid,
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
communities: list[CommunityNode] = [
|
communities: list[CommunityNode] = [
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
||||||
"""
|
"""
|
||||||
SHOW INDEXES YIELD name
|
SHOW INDEXES YIELD name
|
||||||
""",
|
""",
|
||||||
_database=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
index_names = [record['name'] for record in records]
|
index_names = [record['name'] for record in records]
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue