Msc benchmark update (#173)
* eval update * I sped it up * make format * search updates * updates * cleanup * make format * remove unused imports * poetry lock
This commit is contained in:
parent
ec2e51c5ec
commit
c8ff5be8ce
12 changed files with 878 additions and 264 deletions
|
|
@ -58,13 +58,21 @@ def setup_logging():
|
|||
|
||||
async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str):
|
||||
search_start = time()
|
||||
results = await graphiti._search(query, COMBINED_HYBRID_SEARCH_RRF, group_ids=[str(group_id)])
|
||||
results = await graphiti._search(
|
||||
query,
|
||||
COMBINED_HYBRID_SEARCH_RRF,
|
||||
group_ids=[str(group_id)],
|
||||
)
|
||||
search_end = time()
|
||||
search_duration = search_end - search_start
|
||||
|
||||
facts = [edge.fact for edge in results.edges]
|
||||
entity_summaries = [node.name + ': ' + node.summary for node in results.nodes]
|
||||
context = {'facts': facts, 'entity_summaries': entity_summaries, 'query': 'Bob: ' + query}
|
||||
context = {
|
||||
'facts': facts,
|
||||
'entity_summaries': entity_summaries,
|
||||
'query': 'Bob: ' + query,
|
||||
}
|
||||
|
||||
llm_response = await graphiti.llm_client.generate_response(
|
||||
prompt_library.eval.qa_prompt(context)
|
||||
|
|
@ -96,7 +104,14 @@ async def main():
|
|||
setup_logging()
|
||||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
|
||||
fields = ['Group id', 'Question', 'Answer', 'Response', 'Score', 'Search Duration (ms)']
|
||||
fields = [
|
||||
'Group id',
|
||||
'Question',
|
||||
'Answer',
|
||||
'Response',
|
||||
'Score',
|
||||
'Search Duration (ms)',
|
||||
]
|
||||
with open('../data/msc_eval.csv', 'w', newline='') as file:
|
||||
writer = csv.DictWriter(file, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ async def main():
|
|||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
msc_messages = parse_msc_messages()
|
||||
i = 0
|
||||
while i <= 490:
|
||||
while i < len(msc_messages):
|
||||
msc_message_slice = msc_messages[i : i + 10]
|
||||
group_ids = range(len(msc_messages))[i : i + 10]
|
||||
|
||||
|
|
@ -84,8 +84,5 @@ async def main():
|
|||
|
||||
i += 10
|
||||
|
||||
# build communities
|
||||
# await client.build_communities()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -188,9 +188,9 @@ class EntityEdge(Edge):
|
|||
MATCH (source:Entity {uuid: $source_uuid})
|
||||
MATCH (target:Entity {uuid: $target_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
||||
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
|
||||
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
|
||||
valid_at: $valid_at, invalid_at: $invalid_at}
|
||||
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
|
||||
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
|
||||
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
||||
RETURN r.uuid AS uuid""",
|
||||
source_uuid=self.source_node_uuid,
|
||||
target_uuid=self.target_node_uuid,
|
||||
|
|
|
|||
|
|
@ -700,18 +700,17 @@ class Graphiti:
|
|||
).nodes
|
||||
return nodes
|
||||
|
||||
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
||||
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
||||
|
||||
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
||||
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
||||
edges_list = await asyncio.gather(
|
||||
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
||||
)
|
||||
|
||||
edges_list = await asyncio.gather(
|
||||
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
||||
)
|
||||
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
||||
|
||||
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
||||
nodes = await get_mentioned_nodes(self.driver, episodes)
|
||||
|
||||
nodes = await get_mentioned_nodes(self.driver, episodes)
|
||||
communities = await get_communities_by_nodes(self.driver, nodes)
|
||||
|
||||
communities = await get_communities_by_nodes(self.driver, nodes)
|
||||
|
||||
return SearchResults(edges=edges, nodes=nodes, communities=communities)
|
||||
return SearchResults(edges=edges, nodes=nodes, communities=communities)
|
||||
|
|
|
|||
|
|
@ -225,7 +225,8 @@ class EntityNode(Node):
|
|||
result = await driver.execute_query(
|
||||
"""
|
||||
MERGE (n:Entity {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||
RETURN n.uuid AS uuid""",
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
|
|
@ -308,7 +309,8 @@ class CommunityNode(Node):
|
|||
result = await driver.execute_query(
|
||||
"""
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||
RETURN n.uuid AS uuid""",
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
|
|
|
|||
|
|
@ -23,11 +23,33 @@ from .models import Message, PromptFunction, PromptVersion
|
|||
class Prompt(Protocol):
|
||||
qa_prompt: PromptVersion
|
||||
eval_prompt: PromptVersion
|
||||
query_expansion: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
qa_prompt: PromptFunction
|
||||
eval_prompt: PromptFunction
|
||||
query_expansion: PromptFunction
|
||||
|
||||
|
||||
def query_expansion(context: dict[str, Any]) -> list[Message]:
|
||||
sys_prompt = """You are an expert at rephrasing questions into queries used in a database retrieval system"""
|
||||
|
||||
user_prompt = f"""
|
||||
Bob is asking Alice a question, are you able to rephrase the question into a simpler one about Alice in the third person
|
||||
that maintains the relevant context?
|
||||
<QUESTION>
|
||||
{json.dumps(context['query'])}
|
||||
</QUESTION>
|
||||
respond with a JSON object in the following format:
|
||||
{{
|
||||
"query": "query optimized for database search"
|
||||
}}
|
||||
"""
|
||||
return [
|
||||
Message(role='system', content=sys_prompt),
|
||||
Message(role='user', content=user_prompt),
|
||||
]
|
||||
|
||||
|
||||
def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
||||
|
|
@ -38,7 +60,7 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
|||
You are given the following entity summaries and facts to help you determine the answer to your question.
|
||||
<ENTITY_SUMMARIES>
|
||||
{json.dumps(context['entity_summaries'])}
|
||||
</ENTITY_SUMMARIES
|
||||
</ENTITY_SUMMARIES>
|
||||
<FACTS>
|
||||
{json.dumps(context['facts'])}
|
||||
</FACTS>
|
||||
|
|
@ -87,4 +109,8 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
|
|||
]
|
||||
|
||||
|
||||
versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt}
|
||||
versions: Versions = {
|
||||
'qa_prompt': qa_prompt,
|
||||
'eval_prompt': eval_prompt,
|
||||
'query_expansion': query_expansion,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -113,8 +113,9 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
|||
2. Each edge should represent a clear relationship between two DISTINCT nodes.
|
||||
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
|
||||
4. Provide a more detailed fact describing the relationship.
|
||||
5. Consider temporal aspects of relationships when relevant.
|
||||
6. Avoid using the same node as the source and target of a relationship
|
||||
5. The fact should include any specific relevant information, including numeric information
|
||||
6. Consider temporal aspects of relationships when relevant.
|
||||
7. Avoid using the same node as the source and target of a relationship
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
|
|
|
|||
|
|
@ -29,13 +29,10 @@ from graphiti_core.search.search_config import (
|
|||
DEFAULT_SEARCH_LIMIT,
|
||||
CommunityReranker,
|
||||
CommunitySearchConfig,
|
||||
CommunitySearchMethod,
|
||||
EdgeReranker,
|
||||
EdgeSearchConfig,
|
||||
EdgeSearchMethod,
|
||||
NodeReranker,
|
||||
NodeSearchConfig,
|
||||
NodeSearchMethod,
|
||||
SearchConfig,
|
||||
SearchResults,
|
||||
)
|
||||
|
|
@ -120,22 +117,16 @@ async def edge_search(
|
|||
if config is None:
|
||||
return []
|
||||
|
||||
search_results: list[list[EntityEdge]] = []
|
||||
|
||||
if EdgeSearchMethod.bm25 in config.search_methods:
|
||||
text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
|
||||
search_results.append(text_search)
|
||||
|
||||
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
||||
search_vector = await embedder.create(input=[query])
|
||||
|
||||
similarity_search = await edge_similarity_search(
|
||||
driver, search_vector, None, None, group_ids, 2 * limit
|
||||
search_results: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit),
|
||||
edge_similarity_search(
|
||||
driver, await embedder.create(input=[query]), None, None, group_ids, 2 * limit
|
||||
),
|
||||
]
|
||||
)
|
||||
search_results.append(similarity_search)
|
||||
|
||||
if len(search_results) > 1 and config.reranker is None:
|
||||
raise SearchRerankerError('Multiple edge searches enabled without a reranker')
|
||||
)
|
||||
|
||||
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
||||
|
||||
|
|
@ -184,22 +175,16 @@ async def node_search(
|
|||
if config is None:
|
||||
return []
|
||||
|
||||
search_results: list[list[EntityNode]] = []
|
||||
|
||||
if NodeSearchMethod.bm25 in config.search_methods:
|
||||
text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
|
||||
search_results.append(text_search)
|
||||
|
||||
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
||||
search_vector = await embedder.create(input=[query])
|
||||
|
||||
similarity_search = await node_similarity_search(
|
||||
driver, search_vector, group_ids, 2 * limit
|
||||
search_results: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
node_fulltext_search(driver, query, group_ids, 2 * limit),
|
||||
node_similarity_search(
|
||||
driver, await embedder.create(input=[query]), group_ids, 2 * limit
|
||||
),
|
||||
]
|
||||
)
|
||||
search_results.append(similarity_search)
|
||||
|
||||
if len(search_results) > 1 and config.reranker is None:
|
||||
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
||||
)
|
||||
|
||||
search_result_uuids = [[node.uuid for node in result] for result in search_results]
|
||||
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
||||
|
|
@ -232,22 +217,16 @@ async def community_search(
|
|||
if config is None:
|
||||
return []
|
||||
|
||||
search_results: list[list[CommunityNode]] = []
|
||||
|
||||
if CommunitySearchMethod.bm25 in config.search_methods:
|
||||
text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
|
||||
search_results.append(text_search)
|
||||
|
||||
if CommunitySearchMethod.cosine_similarity in config.search_methods:
|
||||
search_vector = await embedder.create(input=[query])
|
||||
|
||||
similarity_search = await community_similarity_search(
|
||||
driver, search_vector, group_ids, 2 * limit
|
||||
search_results: list[list[CommunityNode]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
community_fulltext_search(driver, query, group_ids, 2 * limit),
|
||||
community_similarity_search(
|
||||
driver, await embedder.create(input=[query]), group_ids, 2 * limit
|
||||
),
|
||||
]
|
||||
)
|
||||
search_results.append(similarity_search)
|
||||
|
||||
if len(search_results) > 1 and config.reranker is None:
|
||||
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
||||
)
|
||||
|
||||
search_result_uuids = [[community.uuid for community in result] for result in search_results]
|
||||
community_uuid_map = {
|
||||
|
|
|
|||
|
|
@ -57,17 +57,17 @@ class CommunityReranker(Enum):
|
|||
|
||||
class EdgeSearchConfig(BaseModel):
|
||||
search_methods: list[EdgeSearchMethod]
|
||||
reranker: EdgeReranker | None
|
||||
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
|
||||
|
||||
|
||||
class NodeSearchConfig(BaseModel):
|
||||
search_methods: list[NodeSearchMethod]
|
||||
reranker: NodeReranker | None
|
||||
reranker: NodeReranker = Field(default=NodeReranker.rrf)
|
||||
|
||||
|
||||
class CommunitySearchConfig(BaseModel):
|
||||
search_methods: list[CommunitySearchMethod]
|
||||
reranker: CommunityReranker | None
|
||||
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -52,6 +52,21 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
|
|||
return full_query
|
||||
|
||||
|
||||
async def get_episodes_by_mentions(
|
||||
driver: AsyncDriver,
|
||||
nodes: list[EntityNode],
|
||||
edges: list[EntityEdge],
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EpisodicNode]:
|
||||
episode_uuids: list[str] = []
|
||||
for edge in edges:
|
||||
episode_uuids.extend(edge.episodes)
|
||||
|
||||
episodes = await EpisodicNode.get_by_uuids(driver, episode_uuids[:limit])
|
||||
|
||||
return episodes
|
||||
|
||||
|
||||
async def get_mentioned_nodes(
|
||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||
) -> list[EntityNode]:
|
||||
|
|
@ -113,9 +128,6 @@ async def edge_fulltext_search(
|
|||
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
||||
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
||||
AND ($group_ids IS NULL OR n.group_id IN $group_ids)
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
|
|
@ -156,12 +168,14 @@ async def edge_similarity_search(
|
|||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
query = Query("""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
||||
AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
||||
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
||||
WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
WHERE score > 0.6
|
||||
RETURN
|
||||
vector.similarity.cosine(r.fact_embedding, $search_vector) AS score,
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
|
|
@ -205,7 +219,6 @@ async def node_fulltext_search(
|
|||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
||||
YIELD node AS n, score
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
RETURN
|
||||
n.uuid AS uuid,
|
||||
n.group_id AS group_id,
|
||||
|
|
@ -234,10 +247,12 @@ async def node_similarity_search(
|
|||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
MATCH (n:Entity)
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||
WHERE score > 0.6
|
||||
RETURN
|
||||
vector.similarity.cosine(n.name_embedding, $search_vector) AS score,
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
|
|
@ -269,8 +284,6 @@ async def community_fulltext_search(
|
|||
"""
|
||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||
YIELD node AS comm, score
|
||||
MATCH (comm:Community)
|
||||
WHERE $group_ids IS NULL OR comm.group_id in $group_ids
|
||||
RETURN
|
||||
comm.uuid AS uuid,
|
||||
comm.group_id AS group_id,
|
||||
|
|
@ -299,10 +312,12 @@ async def community_similarity_search(
|
|||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
MATCH (comm:Community)
|
||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
||||
WHERE score > 0.6
|
||||
RETURN
|
||||
vector.similarity.cosine(comm.name_embedding, $search_vector) AS score,
|
||||
comm.uuid As uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
|
|
|
|||
935
poetry.lock
generated
935
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -20,6 +20,7 @@ diskcache = "^5.6.3"
|
|||
openai = "^1.50.2"
|
||||
tenacity = "<9.0.0"
|
||||
numpy = ">=1.0.0"
|
||||
voyageai = "^0.2.3"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^8.3.3"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue