dont create extra search embeddings (#861)
* dont create extra search embeddings * updates * add missing conditionals * fix * float 0 * null check * more nullchecks * bump version
This commit is contained in:
parent
cbf783654b
commit
fa9c1696b8
3 changed files with 31 additions and 11 deletions
|
|
@ -21,6 +21,7 @@ from time import time
|
|||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.embedder.client import EMBEDDING_DIM
|
||||
from graphiti_core.errors import SearchRerankerError
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
|
|
@ -29,6 +30,7 @@ from graphiti_core.search.search_config import (
|
|||
DEFAULT_SEARCH_LIMIT,
|
||||
CommunityReranker,
|
||||
CommunitySearchConfig,
|
||||
CommunitySearchMethod,
|
||||
EdgeReranker,
|
||||
EdgeSearchConfig,
|
||||
EdgeSearchMethod,
|
||||
|
|
@ -81,11 +83,29 @@ async def search(
|
|||
|
||||
if query.strip() == '':
|
||||
return SearchResults()
|
||||
query_vector = (
|
||||
query_vector
|
||||
if query_vector is not None
|
||||
else await embedder.create(input_data=[query.replace('\n', ' ')])
|
||||
)
|
||||
|
||||
if (
|
||||
config.edge_config
|
||||
and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
|
||||
or config.edge_config
|
||||
and EdgeReranker.mmr == config.edge_config.reranker
|
||||
or config.node_config
|
||||
and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
|
||||
or config.node_config
|
||||
and NodeReranker.mmr == config.node_config.reranker
|
||||
or (
|
||||
config.community_config
|
||||
and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
|
||||
)
|
||||
or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
|
||||
):
|
||||
search_vector = (
|
||||
query_vector
|
||||
if query_vector is not None
|
||||
else await embedder.create(input_data=[query.replace('\n', ' ')])
|
||||
)
|
||||
else:
|
||||
search_vector = [0.0] * EMBEDDING_DIM
|
||||
|
||||
# if group_ids is empty, set it to None
|
||||
group_ids = group_ids if group_ids and group_ids != [''] else None
|
||||
|
|
@ -99,7 +119,7 @@ async def search(
|
|||
driver,
|
||||
cross_encoder,
|
||||
query,
|
||||
query_vector,
|
||||
search_vector,
|
||||
group_ids,
|
||||
config.edge_config,
|
||||
search_filter,
|
||||
|
|
@ -112,7 +132,7 @@ async def search(
|
|||
driver,
|
||||
cross_encoder,
|
||||
query,
|
||||
query_vector,
|
||||
search_vector,
|
||||
group_ids,
|
||||
config.node_config,
|
||||
search_filter,
|
||||
|
|
@ -125,7 +145,7 @@ async def search(
|
|||
driver,
|
||||
cross_encoder,
|
||||
query,
|
||||
query_vector,
|
||||
search_vector,
|
||||
group_ids,
|
||||
config.episode_config,
|
||||
search_filter,
|
||||
|
|
@ -136,7 +156,7 @@ async def search(
|
|||
driver,
|
||||
cross_encoder,
|
||||
query,
|
||||
query_vector,
|
||||
search_vector,
|
||||
group_ids,
|
||||
config.community_config,
|
||||
config.limit,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.19.0pre1"
|
||||
version = "0.19.0pre2"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.20.0rc1"
|
||||
version = "0.19.0rc2"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue