fix (embedding): fixed query endpoint

This commit is contained in:
GGrassia 2025-11-28 15:38:33 +01:00
parent 3a2d3ddb9f
commit 49ce064a11

View file

@ -2831,9 +2831,18 @@ async def _perform_kg_search(
query_embedding = None query_embedding = None
if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb): if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
embedding_func_config = text_chunks_db.embedding_func embedding_func_config = text_chunks_db.embedding_func
if embedding_func_config and embedding_func_config.func: if embedding_func_config:
try: try:
query_embedding = await embedding_func_config.func([query]) # Handle both EmbeddingFunc objects and plain callable functions
from .utils import EmbeddingFunc
if isinstance(embedding_func_config, EmbeddingFunc):
embedding_func = embedding_func_config.func
else:
# It's a plain callable function
embedding_func = embedding_func_config
if embedding_func:
query_embedding = await embedding_func([query])
query_embedding = query_embedding[ query_embedding = query_embedding[
0 0
] # Extract first embedding from batch result ] # Extract first embedding from batch result
@ -3748,7 +3757,13 @@ async def _find_related_text_unit_from_entities(
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"
else: else:
try: try:
# Handle both EmbeddingFunc objects and plain callable functions
from .utils import EmbeddingFunc
if isinstance(embedding_func_config, EmbeddingFunc):
actual_embedding_func = embedding_func_config.func actual_embedding_func = embedding_func_config.func
else:
# It's a plain callable function
actual_embedding_func = embedding_func_config
selected_chunk_ids = None selected_chunk_ids = None
if actual_embedding_func: if actual_embedding_func:
@ -4043,7 +4058,13 @@ async def _find_related_text_unit_from_relations(
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"
else: else:
try: try:
# Handle both EmbeddingFunc objects and plain callable functions
from .utils import EmbeddingFunc
if isinstance(embedding_func_config, EmbeddingFunc):
actual_embedding_func = embedding_func_config.func actual_embedding_func = embedding_func_config.func
else:
# It's a plain callable function
actual_embedding_func = embedding_func_config
if actual_embedding_func: if actual_embedding_func:
selected_chunk_ids = await pick_by_vector_similarity( selected_chunk_ids = await pick_by_vector_similarity(