Msc benchmark update (#173)
* eval update * I sped it up * make format * search updates * updates * cleanup * make format * remove unused imports * poetry lock
This commit is contained in:
parent
ec2e51c5ec
commit
c8ff5be8ce
12 changed files with 878 additions and 264 deletions
|
|
@ -58,13 +58,21 @@ def setup_logging():
|
||||||
|
|
||||||
async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str):
|
async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str):
|
||||||
search_start = time()
|
search_start = time()
|
||||||
results = await graphiti._search(query, COMBINED_HYBRID_SEARCH_RRF, group_ids=[str(group_id)])
|
results = await graphiti._search(
|
||||||
|
query,
|
||||||
|
COMBINED_HYBRID_SEARCH_RRF,
|
||||||
|
group_ids=[str(group_id)],
|
||||||
|
)
|
||||||
search_end = time()
|
search_end = time()
|
||||||
search_duration = search_end - search_start
|
search_duration = search_end - search_start
|
||||||
|
|
||||||
facts = [edge.fact for edge in results.edges]
|
facts = [edge.fact for edge in results.edges]
|
||||||
entity_summaries = [node.name + ': ' + node.summary for node in results.nodes]
|
entity_summaries = [node.name + ': ' + node.summary for node in results.nodes]
|
||||||
context = {'facts': facts, 'entity_summaries': entity_summaries, 'query': 'Bob: ' + query}
|
context = {
|
||||||
|
'facts': facts,
|
||||||
|
'entity_summaries': entity_summaries,
|
||||||
|
'query': 'Bob: ' + query,
|
||||||
|
}
|
||||||
|
|
||||||
llm_response = await graphiti.llm_client.generate_response(
|
llm_response = await graphiti.llm_client.generate_response(
|
||||||
prompt_library.eval.qa_prompt(context)
|
prompt_library.eval.qa_prompt(context)
|
||||||
|
|
@ -96,7 +104,14 @@ async def main():
|
||||||
setup_logging()
|
setup_logging()
|
||||||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||||
|
|
||||||
fields = ['Group id', 'Question', 'Answer', 'Response', 'Score', 'Search Duration (ms)']
|
fields = [
|
||||||
|
'Group id',
|
||||||
|
'Question',
|
||||||
|
'Answer',
|
||||||
|
'Response',
|
||||||
|
'Score',
|
||||||
|
'Search Duration (ms)',
|
||||||
|
]
|
||||||
with open('../data/msc_eval.csv', 'w', newline='') as file:
|
with open('../data/msc_eval.csv', 'w', newline='') as file:
|
||||||
writer = csv.DictWriter(file, fieldnames=fields)
|
writer = csv.DictWriter(file, fieldnames=fields)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ async def main():
|
||||||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||||
msc_messages = parse_msc_messages()
|
msc_messages = parse_msc_messages()
|
||||||
i = 0
|
i = 0
|
||||||
while i <= 490:
|
while i < len(msc_messages):
|
||||||
msc_message_slice = msc_messages[i : i + 10]
|
msc_message_slice = msc_messages[i : i + 10]
|
||||||
group_ids = range(len(msc_messages))[i : i + 10]
|
group_ids = range(len(msc_messages))[i : i + 10]
|
||||||
|
|
||||||
|
|
@ -84,8 +84,5 @@ async def main():
|
||||||
|
|
||||||
i += 10
|
i += 10
|
||||||
|
|
||||||
# build communities
|
|
||||||
# await client.build_communities()
|
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
|
|
@ -188,9 +188,9 @@ class EntityEdge(Edge):
|
||||||
MATCH (source:Entity {uuid: $source_uuid})
|
MATCH (source:Entity {uuid: $source_uuid})
|
||||||
MATCH (target:Entity {uuid: $target_uuid})
|
MATCH (target:Entity {uuid: $target_uuid})
|
||||||
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
||||||
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
|
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
|
||||||
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
|
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
|
||||||
valid_at: $valid_at, invalid_at: $invalid_at}
|
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
||||||
RETURN r.uuid AS uuid""",
|
RETURN r.uuid AS uuid""",
|
||||||
source_uuid=self.source_node_uuid,
|
source_uuid=self.source_node_uuid,
|
||||||
target_uuid=self.target_node_uuid,
|
target_uuid=self.target_node_uuid,
|
||||||
|
|
|
||||||
|
|
@ -700,18 +700,17 @@ class Graphiti:
|
||||||
).nodes
|
).nodes
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
||||||
|
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
||||||
|
|
||||||
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
edges_list = await asyncio.gather(
|
||||||
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
||||||
|
)
|
||||||
|
|
||||||
edges_list = await asyncio.gather(
|
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
||||||
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
|
||||||
)
|
|
||||||
|
|
||||||
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
nodes = await get_mentioned_nodes(self.driver, episodes)
|
||||||
|
|
||||||
nodes = await get_mentioned_nodes(self.driver, episodes)
|
communities = await get_communities_by_nodes(self.driver, nodes)
|
||||||
|
|
||||||
communities = await get_communities_by_nodes(self.driver, nodes)
|
return SearchResults(edges=edges, nodes=nodes, communities=communities)
|
||||||
|
|
||||||
return SearchResults(edges=edges, nodes=nodes, communities=communities)
|
|
||||||
|
|
|
||||||
|
|
@ -225,7 +225,8 @@ class EntityNode(Node):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MERGE (n:Entity {uuid: $uuid})
|
MERGE (n:Entity {uuid: $uuid})
|
||||||
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||||
|
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||||
RETURN n.uuid AS uuid""",
|
RETURN n.uuid AS uuid""",
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
|
|
@ -308,7 +309,8 @@ class CommunityNode(Node):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MERGE (n:Community {uuid: $uuid})
|
MERGE (n:Community {uuid: $uuid})
|
||||||
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||||
|
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||||
RETURN n.uuid AS uuid""",
|
RETURN n.uuid AS uuid""",
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
|
|
|
||||||
|
|
@ -23,11 +23,33 @@ from .models import Message, PromptFunction, PromptVersion
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
qa_prompt: PromptVersion
|
qa_prompt: PromptVersion
|
||||||
eval_prompt: PromptVersion
|
eval_prompt: PromptVersion
|
||||||
|
query_expansion: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
qa_prompt: PromptFunction
|
qa_prompt: PromptFunction
|
||||||
eval_prompt: PromptFunction
|
eval_prompt: PromptFunction
|
||||||
|
query_expansion: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
|
def query_expansion(context: dict[str, Any]) -> list[Message]:
|
||||||
|
sys_prompt = """You are an expert at rephrasing questions into queries used in a database retrieval system"""
|
||||||
|
|
||||||
|
user_prompt = f"""
|
||||||
|
Bob is asking Alice a question, are you able to rephrase the question into a simpler one about Alice in the third person
|
||||||
|
that maintains the relevant context?
|
||||||
|
<QUESTION>
|
||||||
|
{json.dumps(context['query'])}
|
||||||
|
</QUESTION>
|
||||||
|
respond with a JSON object in the following format:
|
||||||
|
{{
|
||||||
|
"query": "query optimized for database search"
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
Message(role='system', content=sys_prompt),
|
||||||
|
Message(role='user', content=user_prompt),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
@ -38,7 +60,7 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
||||||
You are given the following entity summaries and facts to help you determine the answer to your question.
|
You are given the following entity summaries and facts to help you determine the answer to your question.
|
||||||
<ENTITY_SUMMARIES>
|
<ENTITY_SUMMARIES>
|
||||||
{json.dumps(context['entity_summaries'])}
|
{json.dumps(context['entity_summaries'])}
|
||||||
</ENTITY_SUMMARIES
|
</ENTITY_SUMMARIES>
|
||||||
<FACTS>
|
<FACTS>
|
||||||
{json.dumps(context['facts'])}
|
{json.dumps(context['facts'])}
|
||||||
</FACTS>
|
</FACTS>
|
||||||
|
|
@ -87,4 +109,8 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt}
|
versions: Versions = {
|
||||||
|
'qa_prompt': qa_prompt,
|
||||||
|
'eval_prompt': eval_prompt,
|
||||||
|
'query_expansion': query_expansion,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -113,8 +113,9 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
2. Each edge should represent a clear relationship between two DISTINCT nodes.
|
2. Each edge should represent a clear relationship between two DISTINCT nodes.
|
||||||
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
|
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
|
||||||
4. Provide a more detailed fact describing the relationship.
|
4. Provide a more detailed fact describing the relationship.
|
||||||
5. Consider temporal aspects of relationships when relevant.
|
5. The fact should include any specific relevant information, including numeric information
|
||||||
6. Avoid using the same node as the source and target of a relationship
|
6. Consider temporal aspects of relationships when relevant.
|
||||||
|
7. Avoid using the same node as the source and target of a relationship
|
||||||
|
|
||||||
Respond with a JSON object in the following format:
|
Respond with a JSON object in the following format:
|
||||||
{{
|
{{
|
||||||
|
|
|
||||||
|
|
@ -29,13 +29,10 @@ from graphiti_core.search.search_config import (
|
||||||
DEFAULT_SEARCH_LIMIT,
|
DEFAULT_SEARCH_LIMIT,
|
||||||
CommunityReranker,
|
CommunityReranker,
|
||||||
CommunitySearchConfig,
|
CommunitySearchConfig,
|
||||||
CommunitySearchMethod,
|
|
||||||
EdgeReranker,
|
EdgeReranker,
|
||||||
EdgeSearchConfig,
|
EdgeSearchConfig,
|
||||||
EdgeSearchMethod,
|
|
||||||
NodeReranker,
|
NodeReranker,
|
||||||
NodeSearchConfig,
|
NodeSearchConfig,
|
||||||
NodeSearchMethod,
|
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
SearchResults,
|
SearchResults,
|
||||||
)
|
)
|
||||||
|
|
@ -120,22 +117,16 @@ async def edge_search(
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
search_results: list[list[EntityEdge]] = []
|
search_results: list[list[EntityEdge]] = list(
|
||||||
|
await asyncio.gather(
|
||||||
if EdgeSearchMethod.bm25 in config.search_methods:
|
*[
|
||||||
text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
|
edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit),
|
||||||
search_results.append(text_search)
|
edge_similarity_search(
|
||||||
|
driver, await embedder.create(input=[query]), None, None, group_ids, 2 * limit
|
||||||
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
),
|
||||||
search_vector = await embedder.create(input=[query])
|
]
|
||||||
|
|
||||||
similarity_search = await edge_similarity_search(
|
|
||||||
driver, search_vector, None, None, group_ids, 2 * limit
|
|
||||||
)
|
)
|
||||||
search_results.append(similarity_search)
|
)
|
||||||
|
|
||||||
if len(search_results) > 1 and config.reranker is None:
|
|
||||||
raise SearchRerankerError('Multiple edge searches enabled without a reranker')
|
|
||||||
|
|
||||||
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
||||||
|
|
||||||
|
|
@ -184,22 +175,16 @@ async def node_search(
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
search_results: list[list[EntityNode]] = []
|
search_results: list[list[EntityNode]] = list(
|
||||||
|
await asyncio.gather(
|
||||||
if NodeSearchMethod.bm25 in config.search_methods:
|
*[
|
||||||
text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
|
node_fulltext_search(driver, query, group_ids, 2 * limit),
|
||||||
search_results.append(text_search)
|
node_similarity_search(
|
||||||
|
driver, await embedder.create(input=[query]), group_ids, 2 * limit
|
||||||
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
),
|
||||||
search_vector = await embedder.create(input=[query])
|
]
|
||||||
|
|
||||||
similarity_search = await node_similarity_search(
|
|
||||||
driver, search_vector, group_ids, 2 * limit
|
|
||||||
)
|
)
|
||||||
search_results.append(similarity_search)
|
)
|
||||||
|
|
||||||
if len(search_results) > 1 and config.reranker is None:
|
|
||||||
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
|
||||||
|
|
||||||
search_result_uuids = [[node.uuid for node in result] for result in search_results]
|
search_result_uuids = [[node.uuid for node in result] for result in search_results]
|
||||||
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
||||||
|
|
@ -232,22 +217,16 @@ async def community_search(
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
search_results: list[list[CommunityNode]] = []
|
search_results: list[list[CommunityNode]] = list(
|
||||||
|
await asyncio.gather(
|
||||||
if CommunitySearchMethod.bm25 in config.search_methods:
|
*[
|
||||||
text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
|
community_fulltext_search(driver, query, group_ids, 2 * limit),
|
||||||
search_results.append(text_search)
|
community_similarity_search(
|
||||||
|
driver, await embedder.create(input=[query]), group_ids, 2 * limit
|
||||||
if CommunitySearchMethod.cosine_similarity in config.search_methods:
|
),
|
||||||
search_vector = await embedder.create(input=[query])
|
]
|
||||||
|
|
||||||
similarity_search = await community_similarity_search(
|
|
||||||
driver, search_vector, group_ids, 2 * limit
|
|
||||||
)
|
)
|
||||||
search_results.append(similarity_search)
|
)
|
||||||
|
|
||||||
if len(search_results) > 1 and config.reranker is None:
|
|
||||||
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
|
||||||
|
|
||||||
search_result_uuids = [[community.uuid for community in result] for result in search_results]
|
search_result_uuids = [[community.uuid for community in result] for result in search_results]
|
||||||
community_uuid_map = {
|
community_uuid_map = {
|
||||||
|
|
|
||||||
|
|
@ -57,17 +57,17 @@ class CommunityReranker(Enum):
|
||||||
|
|
||||||
class EdgeSearchConfig(BaseModel):
|
class EdgeSearchConfig(BaseModel):
|
||||||
search_methods: list[EdgeSearchMethod]
|
search_methods: list[EdgeSearchMethod]
|
||||||
reranker: EdgeReranker | None
|
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
|
||||||
|
|
||||||
|
|
||||||
class NodeSearchConfig(BaseModel):
|
class NodeSearchConfig(BaseModel):
|
||||||
search_methods: list[NodeSearchMethod]
|
search_methods: list[NodeSearchMethod]
|
||||||
reranker: NodeReranker | None
|
reranker: NodeReranker = Field(default=NodeReranker.rrf)
|
||||||
|
|
||||||
|
|
||||||
class CommunitySearchConfig(BaseModel):
|
class CommunitySearchConfig(BaseModel):
|
||||||
search_methods: list[CommunitySearchMethod]
|
search_methods: list[CommunitySearchMethod]
|
||||||
reranker: CommunityReranker | None
|
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
||||||
|
|
||||||
|
|
||||||
class SearchConfig(BaseModel):
|
class SearchConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,21 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||||
return full_query
|
return full_query
|
||||||
|
|
||||||
|
|
||||||
|
async def get_episodes_by_mentions(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
nodes: list[EntityNode],
|
||||||
|
edges: list[EntityEdge],
|
||||||
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
|
) -> list[EpisodicNode]:
|
||||||
|
episode_uuids: list[str] = []
|
||||||
|
for edge in edges:
|
||||||
|
episode_uuids.extend(edge.episodes)
|
||||||
|
|
||||||
|
episodes = await EpisodicNode.get_by_uuids(driver, episode_uuids[:limit])
|
||||||
|
|
||||||
|
return episodes
|
||||||
|
|
||||||
|
|
||||||
async def get_mentioned_nodes(
|
async def get_mentioned_nodes(
|
||||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
|
|
@ -113,9 +128,6 @@ async def edge_fulltext_search(
|
||||||
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
|
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||||
WHERE ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
|
||||||
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
|
||||||
AND ($group_ids IS NULL OR n.group_id IN $group_ids)
|
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
|
|
@ -156,12 +168,14 @@ async def edge_similarity_search(
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
query = Query("""
|
query = Query("""
|
||||||
|
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||||
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
||||||
AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
||||||
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
||||||
|
WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||||
|
WHERE score > 0.6
|
||||||
RETURN
|
RETURN
|
||||||
vector.similarity.cosine(r.fact_embedding, $search_vector) AS score,
|
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
n.uuid AS source_node_uuid,
|
n.uuid AS source_node_uuid,
|
||||||
|
|
@ -205,7 +219,6 @@ async def node_fulltext_search(
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
||||||
YIELD node AS n, score
|
YIELD node AS n, score
|
||||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid AS uuid,
|
n.uuid AS uuid,
|
||||||
n.group_id AS group_id,
|
n.group_id AS group_id,
|
||||||
|
|
@ -234,10 +247,12 @@ async def node_similarity_search(
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
|
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||||
|
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||||
|
WHERE score > 0.6
|
||||||
RETURN
|
RETURN
|
||||||
vector.similarity.cosine(n.name_embedding, $search_vector) AS score,
|
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
n.group_id AS group_id,
|
n.group_id AS group_id,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
|
|
@ -269,8 +284,6 @@ async def community_fulltext_search(
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||||
YIELD node AS comm, score
|
YIELD node AS comm, score
|
||||||
MATCH (comm:Community)
|
|
||||||
WHERE $group_ids IS NULL OR comm.group_id in $group_ids
|
|
||||||
RETURN
|
RETURN
|
||||||
comm.uuid AS uuid,
|
comm.uuid AS uuid,
|
||||||
comm.group_id AS group_id,
|
comm.group_id AS group_id,
|
||||||
|
|
@ -299,10 +312,12 @@ async def community_similarity_search(
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
|
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||||
MATCH (comm:Community)
|
MATCH (comm:Community)
|
||||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||||
|
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
||||||
|
WHERE score > 0.6
|
||||||
RETURN
|
RETURN
|
||||||
vector.similarity.cosine(comm.name_embedding, $search_vector) AS score,
|
|
||||||
comm.uuid As uuid,
|
comm.uuid As uuid,
|
||||||
comm.group_id AS group_id,
|
comm.group_id AS group_id,
|
||||||
comm.name AS name,
|
comm.name AS name,
|
||||||
|
|
|
||||||
935
poetry.lock
generated
935
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -20,6 +20,7 @@ diskcache = "^5.6.3"
|
||||||
openai = "^1.50.2"
|
openai = "^1.50.2"
|
||||||
tenacity = "<9.0.0"
|
tenacity = "<9.0.0"
|
||||||
numpy = ">=1.0.0"
|
numpy = ">=1.0.0"
|
||||||
|
voyageai = "^0.2.3"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
pytest = "^8.3.3"
|
pytest = "^8.3.3"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue