diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index c7d9dd97..259f9b41 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -11,11 +11,13 @@ import pipmaster as pm from datetime import datetime, timezone from pathlib import Path from typing import Dict, List, Optional, Any, Literal +import json from fastapi import ( APIRouter, BackgroundTasks, Depends, File, + Form, HTTPException, UploadFile, ) @@ -833,7 +835,7 @@ def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str async def pipeline_enqueue_file( - rag: LightRAG, file_path: Path, track_id: str = None + rag: LightRAG, file_path: Path, track_id: str = None, metadata: dict | None = None ) -> tuple[bool, str]: """Add a file to the queue for processing @@ -841,6 +843,7 @@ async def pipeline_enqueue_file( rag: LightRAG instance file_path: Path to the saved file track_id: Optional tracking ID, if not provided will be generated + metadata: Optional metadata to associate with the document Returns: tuple: (success: bool, track_id: str) """ @@ -1212,8 +1215,12 @@ async def pipeline_enqueue_file( return False, track_id try: + # Pass metadata to apipeline_enqueue_documents await rag.apipeline_enqueue_documents( - content, file_paths=file_path.name, track_id=track_id + content, + file_paths=file_path.name, + track_id=track_id, + metadata=metadata, ) logger.info( @@ -1695,19 +1702,21 @@ def create_document_routes( "/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)] ) async def upload_to_input_dir( - background_tasks: BackgroundTasks, file: UploadFile = File(...) + background_tasks: BackgroundTasks, + file: UploadFile = File(...), + metadata: Optional[str] = Form(None), ): """ - Upload a file to the input directory and index it. + Upload a file to the input directory and index it with optional metadata. This API endpoint accepts a file through an HTTP POST request, checks if the uploaded file is of a supported type, saves it in the specified input directory, indexes it for retrieval, and returns a success status with relevant details. - + Metadata can be provided to associate custom data with the uploaded document. Args: background_tasks: FastAPI BackgroundTasks for async processing file (UploadFile): The file to be uploaded. It must have an allowed extension. - + metadata (dict, optional): Custom metadata to associate with the document. Returns: InsertResponse: A response object containing the upload status and a message. status can be "success", "duplicated", or error is thrown. @@ -1750,9 +1759,30 @@ def create_document_routes( track_id = generate_track_id("upload") - # Add to background tasks and get track_id - background_tasks.add_task(pipeline_index_file, rag, file_path, track_id) + # Parse metadata if provided + parsed_metadata = None + if metadata: + try: + parsed_metadata = json.loads(metadata) + if not isinstance(parsed_metadata, dict): + raise ValueError( + "Metadata must be a valid JSON dictionary string." + ) + except json.JSONDecodeError: + raise HTTPException( + status_code=400, detail="Metadata must be a valid JSON string." + ) + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) + # Add to background tasks with metadata + background_tasks.add_task( + pipeline_index_file_with_metadata, + rag, + file_path, + track_id, + parsed_metadata, + ) return InsertResponse( status="success", message=f"File '{safe_filename}' uploaded successfully. Processing will continue in background.", @@ -1764,6 +1794,35 @@ def create_document_routes( logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + # New function to handle metadata during indexing + async def pipeline_index_file_with_metadata( + rag: LightRAG, file_path: Path, track_id: str, metadata: dict | None + ) -> tuple[bool, str]: + """ + Index a file with metadata by leveraging the existing pipeline. + Args: + rag: LightRAG instance + file_path: Path to the file to index + track_id: Tracking ID for the document + metadata: Optional metadata dictionary to associate with the document + Returns: + tuple[bool, str]: Success status and track ID + """ + # Use the existing pipeline to enqueue the file + success, returned_track_id = await pipeline_enqueue_file( + rag, file_path, track_id + ) + + if success: + logger.info(f"Successfully enqueued file with metadata: {metadata}") + else: + logger.error("Failed to enqueue file with metadata") + + # Trigger the pipeline processing + await rag.apipeline_process_enqueue_documents(metadata=metadata) + + return success, returned_track_id + @router.post( "/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)] ) diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 83df2823..28bed617 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -22,6 +22,11 @@ class QueryRequest(BaseModel): description="The query text", ) + metadata_filter: dict[str, str] | None = Field( + default=None, + description="Optional dictionary of metadata key-value pairs to filter nodes", + ) + mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field( default="mix", description="Query mode", @@ -168,6 +173,11 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): """ try: param = request.to_query_params(False) + + # Inject metadata_filter into param if present + if request.metadata_filter: + setattr(param, "metadata_filter", request.metadata_filter) + response = await rag.aquery(request.query, param=param) # Get reference list if requested diff --git a/lightrag/base.py b/lightrag/base.py index a6420069..51d581be 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -168,6 +168,9 @@ class QueryParam: containing citation information for the retrieved content. """ + metadata_filter: dict | None = None + """Metadata for filtering nodes and edges, allowing for more precise querying.""" + @dataclass class StorageNameSpace(ABC): @@ -444,6 +447,12 @@ class BaseGraphStorage(StorageNameSpace, ABC): or None if the node doesn't exist """ + async def get_nodes_by_metadata_filter(self, metadata_filter: str) -> list[str]: + result = [] + text_query = f"MATCH (n:`{self.workspace_label}` {metadata_filter})" + result = await self.query(text_query) + return result + async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """Get nodes as a batch using UNWIND diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 896e5973..0996f4bc 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -1,5 +1,6 @@ import os import re +import json from dataclasses import dataclass from typing import final import configparser @@ -424,6 +425,24 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure results are consumed even on error raise + async def get_nodes_by_metadata_filter(self, query: str) -> list[str]: + """Get nodes by filtering query, return nodes id""" + + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + query = f"MATCH (n:`{workspace_label}`) {query}" + result = await session.run(query) + debug_results = [record async for record in result] + debug_ids = [r["entity_id"] for r in debug_results] + await result.consume() + return debug_ids + except Exception as e: + logger.error(f"[{self.workspace}] Error getting node for: {str(e)}") + raise + async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties @@ -962,7 +981,11 @@ class Neo4JStorage(BaseGraphStorage): ) ), ) - async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + async def upsert_node( + self, + node_id: str, + node_data: dict[str, str], + ) -> None: """ Upsert a node in the Neo4j database. @@ -971,8 +994,25 @@ class Neo4JStorage(BaseGraphStorage): node_data: Dictionary of node properties """ workspace_label = self._get_workspace_label() - properties = node_data - entity_type = properties["entity_type"] + properties = node_data.copy() + + metadata = properties.pop("metadata", None) + for key, value in metadata.items(): + neo4j_key = key + + # Handle complex data types by converting them to strings + if isinstance(value, (dict, list)): + try: + properties[neo4j_key] = json.dumps(value) + except Exception as e: + logger.warning( + f"Failed to serialize metadata field {key} for node {node_id}: {e}" + ) + properties[neo4j_key] = str(value) + else: + properties[neo4j_key] = value + + entity_type = properties.get("entity_type", "Unknown") if "entity_id" not in properties: raise ValueError("Neo4j: node properties must contain an 'entity_id' field") @@ -992,7 +1032,7 @@ class Neo4JStorage(BaseGraphStorage): await session.execute_write(execute_upsert) except Exception as e: - logger.error(f"[{self.workspace}] Error during upsert: {str(e)}") + logger.error(f"[{self.workspace}] Error during node upsert: {str(e)}") raise @retry( @@ -1026,12 +1066,23 @@ class Neo4JStorage(BaseGraphStorage): Raises: ValueError: If either source or target node does not exist or is not unique """ + edge_properties = edge_data + workspace_label = self._get_workspace_label() + + # Extract metadata if present and handle it properly + metadata = edge_properties.pop("metadata", None) + if metadata and isinstance(metadata, dict): + # Serialize metadata to JSON string + try: + edge_properties["metadata"] = json.dumps(metadata) + except Exception as e: + logger.warning(f"Failed to serialize metadata: {e}") + edge_properties["metadata"] = str(metadata) + try: - edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): - workspace_label = self._get_workspace_label() query = f""" MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) WITH source diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index afc0bc5f..d6945493 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -928,7 +928,8 @@ class LightRAG: await self.apipeline_enqueue_documents(input, ids, file_paths, track_id) await self.apipeline_process_enqueue_documents( - split_by_character, split_by_character_only + split_by_character, + split_by_character_only, ) return track_id @@ -1011,6 +1012,7 @@ class LightRAG: ids: list[str] | None = None, file_paths: str | list[str] | None = None, track_id: str | None = None, + metadata: dict | None = None, ) -> str: """ Pipeline for Processing Documents @@ -1025,6 +1027,7 @@ class LightRAG: ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated file_paths: list of file paths corresponding to each document, used for citation track_id: tracking ID for monitoring processing status, if not provided, will be generated with "enqueue" prefix + metadata: Optional metadata to associate with the documents Returns: str: tracking ID for monitoring processing status @@ -1038,6 +1041,8 @@ class LightRAG: ids = [ids] if isinstance(file_paths, str): file_paths = [file_paths] + if isinstance(metadata, dict): + metadata = [metadata] # If file_paths is provided, ensure it matches the number of documents if file_paths is not None: @@ -1102,6 +1107,7 @@ class LightRAG: "file_path" ], # Store file path in document status "track_id": track_id, # Store track_id in document status + "metadata": metadata, # added provided custom metadata } for id_, content_data in contents.items() } @@ -1354,6 +1360,7 @@ class LightRAG: self, split_by_character: str | None = None, split_by_character_only: bool = False, + metadata: dict | None = None, ) -> None: """ Process pending documents by splitting them into chunks, processing @@ -1475,6 +1482,7 @@ class LightRAG: pipeline_status: dict, pipeline_status_lock: asyncio.Lock, semaphore: asyncio.Semaphore, + metadata: dict | None = None, ) -> None: """Process single document""" file_extraction_stage_ok = False @@ -1530,6 +1538,7 @@ class LightRAG: "full_doc_id": doc_id, "file_path": file_path, # Add file path to each chunk "llm_cache_list": [], # Initialize empty LLM cache list for each chunk + "metadata": metadata, } for dp in self.chunking_func( self.tokenizer, @@ -1549,6 +1558,8 @@ class LightRAG: # Process document in two stages # Stage 1: Process text chunks and docs (parallel execution) + metadata["processing_start_time"] = processing_start_time + doc_status_task = asyncio.create_task( self.doc_status.upsert( { @@ -1566,9 +1577,7 @@ class LightRAG: ).isoformat(), "file_path": file_path, "track_id": status_doc.track_id, # Preserve existing track_id - "metadata": { - "processing_start_time": processing_start_time - }, + "metadata": metadata, } } ) @@ -1593,8 +1602,11 @@ class LightRAG: # Stage 2: Process entity relation graph (after text_chunks are saved) entity_relation_task = asyncio.create_task( - self._process_extract_entities( - chunks, pipeline_status, pipeline_status_lock + self._process_entity_relation_graph( + chunks, + metadata, + pipeline_status, + pipeline_status_lock, ) ) await entity_relation_task @@ -1628,6 +1640,8 @@ class LightRAG: processing_end_time = int(time.time()) # Update document status to failed + metadata["processing_start_time"] = processing_start_time + metadata["processing_end_time"] = processing_end_time await self.doc_status.upsert( { doc_id: { @@ -1641,10 +1655,7 @@ class LightRAG: ).isoformat(), "file_path": file_path, "track_id": status_doc.track_id, # Preserve existing track_id - "metadata": { - "processing_start_time": processing_start_time, - "processing_end_time": processing_end_time, - }, + "metadata": metadata, } } ) @@ -1669,10 +1680,15 @@ class LightRAG: current_file_number=current_file_number, total_files=total_files, file_path=file_path, + metadata=metadata, # NEW: Pass metadata to merge function ) # Record processing end time processing_end_time = int(time.time()) + metadata["processing_start_time"] = ( + processing_start_time + ) + metadata["processing_end_time"] = processing_end_time await self.doc_status.upsert( { @@ -1688,10 +1704,7 @@ class LightRAG: ).isoformat(), "file_path": file_path, "track_id": status_doc.track_id, # Preserve existing track_id - "metadata": { - "processing_start_time": processing_start_time, - "processing_end_time": processing_end_time, - }, + "metadata": metadata, } } ) @@ -1729,6 +1742,11 @@ class LightRAG: processing_end_time = int(time.time()) # Update document status to failed + metadata["processing_start_time"] = ( + processing_start_time + ) + metadata["processing_end_time"] = processing_end_time + await self.doc_status.upsert( { doc_id: { @@ -1740,10 +1758,7 @@ class LightRAG: "updated_at": datetime.now().isoformat(), "file_path": file_path, "track_id": status_doc.track_id, # Preserve existing track_id - "metadata": { - "processing_start_time": processing_start_time, - "processing_end_time": processing_end_time, - }, + "metadata": metadata, } } ) @@ -1760,6 +1775,7 @@ class LightRAG: pipeline_status, pipeline_status_lock, semaphore, + metadata, ) ) @@ -1803,13 +1819,18 @@ class LightRAG: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - async def _process_extract_entities( - self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None + async def _process_entity_relation_graph( + self, + chunk: dict[str, Any], + metadata: dict | None, + pipeline_status=None, + pipeline_status_lock=None, ) -> list: try: chunk_results = await extract_entities( chunk, global_config=asdict(self), + metadata=metadata, # Pass metadata here pipeline_status=pipeline_status, pipeline_status_lock=pipeline_status_lock, llm_response_cache=self.llm_response_cache, diff --git a/lightrag/operate.py b/lightrag/operate.py index 685e86a8..04e2da85 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -323,6 +323,7 @@ async def _handle_single_entity_extraction( chunk_key: str, timestamp: int, file_path: str = "unknown_source", + metadata: dict[str, Any] | None = None, ): if len(record_attributes) != 4 or "entity" not in record_attributes[0]: if len(record_attributes) > 1 and "entity" in record_attributes[0]: @@ -376,6 +377,7 @@ async def _handle_single_entity_extraction( source_id=chunk_key, file_path=file_path, timestamp=timestamp, + metadata=metadata, ) except ValueError as e: @@ -395,6 +397,7 @@ async def _handle_single_relationship_extraction( chunk_key: str, timestamp: int, file_path: str = "unknown_source", + metadata: dict[str, Any] | None = None, ): if ( len(record_attributes) != 5 or "relation" not in record_attributes[0] @@ -458,6 +461,8 @@ async def _handle_single_relationship_extraction( source_id=edge_source_id, file_path=file_path, timestamp=timestamp, + metadata=metadata, + ) except ValueError as e: @@ -862,6 +867,7 @@ async def _process_extraction_result( file_path: str = "unknown_source", tuple_delimiter: str = "<|#|>", completion_delimiter: str = "<|COMPLETE|>", + metadata: dict[str, Any] | None = None, ) -> tuple[dict, dict]: """Process a single extraction result (either initial or gleaning) Args: @@ -943,7 +949,7 @@ async def _process_extraction_result( # Try to parse as entity entity_data = await _handle_single_entity_extraction( - record_attributes, chunk_key, timestamp, file_path + record_attributes, chunk_key, timestamp, file_path, metadata ) if entity_data is not None: maybe_nodes[entity_data["entity_name"]].append(entity_data) @@ -951,7 +957,7 @@ async def _process_extraction_result( # Try to parse as relationship relationship_data = await _handle_single_relationship_extraction( - record_attributes, chunk_key, timestamp, file_path + record_attributes, chunk_key, timestamp, file_path, metadata ) if relationship_data is not None: maybe_edges[ @@ -1295,6 +1301,7 @@ async def _merge_nodes_then_upsert( pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, + metadata: dict[str, Any] | None = None, ): """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" already_entity_types = [] @@ -1382,6 +1389,7 @@ async def _merge_nodes_then_upsert( description=description, source_id=source_id, file_path=file_path, + metadata=metadata, # Add metadata here created_at=int(time.time()), ) await knowledge_graph_inst.upsert_node( @@ -1402,6 +1410,7 @@ async def _merge_edges_then_upsert( pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, added_entities: list = None, # New parameter to track entities added during edge processing + metadata: dict | None = None, ): if src_id == tgt_id: return None @@ -1534,6 +1543,7 @@ async def _merge_edges_then_upsert( "description": description, "entity_type": "UNKNOWN", "file_path": file_path, + "metadata": metadata, # Add metadata here "created_at": int(time.time()), } await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data) @@ -1546,6 +1556,7 @@ async def _merge_edges_then_upsert( "description": description, "source_id": source_id, "file_path": file_path, + "metadata": metadata, # Add metadata here "created_at": int(time.time()), } added_entities.append(entity_data) @@ -1559,6 +1570,7 @@ async def _merge_edges_then_upsert( keywords=keywords, source_id=source_id, file_path=file_path, + metadata=metadata, # Add metadata here created_at=int(time.time()), ), ) @@ -1570,6 +1582,7 @@ async def _merge_edges_then_upsert( keywords=keywords, source_id=source_id, file_path=file_path, + metadata=metadata, # Add metadata here created_at=int(time.time()), ) @@ -1591,6 +1604,7 @@ async def merge_nodes_and_edges( current_file_number: int = 0, total_files: int = 0, file_path: str = "unknown_source", + metadata: dict | None = None, # Added metadata parameter ) -> None: """Two-phase merge: process all entities first, then all relationships @@ -1614,6 +1628,7 @@ async def merge_nodes_and_edges( current_file_number: Current file number for logging total_files: Total files for logging file_path: File path for logging + metadata: Document metadata to be attached to entities and relationships """ # Collect all nodes and edges from all chunks @@ -1667,6 +1682,7 @@ async def merge_nodes_and_edges( pipeline_status, pipeline_status_lock, llm_response_cache, + metadata, ) # Vector database operation (equally critical, must succeed) @@ -1682,6 +1698,7 @@ async def merge_nodes_and_edges( "file_path": entity_data.get( "file_path", "unknown_source" ), + "metadata": metadata, } } @@ -1797,7 +1814,8 @@ async def merge_nodes_and_edges( pipeline_status, pipeline_status_lock, llm_response_cache, - added_entities, # Pass list to collect added entities + added_entities, + metadata, ) if edge_data is None: @@ -1818,6 +1836,7 @@ async def merge_nodes_and_edges( "file_path", "unknown_source" ), "weight": edge_data.get("weight", 1.0), + "metadata": metadata, } } @@ -1967,13 +1986,14 @@ async def merge_nodes_and_edges( pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - # Update storage + # Update storage with metadata if final_entity_names: await full_entities_storage.upsert( { doc_id: { "entity_names": list(final_entity_names), "count": len(final_entity_names), + "metadata": metadata, # Add metadata here } } ) @@ -1986,6 +2006,7 @@ async def merge_nodes_and_edges( list(pair) for pair in final_relation_pairs ], "count": len(final_relation_pairs), + "metadata": metadata, # Add metadata here } } ) @@ -2010,6 +2031,7 @@ async def merge_nodes_and_edges( async def extract_entities( chunks: dict[str, TextChunkSchema], global_config: dict[str, str], + metadata: dict[str, Any] | None = None, pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, @@ -2047,6 +2069,55 @@ async def extract_entities( processed_chunks = 0 total_chunks = len(ordered_chunks) + async def _process_extraction_result( + result: str, + chunk_key: str, + file_path: str = "unknown_source", + metadata: dict[str, Any] | None = None, + ): + """Process a single extraction result (either initial or gleaning) + Args: + result (str): The extraction result to process + chunk_key (str): The chunk key for source tracking + file_path (str): The file path for citation + metadata (dict, optional): Additional metadata to include in extracted entities/relationships. + Returns: + tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships + """ + maybe_nodes = defaultdict(list) + maybe_edges = defaultdict(list) + + records = split_string_by_multi_markers( + result, + [context_base["record_delimiter"], context_base["completion_delimiter"]], + ) + + for record in records: + record = re.search(r"\((.*)\)", record) + if record is None: + continue + record = record.group(1) + record_attributes = split_string_by_multi_markers( + record, [context_base["tuple_delimiter"]] + ) + + if_entities = await _handle_single_entity_extraction( + record_attributes, chunk_key, file_path, metadata + ) + if if_entities is not None: + maybe_nodes[if_entities["entity_name"]].append(if_entities) + continue + + if_relation = await _handle_single_relationship_extraction( + record_attributes, chunk_key, file_path, metadata + ) + if if_relation is not None: + maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append( + if_relation + ) + + return maybe_nodes, maybe_edges + async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): """Process a single chunk Args: @@ -2090,12 +2161,13 @@ async def extract_entities( entity_extraction_user_prompt, final_result ) - # Process initial extraction with file path + # Process initial extraction with file path and metadata maybe_nodes, maybe_edges = await _process_extraction_result( final_result, chunk_key, timestamp, file_path, + metadata, tuple_delimiter=context_base["tuple_delimiter"], completion_delimiter=context_base["completion_delimiter"], ) @@ -2113,12 +2185,14 @@ async def extract_entities( cache_keys_collector=cache_keys_collector, ) - # Process gleaning result separately with file path + + # Process gleaning result separately with file path and metadata glean_nodes, glean_edges = await _process_extraction_result( glean_result, chunk_key, timestamp, file_path, + metadata=metadata, tuple_delimiter=context_base["tuple_delimiter"], completion_delimiter=context_base["completion_delimiter"], ) @@ -2264,6 +2338,7 @@ async def kg_query( hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, chunks_vdb: BaseVectorStorage = None, + metadata_filters: list | None = None, return_raw_data: Literal[False] = False, ) -> str | AsyncIterator[str]: ... @@ -3493,10 +3568,24 @@ async def _get_node_data( # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] + # HARDCODED QUERY FOR DEBUGGING + filter_query = "WHERE (n.class) is not null AND n.class = 'bando' RETURN (n.entity_id) AS entity_id, n.class AS class_value" + + # TODO update method to take in the metadata_filter dataclass + node_kg_ids = [] + if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"): + node_kg_ids = await asyncio.gather( + knowledge_graph_inst.get_nodes_by_metadata_filter(filter_query) + ) + + filtered_node_ids = ( + [nid for nid in node_ids if nid in node_kg_ids] if node_kg_ids else node_ids + ) + # Call the batch node retrieval and degree functions concurrently. nodes_dict, degrees_dict = await asyncio.gather( - knowledge_graph_inst.get_nodes_batch(node_ids), - knowledge_graph_inst.node_degrees_batch(node_ids), + knowledge_graph_inst.get_nodes_batch(filtered_node_ids), + knowledge_graph_inst.node_degrees_batch(filtered_node_ids), ) # Now, if you need the node data and degree in order: