Msc benchmark update (#173)

* eval update

* I sped it up

* make format

* search updates

* updates

* cleanup

* make format

* remove unused imports

* poetry lock
This commit is contained in:
Preston Rasmussen 2024-10-03 15:39:35 -04:00 committed by GitHub
parent ec2e51c5ec
commit c8ff5be8ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 878 additions and 264 deletions

View file

@ -58,13 +58,21 @@ def setup_logging():
async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str): async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str):
search_start = time() search_start = time()
results = await graphiti._search(query, COMBINED_HYBRID_SEARCH_RRF, group_ids=[str(group_id)]) results = await graphiti._search(
query,
COMBINED_HYBRID_SEARCH_RRF,
group_ids=[str(group_id)],
)
search_end = time() search_end = time()
search_duration = search_end - search_start search_duration = search_end - search_start
facts = [edge.fact for edge in results.edges] facts = [edge.fact for edge in results.edges]
entity_summaries = [node.name + ': ' + node.summary for node in results.nodes] entity_summaries = [node.name + ': ' + node.summary for node in results.nodes]
context = {'facts': facts, 'entity_summaries': entity_summaries, 'query': 'Bob: ' + query} context = {
'facts': facts,
'entity_summaries': entity_summaries,
'query': 'Bob: ' + query,
}
llm_response = await graphiti.llm_client.generate_response( llm_response = await graphiti.llm_client.generate_response(
prompt_library.eval.qa_prompt(context) prompt_library.eval.qa_prompt(context)
@ -96,7 +104,14 @@ async def main():
setup_logging() setup_logging()
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password) graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
fields = ['Group id', 'Question', 'Answer', 'Response', 'Score', 'Search Duration (ms)'] fields = [
'Group id',
'Question',
'Answer',
'Response',
'Score',
'Search Duration (ms)',
]
with open('../data/msc_eval.csv', 'w', newline='') as file: with open('../data/msc_eval.csv', 'w', newline='') as file:
writer = csv.DictWriter(file, fieldnames=fields) writer = csv.DictWriter(file, fieldnames=fields)
writer.writeheader() writer.writeheader()

View file

@ -71,7 +71,7 @@ async def main():
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password) graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
msc_messages = parse_msc_messages() msc_messages = parse_msc_messages()
i = 0 i = 0
while i <= 490: while i < len(msc_messages):
msc_message_slice = msc_messages[i : i + 10] msc_message_slice = msc_messages[i : i + 10]
group_ids = range(len(msc_messages))[i : i + 10] group_ids = range(len(msc_messages))[i : i + 10]
@ -84,8 +84,5 @@ async def main():
i += 10 i += 10
# build communities
# await client.build_communities()
asyncio.run(main()) asyncio.run(main())

View file

@ -188,9 +188,9 @@ class EntityEdge(Edge):
MATCH (source:Entity {uuid: $source_uuid}) MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid}) MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target) MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding, SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
episodes: $episodes, created_at: $created_at, expired_at: $expired_at, created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
valid_at: $valid_at, invalid_at: $invalid_at} WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
RETURN r.uuid AS uuid""", RETURN r.uuid AS uuid""",
source_uuid=self.source_node_uuid, source_uuid=self.source_node_uuid,
target_uuid=self.target_node_uuid, target_uuid=self.target_node_uuid,

View file

@ -700,18 +700,17 @@ class Graphiti:
).nodes ).nodes
return nodes return nodes
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults: edges_list = await asyncio.gather(
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids) *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
)
edges_list = await asyncio.gather( edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
)
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst] nodes = await get_mentioned_nodes(self.driver, episodes)
nodes = await get_mentioned_nodes(self.driver, episodes) communities = await get_communities_by_nodes(self.driver, nodes)
communities = await get_communities_by_nodes(self.driver, nodes) return SearchResults(edges=edges, nodes=nodes, communities=communities)
return SearchResults(edges=edges, nodes=nodes, communities=communities)

View file

@ -225,7 +225,8 @@ class EntityNode(Node):
result = await driver.execute_query( result = await driver.execute_query(
""" """
MERGE (n:Entity {uuid: $uuid}) MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at} SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid""", RETURN n.uuid AS uuid""",
uuid=self.uuid, uuid=self.uuid,
name=self.name, name=self.name,
@ -308,7 +309,8 @@ class CommunityNode(Node):
result = await driver.execute_query( result = await driver.execute_query(
""" """
MERGE (n:Community {uuid: $uuid}) MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at} SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid""", RETURN n.uuid AS uuid""",
uuid=self.uuid, uuid=self.uuid,
name=self.name, name=self.name,

View file

@ -23,11 +23,33 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol): class Prompt(Protocol):
qa_prompt: PromptVersion qa_prompt: PromptVersion
eval_prompt: PromptVersion eval_prompt: PromptVersion
query_expansion: PromptVersion
class Versions(TypedDict): class Versions(TypedDict):
qa_prompt: PromptFunction qa_prompt: PromptFunction
eval_prompt: PromptFunction eval_prompt: PromptFunction
query_expansion: PromptFunction
def query_expansion(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an expert at rephrasing questions into queries used in a database retrieval system"""
user_prompt = f"""
Bob is asking Alice a question, are you able to rephrase the question into a simpler one about Alice in the third person
that maintains the relevant context?
<QUESTION>
{json.dumps(context['query'])}
</QUESTION>
respond with a JSON object in the following format:
{{
"query": "query optimized for database search"
}}
"""
return [
Message(role='system', content=sys_prompt),
Message(role='user', content=user_prompt),
]
def qa_prompt(context: dict[str, Any]) -> list[Message]: def qa_prompt(context: dict[str, Any]) -> list[Message]:
@ -38,7 +60,7 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
You are given the following entity summaries and facts to help you determine the answer to your question. You are given the following entity summaries and facts to help you determine the answer to your question.
<ENTITY_SUMMARIES> <ENTITY_SUMMARIES>
{json.dumps(context['entity_summaries'])} {json.dumps(context['entity_summaries'])}
</ENTITY_SUMMARIES </ENTITY_SUMMARIES>
<FACTS> <FACTS>
{json.dumps(context['facts'])} {json.dumps(context['facts'])}
</FACTS> </FACTS>
@ -87,4 +109,8 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
] ]
versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt} versions: Versions = {
'qa_prompt': qa_prompt,
'eval_prompt': eval_prompt,
'query_expansion': query_expansion,
}

View file

@ -113,8 +113,9 @@ def v2(context: dict[str, Any]) -> list[Message]:
2. Each edge should represent a clear relationship between two DISTINCT nodes. 2. Each edge should represent a clear relationship between two DISTINCT nodes.
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR). 3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
4. Provide a more detailed fact describing the relationship. 4. Provide a more detailed fact describing the relationship.
5. Consider temporal aspects of relationships when relevant. 5. The fact should include any specific relevant information, including numeric information
6. Avoid using the same node as the source and target of a relationship 6. Consider temporal aspects of relationships when relevant.
7. Avoid using the same node as the source and target of a relationship
Respond with a JSON object in the following format: Respond with a JSON object in the following format:
{{ {{

View file

@ -29,13 +29,10 @@ from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_LIMIT,
CommunityReranker, CommunityReranker,
CommunitySearchConfig, CommunitySearchConfig,
CommunitySearchMethod,
EdgeReranker, EdgeReranker,
EdgeSearchConfig, EdgeSearchConfig,
EdgeSearchMethod,
NodeReranker, NodeReranker,
NodeSearchConfig, NodeSearchConfig,
NodeSearchMethod,
SearchConfig, SearchConfig,
SearchResults, SearchResults,
) )
@ -120,22 +117,16 @@ async def edge_search(
if config is None: if config is None:
return [] return []
search_results: list[list[EntityEdge]] = [] search_results: list[list[EntityEdge]] = list(
await asyncio.gather(
if EdgeSearchMethod.bm25 in config.search_methods: *[
text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit) edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit),
search_results.append(text_search) edge_similarity_search(
driver, await embedder.create(input=[query]), None, None, group_ids, 2 * limit
if EdgeSearchMethod.cosine_similarity in config.search_methods: ),
search_vector = await embedder.create(input=[query]) ]
similarity_search = await edge_similarity_search(
driver, search_vector, None, None, group_ids, 2 * limit
) )
search_results.append(similarity_search) )
if len(search_results) > 1 and config.reranker is None:
raise SearchRerankerError('Multiple edge searches enabled without a reranker')
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result} edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
@ -184,22 +175,16 @@ async def node_search(
if config is None: if config is None:
return [] return []
search_results: list[list[EntityNode]] = [] search_results: list[list[EntityNode]] = list(
await asyncio.gather(
if NodeSearchMethod.bm25 in config.search_methods: *[
text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit) node_fulltext_search(driver, query, group_ids, 2 * limit),
search_results.append(text_search) node_similarity_search(
driver, await embedder.create(input=[query]), group_ids, 2 * limit
if NodeSearchMethod.cosine_similarity in config.search_methods: ),
search_vector = await embedder.create(input=[query]) ]
similarity_search = await node_similarity_search(
driver, search_vector, group_ids, 2 * limit
) )
search_results.append(similarity_search) )
if len(search_results) > 1 and config.reranker is None:
raise SearchRerankerError('Multiple node searches enabled without a reranker')
search_result_uuids = [[node.uuid for node in result] for result in search_results] search_result_uuids = [[node.uuid for node in result] for result in search_results]
node_uuid_map = {node.uuid: node for result in search_results for node in result} node_uuid_map = {node.uuid: node for result in search_results for node in result}
@ -232,22 +217,16 @@ async def community_search(
if config is None: if config is None:
return [] return []
search_results: list[list[CommunityNode]] = [] search_results: list[list[CommunityNode]] = list(
await asyncio.gather(
if CommunitySearchMethod.bm25 in config.search_methods: *[
text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit) community_fulltext_search(driver, query, group_ids, 2 * limit),
search_results.append(text_search) community_similarity_search(
driver, await embedder.create(input=[query]), group_ids, 2 * limit
if CommunitySearchMethod.cosine_similarity in config.search_methods: ),
search_vector = await embedder.create(input=[query]) ]
similarity_search = await community_similarity_search(
driver, search_vector, group_ids, 2 * limit
) )
search_results.append(similarity_search) )
if len(search_results) > 1 and config.reranker is None:
raise SearchRerankerError('Multiple node searches enabled without a reranker')
search_result_uuids = [[community.uuid for community in result] for result in search_results] search_result_uuids = [[community.uuid for community in result] for result in search_results]
community_uuid_map = { community_uuid_map = {

View file

@ -57,17 +57,17 @@ class CommunityReranker(Enum):
class EdgeSearchConfig(BaseModel): class EdgeSearchConfig(BaseModel):
search_methods: list[EdgeSearchMethod] search_methods: list[EdgeSearchMethod]
reranker: EdgeReranker | None reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
class NodeSearchConfig(BaseModel): class NodeSearchConfig(BaseModel):
search_methods: list[NodeSearchMethod] search_methods: list[NodeSearchMethod]
reranker: NodeReranker | None reranker: NodeReranker = Field(default=NodeReranker.rrf)
class CommunitySearchConfig(BaseModel): class CommunitySearchConfig(BaseModel):
search_methods: list[CommunitySearchMethod] search_methods: list[CommunitySearchMethod]
reranker: CommunityReranker | None reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
class SearchConfig(BaseModel): class SearchConfig(BaseModel):

View file

@ -52,6 +52,21 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
return full_query return full_query
async def get_episodes_by_mentions(
driver: AsyncDriver,
nodes: list[EntityNode],
edges: list[EntityEdge],
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EpisodicNode]:
episode_uuids: list[str] = []
for edge in edges:
episode_uuids.extend(edge.episodes)
episodes = await EpisodicNode.get_by_uuids(driver, episode_uuids[:limit])
return episodes
async def get_mentioned_nodes( async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode] driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]: ) -> list[EntityNode]:
@ -113,9 +128,6 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query) CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE ($source_uuid IS NULL OR n.uuid = $source_uuid)
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
AND ($group_ids IS NULL OR n.group_id IN $group_ids)
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
@ -156,12 +168,14 @@ async def edge_similarity_search(
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# vector similarity search over embedded facts # vector similarity search over embedded facts
query = Query(""" query = Query("""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids) WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid = $source_uuid) AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
AND ($target_uuid IS NULL OR m.uuid = $target_uuid) AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > 0.6
RETURN RETURN
vector.similarity.cosine(r.fact_embedding, $search_vector) AS score,
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
@ -205,7 +219,6 @@ async def node_fulltext_search(
""" """
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query) CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
YIELD node AS n, score YIELD node AS n, score
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
RETURN RETURN
n.uuid AS uuid, n.uuid AS uuid,
n.group_id AS group_id, n.group_id AS group_id,
@ -234,10 +247,12 @@ async def node_similarity_search(
# vector similarity search over entity names # vector similarity search over entity names
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (n:Entity) MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids WHERE $group_ids IS NULL OR n.group_id IN $group_ids
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > 0.6
RETURN RETURN
vector.similarity.cosine(n.name_embedding, $search_vector) AS score,
n.uuid As uuid, n.uuid As uuid,
n.group_id AS group_id, n.group_id AS group_id,
n.name AS name, n.name AS name,
@ -269,8 +284,6 @@ async def community_fulltext_search(
""" """
CALL db.index.fulltext.queryNodes("community_name", $query) CALL db.index.fulltext.queryNodes("community_name", $query)
YIELD node AS comm, score YIELD node AS comm, score
MATCH (comm:Community)
WHERE $group_ids IS NULL OR comm.group_id in $group_ids
RETURN RETURN
comm.uuid AS uuid, comm.uuid AS uuid,
comm.group_id AS group_id, comm.group_id AS group_id,
@ -299,10 +312,12 @@ async def community_similarity_search(
# vector similarity search over entity names # vector similarity search over entity names
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (comm:Community) MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids) WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > 0.6
RETURN RETURN
vector.similarity.cosine(comm.name_embedding, $search_vector) AS score,
comm.uuid As uuid, comm.uuid As uuid,
comm.group_id AS group_id, comm.group_id AS group_id,
comm.name AS name, comm.name AS name,

935
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -20,6 +20,7 @@ diskcache = "^5.6.3"
openai = "^1.50.2" openai = "^1.50.2"
tenacity = "<9.0.0" tenacity = "<9.0.0"
numpy = ">=1.0.0" numpy = ">=1.0.0"
voyageai = "^0.2.3"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^8.3.3" pytest = "^8.3.3"