diff --git a/lightrag/operate.py b/lightrag/operate.py index 13262d82..13715e94 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2831,13 +2831,22 @@ async def _perform_kg_search( query_embedding = None if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb): embedding_func_config = text_chunks_db.embedding_func - if embedding_func_config and embedding_func_config.func: + if embedding_func_config: try: - query_embedding = await embedding_func_config.func([query]) - query_embedding = query_embedding[ - 0 - ] # Extract first embedding from batch result - logger.debug("Pre-computed query embedding for all vector operations") + # Handle both EmbeddingFunc objects and plain callable functions + from .utils import EmbeddingFunc + if isinstance(embedding_func_config, EmbeddingFunc): + embedding_func = embedding_func_config.func + else: + # It's a plain callable function + embedding_func = embedding_func_config + + if embedding_func: + query_embedding = await embedding_func([query]) + query_embedding = query_embedding[ + 0 + ] # Extract first embedding from batch result + logger.debug("Pre-computed query embedding for all vector operations") except Exception as e: logger.warning(f"Failed to pre-compute query embedding: {e}") query_embedding = None @@ -3748,7 +3757,13 @@ async def _find_related_text_unit_from_entities( kg_chunk_pick_method = "WEIGHT" else: try: - actual_embedding_func = embedding_func_config.func + # Handle both EmbeddingFunc objects and plain callable functions + from .utils import EmbeddingFunc + if isinstance(embedding_func_config, EmbeddingFunc): + actual_embedding_func = embedding_func_config.func + else: + # It's a plain callable function + actual_embedding_func = embedding_func_config selected_chunk_ids = None if actual_embedding_func: @@ -4043,7 +4058,13 @@ async def _find_related_text_unit_from_relations( kg_chunk_pick_method = "WEIGHT" else: try: - actual_embedding_func = embedding_func_config.func + # Handle both EmbeddingFunc objects and plain callable functions + from .utils import EmbeddingFunc + if isinstance(embedding_func_config, EmbeddingFunc): + actual_embedding_func = embedding_func_config.func + else: + # It's a plain callable function + actual_embedding_func = embedding_func_config if actual_embedding_func: selected_chunk_ids = await pick_by_vector_similarity(