feat (metadata): Add custom metadata support for node properties and querying

- Implement custom metadata insertion as node properties during file upload.
- Add basic metadata filtering functionality to query API

--NOTE: While the base.py file has been modified, the base implementation is incomplete and untested. Only Neo4j database has been properly implemented and tested.

WIP: Query API is temporarily mocked for debugging. Full implementation with complex AND/OR filtering capabilities is in development.

# Conflicts:
#	lightrag/base.py
#	lightrag/lightrag.py
#	lightrag/operate.py
This commit is contained in:
Giulio Grassia 2025-09-02 15:46:59 +02:00
parent 058ce83dba
commit 7be24a3c60
6 changed files with 281 additions and 42 deletions

View file

@ -11,11 +11,13 @@ import pipmaster as pm
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Optional, Any, Literal
import json
from fastapi import (
APIRouter,
BackgroundTasks,
Depends,
File,
Form,
HTTPException,
UploadFile,
)
@ -833,7 +835,7 @@ def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str
async def pipeline_enqueue_file(
rag: LightRAG, file_path: Path, track_id: str = None
rag: LightRAG, file_path: Path, track_id: str = None, metadata: dict | None = None
) -> tuple[bool, str]:
"""Add a file to the queue for processing
@ -841,6 +843,7 @@ async def pipeline_enqueue_file(
rag: LightRAG instance
file_path: Path to the saved file
track_id: Optional tracking ID, if not provided will be generated
metadata: Optional metadata to associate with the document
Returns:
tuple: (success: bool, track_id: str)
"""
@ -1212,8 +1215,12 @@ async def pipeline_enqueue_file(
return False, track_id
try:
# Pass metadata to apipeline_enqueue_documents
await rag.apipeline_enqueue_documents(
content, file_paths=file_path.name, track_id=track_id
content,
file_paths=file_path.name,
track_id=track_id,
metadata=metadata,
)
logger.info(
@ -1695,19 +1702,21 @@ def create_document_routes(
"/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
)
async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...)
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
metadata: Optional[str] = Form(None),
):
"""
Upload a file to the input directory and index it.
Upload a file to the input directory and index it with optional metadata.
This API endpoint accepts a file through an HTTP POST request, checks if the
uploaded file is of a supported type, saves it in the specified input directory,
indexes it for retrieval, and returns a success status with relevant details.
Metadata can be provided to associate custom data with the uploaded document.
Args:
background_tasks: FastAPI BackgroundTasks for async processing
file (UploadFile): The file to be uploaded. It must have an allowed extension.
metadata (dict, optional): Custom metadata to associate with the document.
Returns:
InsertResponse: A response object containing the upload status and a message.
status can be "success", "duplicated", or error is thrown.
@ -1750,9 +1759,30 @@ def create_document_routes(
track_id = generate_track_id("upload")
# Add to background tasks and get track_id
background_tasks.add_task(pipeline_index_file, rag, file_path, track_id)
# Parse metadata if provided
parsed_metadata = None
if metadata:
try:
parsed_metadata = json.loads(metadata)
if not isinstance(parsed_metadata, dict):
raise ValueError(
"Metadata must be a valid JSON dictionary string."
)
except json.JSONDecodeError:
raise HTTPException(
status_code=400, detail="Metadata must be a valid JSON string."
)
except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve))
# Add to background tasks with metadata
background_tasks.add_task(
pipeline_index_file_with_metadata,
rag,
file_path,
track_id,
parsed_metadata,
)
return InsertResponse(
status="success",
message=f"File '{safe_filename}' uploaded successfully. Processing will continue in background.",
@ -1764,6 +1794,35 @@ def create_document_routes(
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
# New function to handle metadata during indexing
async def pipeline_index_file_with_metadata(
rag: LightRAG, file_path: Path, track_id: str, metadata: dict | None
) -> tuple[bool, str]:
"""
Index a file with metadata by leveraging the existing pipeline.
Args:
rag: LightRAG instance
file_path: Path to the file to index
track_id: Tracking ID for the document
metadata: Optional metadata dictionary to associate with the document
Returns:
tuple[bool, str]: Success status and track ID
"""
# Use the existing pipeline to enqueue the file
success, returned_track_id = await pipeline_enqueue_file(
rag, file_path, track_id
)
if success:
logger.info(f"Successfully enqueued file with metadata: {metadata}")
else:
logger.error("Failed to enqueue file with metadata")
# Trigger the pipeline processing
await rag.apipeline_process_enqueue_documents(metadata=metadata)
return success, returned_track_id
@router.post(
"/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
)

View file

@ -22,6 +22,11 @@ class QueryRequest(BaseModel):
description="The query text",
)
metadata_filter: dict[str, str] | None = Field(
default=None,
description="Optional dictionary of metadata key-value pairs to filter nodes",
)
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
default="mix",
description="Query mode",
@ -168,6 +173,11 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
"""
try:
param = request.to_query_params(False)
# Inject metadata_filter into param if present
if request.metadata_filter:
setattr(param, "metadata_filter", request.metadata_filter)
response = await rag.aquery(request.query, param=param)
# Get reference list if requested

View file

@ -168,6 +168,9 @@ class QueryParam:
containing citation information for the retrieved content.
"""
metadata_filter: dict | None = None
"""Metadata for filtering nodes and edges, allowing for more precise querying."""
@dataclass
class StorageNameSpace(ABC):
@ -444,6 +447,12 @@ class BaseGraphStorage(StorageNameSpace, ABC):
or None if the node doesn't exist
"""
async def get_nodes_by_metadata_filter(self, metadata_filter: str) -> list[str]:
result = []
text_query = f"MATCH (n:`{self.workspace_label}` {metadata_filter})"
result = await self.query(text_query)
return result
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""Get nodes as a batch using UNWIND

View file

@ -1,5 +1,6 @@
import os
import re
import json
from dataclasses import dataclass
from typing import final
import configparser
@ -424,6 +425,24 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are consumed even on error
raise
async def get_nodes_by_metadata_filter(self, query: str) -> list[str]:
"""Get nodes by filtering query, return nodes id"""
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = f"MATCH (n:`{workspace_label}`) {query}"
result = await session.run(query)
debug_results = [record async for record in result]
debug_ids = [r["entity_id"] for r in debug_results]
await result.consume()
return debug_ids
except Exception as e:
logger.error(f"[{self.workspace}] Error getting node for: {str(e)}")
raise
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties
@ -962,7 +981,11 @@ class Neo4JStorage(BaseGraphStorage):
)
),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
async def upsert_node(
self,
node_id: str,
node_data: dict[str, str],
) -> None:
"""
Upsert a node in the Neo4j database.
@ -971,8 +994,25 @@ class Neo4JStorage(BaseGraphStorage):
node_data: Dictionary of node properties
"""
workspace_label = self._get_workspace_label()
properties = node_data
entity_type = properties["entity_type"]
properties = node_data.copy()
metadata = properties.pop("metadata", None)
for key, value in metadata.items():
neo4j_key = key
# Handle complex data types by converting them to strings
if isinstance(value, (dict, list)):
try:
properties[neo4j_key] = json.dumps(value)
except Exception as e:
logger.warning(
f"Failed to serialize metadata field {key} for node {node_id}: {e}"
)
properties[neo4j_key] = str(value)
else:
properties[neo4j_key] = value
entity_type = properties.get("entity_type", "Unknown")
if "entity_id" not in properties:
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
@ -992,7 +1032,7 @@ class Neo4JStorage(BaseGraphStorage):
await session.execute_write(execute_upsert)
except Exception as e:
logger.error(f"[{self.workspace}] Error during upsert: {str(e)}")
logger.error(f"[{self.workspace}] Error during node upsert: {str(e)}")
raise
@retry(
@ -1026,12 +1066,23 @@ class Neo4JStorage(BaseGraphStorage):
Raises:
ValueError: If either source or target node does not exist or is not unique
"""
edge_properties = edge_data
workspace_label = self._get_workspace_label()
# Extract metadata if present and handle it properly
metadata = edge_properties.pop("metadata", None)
if metadata and isinstance(metadata, dict):
# Serialize metadata to JSON string
try:
edge_properties["metadata"] = json.dumps(metadata)
except Exception as e:
logger.warning(f"Failed to serialize metadata: {e}")
edge_properties["metadata"] = str(metadata)
try:
edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction):
workspace_label = self._get_workspace_label()
query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
WITH source

View file

@ -928,7 +928,8 @@ class LightRAG:
await self.apipeline_enqueue_documents(input, ids, file_paths, track_id)
await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only
split_by_character,
split_by_character_only,
)
return track_id
@ -1011,6 +1012,7 @@ class LightRAG:
ids: list[str] | None = None,
file_paths: str | list[str] | None = None,
track_id: str | None = None,
metadata: dict | None = None,
) -> str:
"""
Pipeline for Processing Documents
@ -1025,6 +1027,7 @@ class LightRAG:
ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
file_paths: list of file paths corresponding to each document, used for citation
track_id: tracking ID for monitoring processing status, if not provided, will be generated with "enqueue" prefix
metadata: Optional metadata to associate with the documents
Returns:
str: tracking ID for monitoring processing status
@ -1038,6 +1041,8 @@ class LightRAG:
ids = [ids]
if isinstance(file_paths, str):
file_paths = [file_paths]
if isinstance(metadata, dict):
metadata = [metadata]
# If file_paths is provided, ensure it matches the number of documents
if file_paths is not None:
@ -1102,6 +1107,7 @@ class LightRAG:
"file_path"
], # Store file path in document status
"track_id": track_id, # Store track_id in document status
"metadata": metadata, # added provided custom metadata
}
for id_, content_data in contents.items()
}
@ -1354,6 +1360,7 @@ class LightRAG:
self,
split_by_character: str | None = None,
split_by_character_only: bool = False,
metadata: dict | None = None,
) -> None:
"""
Process pending documents by splitting them into chunks, processing
@ -1475,6 +1482,7 @@ class LightRAG:
pipeline_status: dict,
pipeline_status_lock: asyncio.Lock,
semaphore: asyncio.Semaphore,
metadata: dict | None = None,
) -> None:
"""Process single document"""
file_extraction_stage_ok = False
@ -1530,6 +1538,7 @@ class LightRAG:
"full_doc_id": doc_id,
"file_path": file_path, # Add file path to each chunk
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk
"metadata": metadata,
}
for dp in self.chunking_func(
self.tokenizer,
@ -1549,6 +1558,8 @@ class LightRAG:
# Process document in two stages
# Stage 1: Process text chunks and docs (parallel execution)
metadata["processing_start_time"] = processing_start_time
doc_status_task = asyncio.create_task(
self.doc_status.upsert(
{
@ -1566,9 +1577,7 @@ class LightRAG:
).isoformat(),
"file_path": file_path,
"track_id": status_doc.track_id, # Preserve existing track_id
"metadata": {
"processing_start_time": processing_start_time
},
"metadata": metadata,
}
}
)
@ -1593,8 +1602,11 @@ class LightRAG:
# Stage 2: Process entity relation graph (after text_chunks are saved)
entity_relation_task = asyncio.create_task(
self._process_extract_entities(
chunks, pipeline_status, pipeline_status_lock
self._process_entity_relation_graph(
chunks,
metadata,
pipeline_status,
pipeline_status_lock,
)
)
await entity_relation_task
@ -1628,6 +1640,8 @@ class LightRAG:
processing_end_time = int(time.time())
# Update document status to failed
metadata["processing_start_time"] = processing_start_time
metadata["processing_end_time"] = processing_end_time
await self.doc_status.upsert(
{
doc_id: {
@ -1641,10 +1655,7 @@ class LightRAG:
).isoformat(),
"file_path": file_path,
"track_id": status_doc.track_id, # Preserve existing track_id
"metadata": {
"processing_start_time": processing_start_time,
"processing_end_time": processing_end_time,
},
"metadata": metadata,
}
}
)
@ -1669,10 +1680,15 @@ class LightRAG:
current_file_number=current_file_number,
total_files=total_files,
file_path=file_path,
metadata=metadata, # NEW: Pass metadata to merge function
)
# Record processing end time
processing_end_time = int(time.time())
metadata["processing_start_time"] = (
processing_start_time
)
metadata["processing_end_time"] = processing_end_time
await self.doc_status.upsert(
{
@ -1688,10 +1704,7 @@ class LightRAG:
).isoformat(),
"file_path": file_path,
"track_id": status_doc.track_id, # Preserve existing track_id
"metadata": {
"processing_start_time": processing_start_time,
"processing_end_time": processing_end_time,
},
"metadata": metadata,
}
}
)
@ -1729,6 +1742,11 @@ class LightRAG:
processing_end_time = int(time.time())
# Update document status to failed
metadata["processing_start_time"] = (
processing_start_time
)
metadata["processing_end_time"] = processing_end_time
await self.doc_status.upsert(
{
doc_id: {
@ -1740,10 +1758,7 @@ class LightRAG:
"updated_at": datetime.now().isoformat(),
"file_path": file_path,
"track_id": status_doc.track_id, # Preserve existing track_id
"metadata": {
"processing_start_time": processing_start_time,
"processing_end_time": processing_end_time,
},
"metadata": metadata,
}
}
)
@ -1760,6 +1775,7 @@ class LightRAG:
pipeline_status,
pipeline_status_lock,
semaphore,
metadata,
)
)
@ -1803,13 +1819,18 @@ class LightRAG:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
async def _process_extract_entities(
self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None
async def _process_entity_relation_graph(
self,
chunk: dict[str, Any],
metadata: dict | None,
pipeline_status=None,
pipeline_status_lock=None,
) -> list:
try:
chunk_results = await extract_entities(
chunk,
global_config=asdict(self),
metadata=metadata, # Pass metadata here
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache,

View file

@ -323,6 +323,7 @@ async def _handle_single_entity_extraction(
chunk_key: str,
timestamp: int,
file_path: str = "unknown_source",
metadata: dict[str, Any] | None = None,
):
if len(record_attributes) != 4 or "entity" not in record_attributes[0]:
if len(record_attributes) > 1 and "entity" in record_attributes[0]:
@ -376,6 +377,7 @@ async def _handle_single_entity_extraction(
source_id=chunk_key,
file_path=file_path,
timestamp=timestamp,
metadata=metadata,
)
except ValueError as e:
@ -395,6 +397,7 @@ async def _handle_single_relationship_extraction(
chunk_key: str,
timestamp: int,
file_path: str = "unknown_source",
metadata: dict[str, Any] | None = None,
):
if (
len(record_attributes) != 5 or "relation" not in record_attributes[0]
@ -458,6 +461,8 @@ async def _handle_single_relationship_extraction(
source_id=edge_source_id,
file_path=file_path,
timestamp=timestamp,
metadata=metadata,
)
except ValueError as e:
@ -862,6 +867,7 @@ async def _process_extraction_result(
file_path: str = "unknown_source",
tuple_delimiter: str = "<|#|>",
completion_delimiter: str = "<|COMPLETE|>",
metadata: dict[str, Any] | None = None,
) -> tuple[dict, dict]:
"""Process a single extraction result (either initial or gleaning)
Args:
@ -943,7 +949,7 @@ async def _process_extraction_result(
# Try to parse as entity
entity_data = await _handle_single_entity_extraction(
record_attributes, chunk_key, timestamp, file_path
record_attributes, chunk_key, timestamp, file_path, metadata
)
if entity_data is not None:
maybe_nodes[entity_data["entity_name"]].append(entity_data)
@ -951,7 +957,7 @@ async def _process_extraction_result(
# Try to parse as relationship
relationship_data = await _handle_single_relationship_extraction(
record_attributes, chunk_key, timestamp, file_path
record_attributes, chunk_key, timestamp, file_path, metadata
)
if relationship_data is not None:
maybe_edges[
@ -1295,6 +1301,7 @@ async def _merge_nodes_then_upsert(
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
metadata: dict[str, Any] | None = None,
):
"""Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
already_entity_types = []
@ -1382,6 +1389,7 @@ async def _merge_nodes_then_upsert(
description=description,
source_id=source_id,
file_path=file_path,
metadata=metadata, # Add metadata here
created_at=int(time.time()),
)
await knowledge_graph_inst.upsert_node(
@ -1402,6 +1410,7 @@ async def _merge_edges_then_upsert(
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
added_entities: list = None, # New parameter to track entities added during edge processing
metadata: dict | None = None,
):
if src_id == tgt_id:
return None
@ -1534,6 +1543,7 @@ async def _merge_edges_then_upsert(
"description": description,
"entity_type": "UNKNOWN",
"file_path": file_path,
"metadata": metadata, # Add metadata here
"created_at": int(time.time()),
}
await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data)
@ -1546,6 +1556,7 @@ async def _merge_edges_then_upsert(
"description": description,
"source_id": source_id,
"file_path": file_path,
"metadata": metadata, # Add metadata here
"created_at": int(time.time()),
}
added_entities.append(entity_data)
@ -1559,6 +1570,7 @@ async def _merge_edges_then_upsert(
keywords=keywords,
source_id=source_id,
file_path=file_path,
metadata=metadata, # Add metadata here
created_at=int(time.time()),
),
)
@ -1570,6 +1582,7 @@ async def _merge_edges_then_upsert(
keywords=keywords,
source_id=source_id,
file_path=file_path,
metadata=metadata, # Add metadata here
created_at=int(time.time()),
)
@ -1591,6 +1604,7 @@ async def merge_nodes_and_edges(
current_file_number: int = 0,
total_files: int = 0,
file_path: str = "unknown_source",
metadata: dict | None = None, # Added metadata parameter
) -> None:
"""Two-phase merge: process all entities first, then all relationships
@ -1614,6 +1628,7 @@ async def merge_nodes_and_edges(
current_file_number: Current file number for logging
total_files: Total files for logging
file_path: File path for logging
metadata: Document metadata to be attached to entities and relationships
"""
# Collect all nodes and edges from all chunks
@ -1667,6 +1682,7 @@ async def merge_nodes_and_edges(
pipeline_status,
pipeline_status_lock,
llm_response_cache,
metadata,
)
# Vector database operation (equally critical, must succeed)
@ -1682,6 +1698,7 @@ async def merge_nodes_and_edges(
"file_path": entity_data.get(
"file_path", "unknown_source"
),
"metadata": metadata,
}
}
@ -1797,7 +1814,8 @@ async def merge_nodes_and_edges(
pipeline_status,
pipeline_status_lock,
llm_response_cache,
added_entities, # Pass list to collect added entities
added_entities,
metadata,
)
if edge_data is None:
@ -1818,6 +1836,7 @@ async def merge_nodes_and_edges(
"file_path", "unknown_source"
),
"weight": edge_data.get("weight", 1.0),
"metadata": metadata,
}
}
@ -1967,13 +1986,14 @@ async def merge_nodes_and_edges(
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Update storage
# Update storage with metadata
if final_entity_names:
await full_entities_storage.upsert(
{
doc_id: {
"entity_names": list(final_entity_names),
"count": len(final_entity_names),
"metadata": metadata, # Add metadata here
}
}
)
@ -1986,6 +2006,7 @@ async def merge_nodes_and_edges(
list(pair) for pair in final_relation_pairs
],
"count": len(final_relation_pairs),
"metadata": metadata, # Add metadata here
}
}
)
@ -2010,6 +2031,7 @@ async def merge_nodes_and_edges(
async def extract_entities(
chunks: dict[str, TextChunkSchema],
global_config: dict[str, str],
metadata: dict[str, Any] | None = None,
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
@ -2047,6 +2069,55 @@ async def extract_entities(
processed_chunks = 0
total_chunks = len(ordered_chunks)
async def _process_extraction_result(
result: str,
chunk_key: str,
file_path: str = "unknown_source",
metadata: dict[str, Any] | None = None,
):
"""Process a single extraction result (either initial or gleaning)
Args:
result (str): The extraction result to process
chunk_key (str): The chunk key for source tracking
file_path (str): The file path for citation
metadata (dict, optional): Additional metadata to include in extracted entities/relationships.
Returns:
tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships
"""
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
records = split_string_by_multi_markers(
result,
[context_base["record_delimiter"], context_base["completion_delimiter"]],
)
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
record = record.group(1)
record_attributes = split_string_by_multi_markers(
record, [context_base["tuple_delimiter"]]
)
if_entities = await _handle_single_entity_extraction(
record_attributes, chunk_key, file_path, metadata
)
if if_entities is not None:
maybe_nodes[if_entities["entity_name"]].append(if_entities)
continue
if_relation = await _handle_single_relationship_extraction(
record_attributes, chunk_key, file_path, metadata
)
if if_relation is not None:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
if_relation
)
return maybe_nodes, maybe_edges
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
"""Process a single chunk
Args:
@ -2090,12 +2161,13 @@ async def extract_entities(
entity_extraction_user_prompt, final_result
)
# Process initial extraction with file path
# Process initial extraction with file path and metadata
maybe_nodes, maybe_edges = await _process_extraction_result(
final_result,
chunk_key,
timestamp,
file_path,
metadata,
tuple_delimiter=context_base["tuple_delimiter"],
completion_delimiter=context_base["completion_delimiter"],
)
@ -2113,12 +2185,14 @@ async def extract_entities(
cache_keys_collector=cache_keys_collector,
)
# Process gleaning result separately with file path
# Process gleaning result separately with file path and metadata
glean_nodes, glean_edges = await _process_extraction_result(
glean_result,
chunk_key,
timestamp,
file_path,
metadata=metadata,
tuple_delimiter=context_base["tuple_delimiter"],
completion_delimiter=context_base["completion_delimiter"],
)
@ -2264,6 +2338,7 @@ async def kg_query(
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
metadata_filters: list | None = None,
return_raw_data: Literal[False] = False,
) -> str | AsyncIterator[str]: ...
@ -3493,10 +3568,24 @@ async def _get_node_data(
# Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results]
# HARDCODED QUERY FOR DEBUGGING
filter_query = "WHERE (n.class) is not null AND n.class = 'bando' RETURN (n.entity_id) AS entity_id, n.class AS class_value"
# TODO update method to take in the metadata_filter dataclass
node_kg_ids = []
if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"):
node_kg_ids = await asyncio.gather(
knowledge_graph_inst.get_nodes_by_metadata_filter(filter_query)
)
filtered_node_ids = (
[nid for nid in node_ids if nid in node_kg_ids] if node_kg_ids else node_ids
)
# Call the batch node retrieval and degree functions concurrently.
nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(node_ids),
knowledge_graph_inst.node_degrees_batch(node_ids),
knowledge_graph_inst.get_nodes_batch(filtered_node_ids),
knowledge_graph_inst.node_degrees_batch(filtered_node_ids),
)
# Now, if you need the node data and degree in order: