dont create extra search embeddings (#861)
* dont create extra search embeddings * updates * add missing conditionals * fix * float 0 * null check * more nullchecks * bump version
This commit is contained in:
parent
cbf783654b
commit
fa9c1696b8
3 changed files with 31 additions and 11 deletions
|
|
@ -21,6 +21,7 @@ from time import time
|
||||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
|
from graphiti_core.embedder.client import EMBEDDING_DIM
|
||||||
from graphiti_core.errors import SearchRerankerError
|
from graphiti_core.errors import SearchRerankerError
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import semaphore_gather
|
from graphiti_core.helpers import semaphore_gather
|
||||||
|
|
@ -29,6 +30,7 @@ from graphiti_core.search.search_config import (
|
||||||
DEFAULT_SEARCH_LIMIT,
|
DEFAULT_SEARCH_LIMIT,
|
||||||
CommunityReranker,
|
CommunityReranker,
|
||||||
CommunitySearchConfig,
|
CommunitySearchConfig,
|
||||||
|
CommunitySearchMethod,
|
||||||
EdgeReranker,
|
EdgeReranker,
|
||||||
EdgeSearchConfig,
|
EdgeSearchConfig,
|
||||||
EdgeSearchMethod,
|
EdgeSearchMethod,
|
||||||
|
|
@ -81,11 +83,29 @@ async def search(
|
||||||
|
|
||||||
if query.strip() == '':
|
if query.strip() == '':
|
||||||
return SearchResults()
|
return SearchResults()
|
||||||
query_vector = (
|
|
||||||
query_vector
|
if (
|
||||||
if query_vector is not None
|
config.edge_config
|
||||||
else await embedder.create(input_data=[query.replace('\n', ' ')])
|
and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
|
||||||
)
|
or config.edge_config
|
||||||
|
and EdgeReranker.mmr == config.edge_config.reranker
|
||||||
|
or config.node_config
|
||||||
|
and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
|
||||||
|
or config.node_config
|
||||||
|
and NodeReranker.mmr == config.node_config.reranker
|
||||||
|
or (
|
||||||
|
config.community_config
|
||||||
|
and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
|
||||||
|
)
|
||||||
|
or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
|
||||||
|
):
|
||||||
|
search_vector = (
|
||||||
|
query_vector
|
||||||
|
if query_vector is not None
|
||||||
|
else await embedder.create(input_data=[query.replace('\n', ' ')])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
search_vector = [0.0] * EMBEDDING_DIM
|
||||||
|
|
||||||
# if group_ids is empty, set it to None
|
# if group_ids is empty, set it to None
|
||||||
group_ids = group_ids if group_ids and group_ids != [''] else None
|
group_ids = group_ids if group_ids and group_ids != [''] else None
|
||||||
|
|
@ -99,7 +119,7 @@ async def search(
|
||||||
driver,
|
driver,
|
||||||
cross_encoder,
|
cross_encoder,
|
||||||
query,
|
query,
|
||||||
query_vector,
|
search_vector,
|
||||||
group_ids,
|
group_ids,
|
||||||
config.edge_config,
|
config.edge_config,
|
||||||
search_filter,
|
search_filter,
|
||||||
|
|
@ -112,7 +132,7 @@ async def search(
|
||||||
driver,
|
driver,
|
||||||
cross_encoder,
|
cross_encoder,
|
||||||
query,
|
query,
|
||||||
query_vector,
|
search_vector,
|
||||||
group_ids,
|
group_ids,
|
||||||
config.node_config,
|
config.node_config,
|
||||||
search_filter,
|
search_filter,
|
||||||
|
|
@ -125,7 +145,7 @@ async def search(
|
||||||
driver,
|
driver,
|
||||||
cross_encoder,
|
cross_encoder,
|
||||||
query,
|
query,
|
||||||
query_vector,
|
search_vector,
|
||||||
group_ids,
|
group_ids,
|
||||||
config.episode_config,
|
config.episode_config,
|
||||||
search_filter,
|
search_filter,
|
||||||
|
|
@ -136,7 +156,7 @@ async def search(
|
||||||
driver,
|
driver,
|
||||||
cross_encoder,
|
cross_encoder,
|
||||||
query,
|
query,
|
||||||
query_vector,
|
search_vector,
|
||||||
group_ids,
|
group_ids,
|
||||||
config.community_config,
|
config.community_config,
|
||||||
config.limit,
|
config.limit,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.19.0pre1"
|
version = "0.19.0pre2"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.20.0rc1"
|
version = "0.19.0rc2"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue