Msc benchmark update (#173)

* eval update

* I sped it up

* make format

* search updates

* updates

* cleanup

* make format

* remove unused imports

* poetry lock
This commit is contained in:
Preston Rasmussen 2024-10-03 15:39:35 -04:00 committed by GitHub
parent ec2e51c5ec
commit c8ff5be8ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 878 additions and 264 deletions

View file

@ -58,13 +58,21 @@ def setup_logging():
async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str):
search_start = time()
results = await graphiti._search(query, COMBINED_HYBRID_SEARCH_RRF, group_ids=[str(group_id)])
results = await graphiti._search(
query,
COMBINED_HYBRID_SEARCH_RRF,
group_ids=[str(group_id)],
)
search_end = time()
search_duration = search_end - search_start
facts = [edge.fact for edge in results.edges]
entity_summaries = [node.name + ': ' + node.summary for node in results.nodes]
context = {'facts': facts, 'entity_summaries': entity_summaries, 'query': 'Bob: ' + query}
context = {
'facts': facts,
'entity_summaries': entity_summaries,
'query': 'Bob: ' + query,
}
llm_response = await graphiti.llm_client.generate_response(
prompt_library.eval.qa_prompt(context)
@ -96,7 +104,14 @@ async def main():
setup_logging()
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
fields = ['Group id', 'Question', 'Answer', 'Response', 'Score', 'Search Duration (ms)']
fields = [
'Group id',
'Question',
'Answer',
'Response',
'Score',
'Search Duration (ms)',
]
with open('../data/msc_eval.csv', 'w', newline='') as file:
writer = csv.DictWriter(file, fieldnames=fields)
writer.writeheader()

View file

@ -71,7 +71,7 @@ async def main():
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
msc_messages = parse_msc_messages()
i = 0
while i <= 490:
while i < len(msc_messages):
msc_message_slice = msc_messages[i : i + 10]
group_ids = range(len(msc_messages))[i : i + 10]
@ -84,8 +84,5 @@ async def main():
i += 10
# build communities
# await client.build_communities()
asyncio.run(main())

View file

@ -188,9 +188,9 @@ class EntityEdge(Edge):
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
valid_at: $valid_at, invalid_at: $invalid_at}
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
RETURN r.uuid AS uuid""",
source_uuid=self.source_node_uuid,
target_uuid=self.target_node_uuid,

View file

@ -700,18 +700,17 @@ class Graphiti:
).nodes
return nodes
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
edges_list = await asyncio.gather(
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
)
edges_list = await asyncio.gather(
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
)
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
nodes = await get_mentioned_nodes(self.driver, episodes)
nodes = await get_mentioned_nodes(self.driver, episodes)
communities = await get_communities_by_nodes(self.driver, nodes)
communities = await get_communities_by_nodes(self.driver, nodes)
return SearchResults(edges=edges, nodes=nodes, communities=communities)
return SearchResults(edges=edges, nodes=nodes, communities=communities)

View file

@ -225,7 +225,8 @@ class EntityNode(Node):
result = await driver.execute_query(
"""
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid""",
uuid=self.uuid,
name=self.name,
@ -308,7 +309,8 @@ class CommunityNode(Node):
result = await driver.execute_query(
"""
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid""",
uuid=self.uuid,
name=self.name,

View file

@ -23,11 +23,33 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol):
qa_prompt: PromptVersion
eval_prompt: PromptVersion
query_expansion: PromptVersion
class Versions(TypedDict):
qa_prompt: PromptFunction
eval_prompt: PromptFunction
query_expansion: PromptFunction
def query_expansion(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an expert at rephrasing questions into queries used in a database retrieval system"""
user_prompt = f"""
Bob is asking Alice a question, are you able to rephrase the question into a simpler one about Alice in the third person
that maintains the relevant context?
<QUESTION>
{json.dumps(context['query'])}
</QUESTION>
respond with a JSON object in the following format:
{{
"query": "query optimized for database search"
}}
"""
return [
Message(role='system', content=sys_prompt),
Message(role='user', content=user_prompt),
]
def qa_prompt(context: dict[str, Any]) -> list[Message]:
@ -38,7 +60,7 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
You are given the following entity summaries and facts to help you determine the answer to your question.
<ENTITY_SUMMARIES>
{json.dumps(context['entity_summaries'])}
</ENTITY_SUMMARIES
</ENTITY_SUMMARIES>
<FACTS>
{json.dumps(context['facts'])}
</FACTS>
@ -87,4 +109,8 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
]
versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt}
versions: Versions = {
'qa_prompt': qa_prompt,
'eval_prompt': eval_prompt,
'query_expansion': query_expansion,
}

View file

@ -113,8 +113,9 @@ def v2(context: dict[str, Any]) -> list[Message]:
2. Each edge should represent a clear relationship between two DISTINCT nodes.
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
4. Provide a more detailed fact describing the relationship.
5. Consider temporal aspects of relationships when relevant.
6. Avoid using the same node as the source and target of a relationship
5. The fact should include any specific relevant information, including numeric information
6. Consider temporal aspects of relationships when relevant.
7. Avoid using the same node as the source and target of a relationship
Respond with a JSON object in the following format:
{{

View file

@ -29,13 +29,10 @@ from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT,
CommunityReranker,
CommunitySearchConfig,
CommunitySearchMethod,
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
SearchConfig,
SearchResults,
)
@ -120,22 +117,16 @@ async def edge_search(
if config is None:
return []
search_results: list[list[EntityEdge]] = []
if EdgeSearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
search_results.append(text_search)
if EdgeSearchMethod.cosine_similarity in config.search_methods:
search_vector = await embedder.create(input=[query])
similarity_search = await edge_similarity_search(
driver, search_vector, None, None, group_ids, 2 * limit
search_results: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit),
edge_similarity_search(
driver, await embedder.create(input=[query]), None, None, group_ids, 2 * limit
),
]
)
search_results.append(similarity_search)
if len(search_results) > 1 and config.reranker is None:
raise SearchRerankerError('Multiple edge searches enabled without a reranker')
)
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
@ -184,22 +175,16 @@ async def node_search(
if config is None:
return []
search_results: list[list[EntityNode]] = []
if NodeSearchMethod.bm25 in config.search_methods:
text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
search_results.append(text_search)
if NodeSearchMethod.cosine_similarity in config.search_methods:
search_vector = await embedder.create(input=[query])
similarity_search = await node_similarity_search(
driver, search_vector, group_ids, 2 * limit
search_results: list[list[EntityNode]] = list(
await asyncio.gather(
*[
node_fulltext_search(driver, query, group_ids, 2 * limit),
node_similarity_search(
driver, await embedder.create(input=[query]), group_ids, 2 * limit
),
]
)
search_results.append(similarity_search)
if len(search_results) > 1 and config.reranker is None:
raise SearchRerankerError('Multiple node searches enabled without a reranker')
)
search_result_uuids = [[node.uuid for node in result] for result in search_results]
node_uuid_map = {node.uuid: node for result in search_results for node in result}
@ -232,22 +217,16 @@ async def community_search(
if config is None:
return []
search_results: list[list[CommunityNode]] = []
if CommunitySearchMethod.bm25 in config.search_methods:
text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
search_results.append(text_search)
if CommunitySearchMethod.cosine_similarity in config.search_methods:
search_vector = await embedder.create(input=[query])
similarity_search = await community_similarity_search(
driver, search_vector, group_ids, 2 * limit
search_results: list[list[CommunityNode]] = list(
await asyncio.gather(
*[
community_fulltext_search(driver, query, group_ids, 2 * limit),
community_similarity_search(
driver, await embedder.create(input=[query]), group_ids, 2 * limit
),
]
)
search_results.append(similarity_search)
if len(search_results) > 1 and config.reranker is None:
raise SearchRerankerError('Multiple node searches enabled without a reranker')
)
search_result_uuids = [[community.uuid for community in result] for result in search_results]
community_uuid_map = {

View file

@ -57,17 +57,17 @@ class CommunityReranker(Enum):
class EdgeSearchConfig(BaseModel):
search_methods: list[EdgeSearchMethod]
reranker: EdgeReranker | None
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
class NodeSearchConfig(BaseModel):
search_methods: list[NodeSearchMethod]
reranker: NodeReranker | None
reranker: NodeReranker = Field(default=NodeReranker.rrf)
class CommunitySearchConfig(BaseModel):
search_methods: list[CommunitySearchMethod]
reranker: CommunityReranker | None
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
class SearchConfig(BaseModel):

View file

@ -52,6 +52,21 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
return full_query
async def get_episodes_by_mentions(
driver: AsyncDriver,
nodes: list[EntityNode],
edges: list[EntityEdge],
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EpisodicNode]:
episode_uuids: list[str] = []
for edge in edges:
episode_uuids.extend(edge.episodes)
episodes = await EpisodicNode.get_by_uuids(driver, episode_uuids[:limit])
return episodes
async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
@ -113,9 +128,6 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE ($source_uuid IS NULL OR n.uuid = $source_uuid)
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
AND ($group_ids IS NULL OR n.group_id IN $group_ids)
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
@ -156,12 +168,14 @@ async def edge_similarity_search(
) -> list[EntityEdge]:
# vector similarity search over embedded facts
query = Query("""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > 0.6
RETURN
vector.similarity.cosine(r.fact_embedding, $search_vector) AS score,
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
@ -205,7 +219,6 @@ async def node_fulltext_search(
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
YIELD node AS n, score
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
RETURN
n.uuid AS uuid,
n.group_id AS group_id,
@ -234,10 +247,12 @@ async def node_similarity_search(
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > 0.6
RETURN
vector.similarity.cosine(n.name_embedding, $search_vector) AS score,
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
@ -269,8 +284,6 @@ async def community_fulltext_search(
"""
CALL db.index.fulltext.queryNodes("community_name", $query)
YIELD node AS comm, score
MATCH (comm:Community)
WHERE $group_ids IS NULL OR comm.group_id in $group_ids
RETURN
comm.uuid AS uuid,
comm.group_id AS group_id,
@ -299,10 +312,12 @@ async def community_similarity_search(
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > 0.6
RETURN
vector.similarity.cosine(comm.name_embedding, $search_vector) AS score,
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,

935
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -20,6 +20,7 @@ diskcache = "^5.6.3"
openai = "^1.50.2"
tenacity = "<9.0.0"
numpy = ">=1.0.0"
voyageai = "^0.2.3"
[tool.poetry.dev-dependencies]
pytest = "^8.3.3"