Update reranker limits (#203)

* update reranker limits

* update versions

* format

* update names

* fix: voyage linter

---------

Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
This commit is contained in:
Preston Rasmussen 2024-10-28 14:50:16 -04:00 committed by GitHub
parent ceb60a3d33
commit 7bb0c78d5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 1062 additions and 854 deletions

View file

@ -42,7 +42,9 @@ class OpenAIEmbedder(EmbedderClient):
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
async def create(
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
result = await self.client.embeddings.create(
input=input_data, model=self.config.embedding_model
)
return result.data[0].embedding[: self.config.embedding_dim]

View file

@ -41,7 +41,18 @@ class VoyageAIEmbedder(EmbedderClient):
self.client = voyageai.AsyncClient(api_key=config.api_key)
async def create(
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
result = await self.client.embed(input, model=self.config.embedding_model)
if isinstance(input_data, str):
input_list = [input_data]
elif isinstance(input_data, List):
input_list = [str(i) for i in input_data if i]
else:
input_list = [str(i) for i in input_data if i is not None]
input_list = [i for i in input_list if i]
if len(input_list) == 0:
return []
result = await self.client.embed(input_list, model=self.config.embedding_model)
return result.embeddings[0][: self.config.embedding_dim]

View file

@ -45,6 +45,7 @@ from graphiti_core.search.search_utils import (
edge_similarity_search,
episode_mentions_reranker,
maximal_marginal_relevance,
node_bfs_search,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
@ -138,7 +139,7 @@ async def edge_search(
edge_similarity_search(
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
),
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth),
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
]
)
)
@ -160,7 +161,12 @@ async def edge_search(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == EdgeReranker.cross_encoder:
fact_to_uuid_map = {edge.fact: edge.uuid for result in search_results for edge in result}
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
rrf_result_uuids = rrf(search_result_uuids)
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
elif config.reranker == EdgeReranker.node_distance:
@ -212,6 +218,7 @@ async def node_search(
node_similarity_search(
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
),
node_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
]
)
)
@ -232,9 +239,12 @@ async def node_search(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == NodeReranker.cross_encoder:
summary_to_uuid_map = {
node.summary: node.uuid for result in search_results for node in result
}
# use rrf as a preliminary reranker
rrf_result_uuids = rrf(search_result_uuids)
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
elif config.reranker == NodeReranker.episode_mentions:

View file

@ -233,6 +233,7 @@ async def edge_bfs_search(
driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int,
limit: int,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
if bfs_origin_node_uuids is None:
@ -256,12 +257,14 @@ async def edge_bfs_search(
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
LIMIT $limit
""")
records, _, _ = await driver.execute_query(
query,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -348,6 +351,7 @@ async def node_bfs_search(
driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int,
limit: int,
) -> list[EntityNode]:
# vector similarity search over entity names
if bfs_origin_node_uuids is None:
@ -368,6 +372,7 @@ async def node_bfs_search(
""",
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -690,4 +695,4 @@ def maximal_marginal_relevance(
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
return [candidate[0] for candidate in candidates_with_mmr]
return list(set([candidate[0] for candidate in candidates_with_mmr]))

1862
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.3.17"
version = "0.3.18"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",
@ -17,7 +17,7 @@ python = "^3.10"
pydantic = "^2.8.2"
neo4j = "^5.23.0"
diskcache = "^5.6.3"
openai = "^1.50.2"
openai = "^1.52.2"
tenacity = "<9.0.0"
numpy = ">=1.0.0"
@ -26,7 +26,7 @@ pytest = "^8.3.3"
python-dotenv = "^1.0.1"
pytest-asyncio = "^0.24.0"
pytest-xdist = "^3.6.1"
ruff = "^0.6.9"
ruff = "^0.7.1"
[tool.poetry.group.dev.dependencies]
pydantic = "^2.8.2"