Update reranker limits (#203)
* update reranker limits * update versions * format * update names * fix: voyage linter --------- Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
This commit is contained in:
parent
ceb60a3d33
commit
7bb0c78d5d
6 changed files with 1062 additions and 854 deletions
|
|
@ -42,7 +42,9 @@ class OpenAIEmbedder(EmbedderClient):
|
|||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
async def create(
|
||||
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
) -> list[float]:
|
||||
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
|
||||
result = await self.client.embeddings.create(
|
||||
input=input_data, model=self.config.embedding_model
|
||||
)
|
||||
return result.data[0].embedding[: self.config.embedding_dim]
|
||||
|
|
|
|||
|
|
@ -41,7 +41,18 @@ class VoyageAIEmbedder(EmbedderClient):
|
|||
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
||||
|
||||
async def create(
|
||||
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
) -> list[float]:
|
||||
result = await self.client.embed(input, model=self.config.embedding_model)
|
||||
if isinstance(input_data, str):
|
||||
input_list = [input_data]
|
||||
elif isinstance(input_data, List):
|
||||
input_list = [str(i) for i in input_data if i]
|
||||
else:
|
||||
input_list = [str(i) for i in input_data if i is not None]
|
||||
|
||||
input_list = [i for i in input_list if i]
|
||||
if len(input_list) == 0:
|
||||
return []
|
||||
|
||||
result = await self.client.embed(input_list, model=self.config.embedding_model)
|
||||
return result.embeddings[0][: self.config.embedding_dim]
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ from graphiti_core.search.search_utils import (
|
|||
edge_similarity_search,
|
||||
episode_mentions_reranker,
|
||||
maximal_marginal_relevance,
|
||||
node_bfs_search,
|
||||
node_distance_reranker,
|
||||
node_fulltext_search,
|
||||
node_similarity_search,
|
||||
|
|
@ -138,7 +139,7 @@ async def edge_search(
|
|||
edge_similarity_search(
|
||||
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
|
||||
),
|
||||
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth),
|
||||
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
@ -160,7 +161,12 @@ async def edge_search(
|
|||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||
)
|
||||
elif config.reranker == EdgeReranker.cross_encoder:
|
||||
fact_to_uuid_map = {edge.fact: edge.uuid for result in search_results for edge in result}
|
||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||
|
||||
rrf_result_uuids = rrf(search_result_uuids)
|
||||
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
||||
|
||||
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
|
||||
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
|
||||
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
|
||||
elif config.reranker == EdgeReranker.node_distance:
|
||||
|
|
@ -212,6 +218,7 @@ async def node_search(
|
|||
node_similarity_search(
|
||||
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
|
||||
),
|
||||
node_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
@ -232,9 +239,12 @@ async def node_search(
|
|||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||
)
|
||||
elif config.reranker == NodeReranker.cross_encoder:
|
||||
summary_to_uuid_map = {
|
||||
node.summary: node.uuid for result in search_results for node in result
|
||||
}
|
||||
# use rrf as a preliminary reranker
|
||||
rrf_result_uuids = rrf(search_result_uuids)
|
||||
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
||||
|
||||
summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
|
||||
|
||||
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
||||
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
|
||||
elif config.reranker == NodeReranker.episode_mentions:
|
||||
|
|
|
|||
|
|
@ -233,6 +233,7 @@ async def edge_bfs_search(
|
|||
driver: AsyncDriver,
|
||||
bfs_origin_node_uuids: list[str] | None,
|
||||
bfs_max_depth: int,
|
||||
limit: int,
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
if bfs_origin_node_uuids is None:
|
||||
|
|
@ -256,12 +257,14 @@ async def edge_bfs_search(
|
|||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
LIMIT $limit
|
||||
""")
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
depth=bfs_max_depth,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -348,6 +351,7 @@ async def node_bfs_search(
|
|||
driver: AsyncDriver,
|
||||
bfs_origin_node_uuids: list[str] | None,
|
||||
bfs_max_depth: int,
|
||||
limit: int,
|
||||
) -> list[EntityNode]:
|
||||
# vector similarity search over entity names
|
||||
if bfs_origin_node_uuids is None:
|
||||
|
|
@ -368,6 +372,7 @@ async def node_bfs_search(
|
|||
""",
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
depth=bfs_max_depth,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -690,4 +695,4 @@ def maximal_marginal_relevance(
|
|||
|
||||
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
|
||||
|
||||
return [candidate[0] for candidate in candidates_with_mmr]
|
||||
return list(set([candidate[0] for candidate in candidates_with_mmr]))
|
||||
|
|
|
|||
1862
poetry.lock
generated
1862
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.3.17"
|
||||
version = "0.3.18"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
|
@ -17,7 +17,7 @@ python = "^3.10"
|
|||
pydantic = "^2.8.2"
|
||||
neo4j = "^5.23.0"
|
||||
diskcache = "^5.6.3"
|
||||
openai = "^1.50.2"
|
||||
openai = "^1.52.2"
|
||||
tenacity = "<9.0.0"
|
||||
numpy = ">=1.0.0"
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ pytest = "^8.3.3"
|
|||
python-dotenv = "^1.0.1"
|
||||
pytest-asyncio = "^0.24.0"
|
||||
pytest-xdist = "^3.6.1"
|
||||
ruff = "^0.6.9"
|
||||
ruff = "^0.7.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pydantic = "^2.8.2"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue