feat (metadata): Add custom metadata support for node properties and querying

- Implement custom metadata insertion as node properties during file upload.
- Add basic metadata filtering functionality to query API

--NOTE: While the base.py file has been modified, the base implementation is incomplete and untested. Only Neo4j database has been properly implemented and tested.

WIP: Query API is temporarily mocked for debugging. Full implementation with complex AND/OR filtering capabilities is in development.

# Conflicts:
#	lightrag/base.py
#	lightrag/lightrag.py
#	lightrag/operate.py
This commit is contained in:
Giulio Grassia 2025-09-02 15:46:59 +02:00
parent 058ce83dba
commit 7be24a3c60
6 changed files with 281 additions and 42 deletions

View file

@ -11,11 +11,13 @@ import pipmaster as pm
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Any, Literal from typing import Dict, List, Optional, Any, Literal
import json
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
BackgroundTasks, BackgroundTasks,
Depends, Depends,
File, File,
Form,
HTTPException, HTTPException,
UploadFile, UploadFile,
) )
@ -833,7 +835,7 @@ def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str
async def pipeline_enqueue_file( async def pipeline_enqueue_file(
rag: LightRAG, file_path: Path, track_id: str = None rag: LightRAG, file_path: Path, track_id: str = None, metadata: dict | None = None
) -> tuple[bool, str]: ) -> tuple[bool, str]:
"""Add a file to the queue for processing """Add a file to the queue for processing
@ -841,6 +843,7 @@ async def pipeline_enqueue_file(
rag: LightRAG instance rag: LightRAG instance
file_path: Path to the saved file file_path: Path to the saved file
track_id: Optional tracking ID, if not provided will be generated track_id: Optional tracking ID, if not provided will be generated
metadata: Optional metadata to associate with the document
Returns: Returns:
tuple: (success: bool, track_id: str) tuple: (success: bool, track_id: str)
""" """
@ -1212,8 +1215,12 @@ async def pipeline_enqueue_file(
return False, track_id return False, track_id
try: try:
# Pass metadata to apipeline_enqueue_documents
await rag.apipeline_enqueue_documents( await rag.apipeline_enqueue_documents(
content, file_paths=file_path.name, track_id=track_id content,
file_paths=file_path.name,
track_id=track_id,
metadata=metadata,
) )
logger.info( logger.info(
@ -1695,19 +1702,21 @@ def create_document_routes(
"/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)] "/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
) )
async def upload_to_input_dir( async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...) background_tasks: BackgroundTasks,
file: UploadFile = File(...),
metadata: Optional[str] = Form(None),
): ):
""" """
Upload a file to the input directory and index it. Upload a file to the input directory and index it with optional metadata.
This API endpoint accepts a file through an HTTP POST request, checks if the This API endpoint accepts a file through an HTTP POST request, checks if the
uploaded file is of a supported type, saves it in the specified input directory, uploaded file is of a supported type, saves it in the specified input directory,
indexes it for retrieval, and returns a success status with relevant details. indexes it for retrieval, and returns a success status with relevant details.
Metadata can be provided to associate custom data with the uploaded document.
Args: Args:
background_tasks: FastAPI BackgroundTasks for async processing background_tasks: FastAPI BackgroundTasks for async processing
file (UploadFile): The file to be uploaded. It must have an allowed extension. file (UploadFile): The file to be uploaded. It must have an allowed extension.
metadata (dict, optional): Custom metadata to associate with the document.
Returns: Returns:
InsertResponse: A response object containing the upload status and a message. InsertResponse: A response object containing the upload status and a message.
status can be "success", "duplicated", or error is thrown. status can be "success", "duplicated", or error is thrown.
@ -1750,9 +1759,30 @@ def create_document_routes(
track_id = generate_track_id("upload") track_id = generate_track_id("upload")
# Add to background tasks and get track_id # Parse metadata if provided
background_tasks.add_task(pipeline_index_file, rag, file_path, track_id) parsed_metadata = None
if metadata:
try:
parsed_metadata = json.loads(metadata)
if not isinstance(parsed_metadata, dict):
raise ValueError(
"Metadata must be a valid JSON dictionary string."
)
except json.JSONDecodeError:
raise HTTPException(
status_code=400, detail="Metadata must be a valid JSON string."
)
except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve))
# Add to background tasks with metadata
background_tasks.add_task(
pipeline_index_file_with_metadata,
rag,
file_path,
track_id,
parsed_metadata,
)
return InsertResponse( return InsertResponse(
status="success", status="success",
message=f"File '{safe_filename}' uploaded successfully. Processing will continue in background.", message=f"File '{safe_filename}' uploaded successfully. Processing will continue in background.",
@ -1764,6 +1794,35 @@ def create_document_routes(
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# New function to handle metadata during indexing
async def pipeline_index_file_with_metadata(
rag: LightRAG, file_path: Path, track_id: str, metadata: dict | None
) -> tuple[bool, str]:
"""
Index a file with metadata by leveraging the existing pipeline.
Args:
rag: LightRAG instance
file_path: Path to the file to index
track_id: Tracking ID for the document
metadata: Optional metadata dictionary to associate with the document
Returns:
tuple[bool, str]: Success status and track ID
"""
# Use the existing pipeline to enqueue the file
success, returned_track_id = await pipeline_enqueue_file(
rag, file_path, track_id
)
if success:
logger.info(f"Successfully enqueued file with metadata: {metadata}")
else:
logger.error("Failed to enqueue file with metadata")
# Trigger the pipeline processing
await rag.apipeline_process_enqueue_documents(metadata=metadata)
return success, returned_track_id
@router.post( @router.post(
"/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)] "/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
) )

View file

@ -22,6 +22,11 @@ class QueryRequest(BaseModel):
description="The query text", description="The query text",
) )
metadata_filter: dict[str, str] | None = Field(
default=None,
description="Optional dictionary of metadata key-value pairs to filter nodes",
)
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field( mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
default="mix", default="mix",
description="Query mode", description="Query mode",
@ -168,6 +173,11 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
""" """
try: try:
param = request.to_query_params(False) param = request.to_query_params(False)
# Inject metadata_filter into param if present
if request.metadata_filter:
setattr(param, "metadata_filter", request.metadata_filter)
response = await rag.aquery(request.query, param=param) response = await rag.aquery(request.query, param=param)
# Get reference list if requested # Get reference list if requested

View file

@ -168,6 +168,9 @@ class QueryParam:
containing citation information for the retrieved content. containing citation information for the retrieved content.
""" """
metadata_filter: dict | None = None
"""Metadata for filtering nodes and edges, allowing for more precise querying."""
@dataclass @dataclass
class StorageNameSpace(ABC): class StorageNameSpace(ABC):
@ -444,6 +447,12 @@ class BaseGraphStorage(StorageNameSpace, ABC):
or None if the node doesn't exist or None if the node doesn't exist
""" """
async def get_nodes_by_metadata_filter(self, metadata_filter: str) -> list[str]:
result = []
text_query = f"MATCH (n:`{self.workspace_label}` {metadata_filter})"
result = await self.query(text_query)
return result
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""Get nodes as a batch using UNWIND """Get nodes as a batch using UNWIND

View file

@ -1,5 +1,6 @@
import os import os
import re import re
import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import final from typing import final
import configparser import configparser
@ -424,6 +425,24 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are consumed even on error await result.consume() # Ensure results are consumed even on error
raise raise
async def get_nodes_by_metadata_filter(self, query: str) -> list[str]:
"""Get nodes by filtering query, return nodes id"""
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = f"MATCH (n:`{workspace_label}`) {query}"
result = await session.run(query)
debug_results = [record async for record in result]
debug_ids = [r["entity_id"] for r in debug_results]
await result.consume()
return debug_ids
except Exception as e:
logger.error(f"[{self.workspace}] Error getting node for: {str(e)}")
raise
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties """Get node by its label identifier, return only node properties
@ -962,7 +981,11 @@ class Neo4JStorage(BaseGraphStorage):
) )
), ),
) )
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(
self,
node_id: str,
node_data: dict[str, str],
) -> None:
""" """
Upsert a node in the Neo4j database. Upsert a node in the Neo4j database.
@ -971,8 +994,25 @@ class Neo4JStorage(BaseGraphStorage):
node_data: Dictionary of node properties node_data: Dictionary of node properties
""" """
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
properties = node_data properties = node_data.copy()
entity_type = properties["entity_type"]
metadata = properties.pop("metadata", None)
for key, value in metadata.items():
neo4j_key = key
# Handle complex data types by converting them to strings
if isinstance(value, (dict, list)):
try:
properties[neo4j_key] = json.dumps(value)
except Exception as e:
logger.warning(
f"Failed to serialize metadata field {key} for node {node_id}: {e}"
)
properties[neo4j_key] = str(value)
else:
properties[neo4j_key] = value
entity_type = properties.get("entity_type", "Unknown")
if "entity_id" not in properties: if "entity_id" not in properties:
raise ValueError("Neo4j: node properties must contain an 'entity_id' field") raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
@ -992,7 +1032,7 @@ class Neo4JStorage(BaseGraphStorage):
await session.execute_write(execute_upsert) await session.execute_write(execute_upsert)
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error during upsert: {str(e)}") logger.error(f"[{self.workspace}] Error during node upsert: {str(e)}")
raise raise
@retry( @retry(
@ -1026,12 +1066,23 @@ class Neo4JStorage(BaseGraphStorage):
Raises: Raises:
ValueError: If either source or target node does not exist or is not unique ValueError: If either source or target node does not exist or is not unique
""" """
edge_properties = edge_data
workspace_label = self._get_workspace_label()
# Extract metadata if present and handle it properly
metadata = edge_properties.pop("metadata", None)
if metadata and isinstance(metadata, dict):
# Serialize metadata to JSON string
try:
edge_properties["metadata"] = json.dumps(metadata)
except Exception as e:
logger.warning(f"Failed to serialize metadata: {e}")
edge_properties["metadata"] = str(metadata)
try: try:
edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction): async def execute_upsert(tx: AsyncManagedTransaction):
workspace_label = self._get_workspace_label()
query = f""" query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
WITH source WITH source

View file

@ -928,7 +928,8 @@ class LightRAG:
await self.apipeline_enqueue_documents(input, ids, file_paths, track_id) await self.apipeline_enqueue_documents(input, ids, file_paths, track_id)
await self.apipeline_process_enqueue_documents( await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only split_by_character,
split_by_character_only,
) )
return track_id return track_id
@ -1011,6 +1012,7 @@ class LightRAG:
ids: list[str] | None = None, ids: list[str] | None = None,
file_paths: str | list[str] | None = None, file_paths: str | list[str] | None = None,
track_id: str | None = None, track_id: str | None = None,
metadata: dict | None = None,
) -> str: ) -> str:
""" """
Pipeline for Processing Documents Pipeline for Processing Documents
@ -1025,6 +1027,7 @@ class LightRAG:
ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
file_paths: list of file paths corresponding to each document, used for citation file_paths: list of file paths corresponding to each document, used for citation
track_id: tracking ID for monitoring processing status, if not provided, will be generated with "enqueue" prefix track_id: tracking ID for monitoring processing status, if not provided, will be generated with "enqueue" prefix
metadata: Optional metadata to associate with the documents
Returns: Returns:
str: tracking ID for monitoring processing status str: tracking ID for monitoring processing status
@ -1038,6 +1041,8 @@ class LightRAG:
ids = [ids] ids = [ids]
if isinstance(file_paths, str): if isinstance(file_paths, str):
file_paths = [file_paths] file_paths = [file_paths]
if isinstance(metadata, dict):
metadata = [metadata]
# If file_paths is provided, ensure it matches the number of documents # If file_paths is provided, ensure it matches the number of documents
if file_paths is not None: if file_paths is not None:
@ -1102,6 +1107,7 @@ class LightRAG:
"file_path" "file_path"
], # Store file path in document status ], # Store file path in document status
"track_id": track_id, # Store track_id in document status "track_id": track_id, # Store track_id in document status
"metadata": metadata, # added provided custom metadata
} }
for id_, content_data in contents.items() for id_, content_data in contents.items()
} }
@ -1354,6 +1360,7 @@ class LightRAG:
self, self,
split_by_character: str | None = None, split_by_character: str | None = None,
split_by_character_only: bool = False, split_by_character_only: bool = False,
metadata: dict | None = None,
) -> None: ) -> None:
""" """
Process pending documents by splitting them into chunks, processing Process pending documents by splitting them into chunks, processing
@ -1475,6 +1482,7 @@ class LightRAG:
pipeline_status: dict, pipeline_status: dict,
pipeline_status_lock: asyncio.Lock, pipeline_status_lock: asyncio.Lock,
semaphore: asyncio.Semaphore, semaphore: asyncio.Semaphore,
metadata: dict | None = None,
) -> None: ) -> None:
"""Process single document""" """Process single document"""
file_extraction_stage_ok = False file_extraction_stage_ok = False
@ -1530,6 +1538,7 @@ class LightRAG:
"full_doc_id": doc_id, "full_doc_id": doc_id,
"file_path": file_path, # Add file path to each chunk "file_path": file_path, # Add file path to each chunk
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk "llm_cache_list": [], # Initialize empty LLM cache list for each chunk
"metadata": metadata,
} }
for dp in self.chunking_func( for dp in self.chunking_func(
self.tokenizer, self.tokenizer,
@ -1549,6 +1558,8 @@ class LightRAG:
# Process document in two stages # Process document in two stages
# Stage 1: Process text chunks and docs (parallel execution) # Stage 1: Process text chunks and docs (parallel execution)
metadata["processing_start_time"] = processing_start_time
doc_status_task = asyncio.create_task( doc_status_task = asyncio.create_task(
self.doc_status.upsert( self.doc_status.upsert(
{ {
@ -1566,9 +1577,7 @@ class LightRAG:
).isoformat(), ).isoformat(),
"file_path": file_path, "file_path": file_path,
"track_id": status_doc.track_id, # Preserve existing track_id "track_id": status_doc.track_id, # Preserve existing track_id
"metadata": { "metadata": metadata,
"processing_start_time": processing_start_time
},
} }
} }
) )
@ -1593,8 +1602,11 @@ class LightRAG:
# Stage 2: Process entity relation graph (after text_chunks are saved) # Stage 2: Process entity relation graph (after text_chunks are saved)
entity_relation_task = asyncio.create_task( entity_relation_task = asyncio.create_task(
self._process_extract_entities( self._process_entity_relation_graph(
chunks, pipeline_status, pipeline_status_lock chunks,
metadata,
pipeline_status,
pipeline_status_lock,
) )
) )
await entity_relation_task await entity_relation_task
@ -1628,6 +1640,8 @@ class LightRAG:
processing_end_time = int(time.time()) processing_end_time = int(time.time())
# Update document status to failed # Update document status to failed
metadata["processing_start_time"] = processing_start_time
metadata["processing_end_time"] = processing_end_time
await self.doc_status.upsert( await self.doc_status.upsert(
{ {
doc_id: { doc_id: {
@ -1641,10 +1655,7 @@ class LightRAG:
).isoformat(), ).isoformat(),
"file_path": file_path, "file_path": file_path,
"track_id": status_doc.track_id, # Preserve existing track_id "track_id": status_doc.track_id, # Preserve existing track_id
"metadata": { "metadata": metadata,
"processing_start_time": processing_start_time,
"processing_end_time": processing_end_time,
},
} }
} }
) )
@ -1669,10 +1680,15 @@ class LightRAG:
current_file_number=current_file_number, current_file_number=current_file_number,
total_files=total_files, total_files=total_files,
file_path=file_path, file_path=file_path,
metadata=metadata, # NEW: Pass metadata to merge function
) )
# Record processing end time # Record processing end time
processing_end_time = int(time.time()) processing_end_time = int(time.time())
metadata["processing_start_time"] = (
processing_start_time
)
metadata["processing_end_time"] = processing_end_time
await self.doc_status.upsert( await self.doc_status.upsert(
{ {
@ -1688,10 +1704,7 @@ class LightRAG:
).isoformat(), ).isoformat(),
"file_path": file_path, "file_path": file_path,
"track_id": status_doc.track_id, # Preserve existing track_id "track_id": status_doc.track_id, # Preserve existing track_id
"metadata": { "metadata": metadata,
"processing_start_time": processing_start_time,
"processing_end_time": processing_end_time,
},
} }
} }
) )
@ -1729,6 +1742,11 @@ class LightRAG:
processing_end_time = int(time.time()) processing_end_time = int(time.time())
# Update document status to failed # Update document status to failed
metadata["processing_start_time"] = (
processing_start_time
)
metadata["processing_end_time"] = processing_end_time
await self.doc_status.upsert( await self.doc_status.upsert(
{ {
doc_id: { doc_id: {
@ -1740,10 +1758,7 @@ class LightRAG:
"updated_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(),
"file_path": file_path, "file_path": file_path,
"track_id": status_doc.track_id, # Preserve existing track_id "track_id": status_doc.track_id, # Preserve existing track_id
"metadata": { "metadata": metadata,
"processing_start_time": processing_start_time,
"processing_end_time": processing_end_time,
},
} }
} }
) )
@ -1760,6 +1775,7 @@ class LightRAG:
pipeline_status, pipeline_status,
pipeline_status_lock, pipeline_status_lock,
semaphore, semaphore,
metadata,
) )
) )
@ -1803,13 +1819,18 @@ class LightRAG:
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
async def _process_extract_entities( async def _process_entity_relation_graph(
self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None self,
chunk: dict[str, Any],
metadata: dict | None,
pipeline_status=None,
pipeline_status_lock=None,
) -> list: ) -> list:
try: try:
chunk_results = await extract_entities( chunk_results = await extract_entities(
chunk, chunk,
global_config=asdict(self), global_config=asdict(self),
metadata=metadata, # Pass metadata here
pipeline_status=pipeline_status, pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock, pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache, llm_response_cache=self.llm_response_cache,

View file

@ -323,6 +323,7 @@ async def _handle_single_entity_extraction(
chunk_key: str, chunk_key: str,
timestamp: int, timestamp: int,
file_path: str = "unknown_source", file_path: str = "unknown_source",
metadata: dict[str, Any] | None = None,
): ):
if len(record_attributes) != 4 or "entity" not in record_attributes[0]: if len(record_attributes) != 4 or "entity" not in record_attributes[0]:
if len(record_attributes) > 1 and "entity" in record_attributes[0]: if len(record_attributes) > 1 and "entity" in record_attributes[0]:
@ -376,6 +377,7 @@ async def _handle_single_entity_extraction(
source_id=chunk_key, source_id=chunk_key,
file_path=file_path, file_path=file_path,
timestamp=timestamp, timestamp=timestamp,
metadata=metadata,
) )
except ValueError as e: except ValueError as e:
@ -395,6 +397,7 @@ async def _handle_single_relationship_extraction(
chunk_key: str, chunk_key: str,
timestamp: int, timestamp: int,
file_path: str = "unknown_source", file_path: str = "unknown_source",
metadata: dict[str, Any] | None = None,
): ):
if ( if (
len(record_attributes) != 5 or "relation" not in record_attributes[0] len(record_attributes) != 5 or "relation" not in record_attributes[0]
@ -458,6 +461,8 @@ async def _handle_single_relationship_extraction(
source_id=edge_source_id, source_id=edge_source_id,
file_path=file_path, file_path=file_path,
timestamp=timestamp, timestamp=timestamp,
metadata=metadata,
) )
except ValueError as e: except ValueError as e:
@ -862,6 +867,7 @@ async def _process_extraction_result(
file_path: str = "unknown_source", file_path: str = "unknown_source",
tuple_delimiter: str = "<|#|>", tuple_delimiter: str = "<|#|>",
completion_delimiter: str = "<|COMPLETE|>", completion_delimiter: str = "<|COMPLETE|>",
metadata: dict[str, Any] | None = None,
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
"""Process a single extraction result (either initial or gleaning) """Process a single extraction result (either initial or gleaning)
Args: Args:
@ -943,7 +949,7 @@ async def _process_extraction_result(
# Try to parse as entity # Try to parse as entity
entity_data = await _handle_single_entity_extraction( entity_data = await _handle_single_entity_extraction(
record_attributes, chunk_key, timestamp, file_path record_attributes, chunk_key, timestamp, file_path, metadata
) )
if entity_data is not None: if entity_data is not None:
maybe_nodes[entity_data["entity_name"]].append(entity_data) maybe_nodes[entity_data["entity_name"]].append(entity_data)
@ -951,7 +957,7 @@ async def _process_extraction_result(
# Try to parse as relationship # Try to parse as relationship
relationship_data = await _handle_single_relationship_extraction( relationship_data = await _handle_single_relationship_extraction(
record_attributes, chunk_key, timestamp, file_path record_attributes, chunk_key, timestamp, file_path, metadata
) )
if relationship_data is not None: if relationship_data is not None:
maybe_edges[ maybe_edges[
@ -1295,6 +1301,7 @@ async def _merge_nodes_then_upsert(
pipeline_status: dict = None, pipeline_status: dict = None,
pipeline_status_lock=None, pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
metadata: dict[str, Any] | None = None,
): ):
"""Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
already_entity_types = [] already_entity_types = []
@ -1382,6 +1389,7 @@ async def _merge_nodes_then_upsert(
description=description, description=description,
source_id=source_id, source_id=source_id,
file_path=file_path, file_path=file_path,
metadata=metadata, # Add metadata here
created_at=int(time.time()), created_at=int(time.time()),
) )
await knowledge_graph_inst.upsert_node( await knowledge_graph_inst.upsert_node(
@ -1402,6 +1410,7 @@ async def _merge_edges_then_upsert(
pipeline_status_lock=None, pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
added_entities: list = None, # New parameter to track entities added during edge processing added_entities: list = None, # New parameter to track entities added during edge processing
metadata: dict | None = None,
): ):
if src_id == tgt_id: if src_id == tgt_id:
return None return None
@ -1534,6 +1543,7 @@ async def _merge_edges_then_upsert(
"description": description, "description": description,
"entity_type": "UNKNOWN", "entity_type": "UNKNOWN",
"file_path": file_path, "file_path": file_path,
"metadata": metadata, # Add metadata here
"created_at": int(time.time()), "created_at": int(time.time()),
} }
await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data) await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data)
@ -1546,6 +1556,7 @@ async def _merge_edges_then_upsert(
"description": description, "description": description,
"source_id": source_id, "source_id": source_id,
"file_path": file_path, "file_path": file_path,
"metadata": metadata, # Add metadata here
"created_at": int(time.time()), "created_at": int(time.time()),
} }
added_entities.append(entity_data) added_entities.append(entity_data)
@ -1559,6 +1570,7 @@ async def _merge_edges_then_upsert(
keywords=keywords, keywords=keywords,
source_id=source_id, source_id=source_id,
file_path=file_path, file_path=file_path,
metadata=metadata, # Add metadata here
created_at=int(time.time()), created_at=int(time.time()),
), ),
) )
@ -1570,6 +1582,7 @@ async def _merge_edges_then_upsert(
keywords=keywords, keywords=keywords,
source_id=source_id, source_id=source_id,
file_path=file_path, file_path=file_path,
metadata=metadata, # Add metadata here
created_at=int(time.time()), created_at=int(time.time()),
) )
@ -1591,6 +1604,7 @@ async def merge_nodes_and_edges(
current_file_number: int = 0, current_file_number: int = 0,
total_files: int = 0, total_files: int = 0,
file_path: str = "unknown_source", file_path: str = "unknown_source",
metadata: dict | None = None, # Added metadata parameter
) -> None: ) -> None:
"""Two-phase merge: process all entities first, then all relationships """Two-phase merge: process all entities first, then all relationships
@ -1614,6 +1628,7 @@ async def merge_nodes_and_edges(
current_file_number: Current file number for logging current_file_number: Current file number for logging
total_files: Total files for logging total_files: Total files for logging
file_path: File path for logging file_path: File path for logging
metadata: Document metadata to be attached to entities and relationships
""" """
# Collect all nodes and edges from all chunks # Collect all nodes and edges from all chunks
@ -1667,6 +1682,7 @@ async def merge_nodes_and_edges(
pipeline_status, pipeline_status,
pipeline_status_lock, pipeline_status_lock,
llm_response_cache, llm_response_cache,
metadata,
) )
# Vector database operation (equally critical, must succeed) # Vector database operation (equally critical, must succeed)
@ -1682,6 +1698,7 @@ async def merge_nodes_and_edges(
"file_path": entity_data.get( "file_path": entity_data.get(
"file_path", "unknown_source" "file_path", "unknown_source"
), ),
"metadata": metadata,
} }
} }
@ -1797,7 +1814,8 @@ async def merge_nodes_and_edges(
pipeline_status, pipeline_status,
pipeline_status_lock, pipeline_status_lock,
llm_response_cache, llm_response_cache,
added_entities, # Pass list to collect added entities added_entities,
metadata,
) )
if edge_data is None: if edge_data is None:
@ -1818,6 +1836,7 @@ async def merge_nodes_and_edges(
"file_path", "unknown_source" "file_path", "unknown_source"
), ),
"weight": edge_data.get("weight", 1.0), "weight": edge_data.get("weight", 1.0),
"metadata": metadata,
} }
} }
@ -1967,13 +1986,14 @@ async def merge_nodes_and_edges(
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
# Update storage # Update storage with metadata
if final_entity_names: if final_entity_names:
await full_entities_storage.upsert( await full_entities_storage.upsert(
{ {
doc_id: { doc_id: {
"entity_names": list(final_entity_names), "entity_names": list(final_entity_names),
"count": len(final_entity_names), "count": len(final_entity_names),
"metadata": metadata, # Add metadata here
} }
} }
) )
@ -1986,6 +2006,7 @@ async def merge_nodes_and_edges(
list(pair) for pair in final_relation_pairs list(pair) for pair in final_relation_pairs
], ],
"count": len(final_relation_pairs), "count": len(final_relation_pairs),
"metadata": metadata, # Add metadata here
} }
} }
) )
@ -2010,6 +2031,7 @@ async def merge_nodes_and_edges(
async def extract_entities( async def extract_entities(
chunks: dict[str, TextChunkSchema], chunks: dict[str, TextChunkSchema],
global_config: dict[str, str], global_config: dict[str, str],
metadata: dict[str, Any] | None = None,
pipeline_status: dict = None, pipeline_status: dict = None,
pipeline_status_lock=None, pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
@ -2047,6 +2069,55 @@ async def extract_entities(
processed_chunks = 0 processed_chunks = 0
total_chunks = len(ordered_chunks) total_chunks = len(ordered_chunks)
async def _process_extraction_result(
result: str,
chunk_key: str,
file_path: str = "unknown_source",
metadata: dict[str, Any] | None = None,
):
"""Process a single extraction result (either initial or gleaning)
Args:
result (str): The extraction result to process
chunk_key (str): The chunk key for source tracking
file_path (str): The file path for citation
metadata (dict, optional): Additional metadata to include in extracted entities/relationships.
Returns:
tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships
"""
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
records = split_string_by_multi_markers(
result,
[context_base["record_delimiter"], context_base["completion_delimiter"]],
)
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
record = record.group(1)
record_attributes = split_string_by_multi_markers(
record, [context_base["tuple_delimiter"]]
)
if_entities = await _handle_single_entity_extraction(
record_attributes, chunk_key, file_path, metadata
)
if if_entities is not None:
maybe_nodes[if_entities["entity_name"]].append(if_entities)
continue
if_relation = await _handle_single_relationship_extraction(
record_attributes, chunk_key, file_path, metadata
)
if if_relation is not None:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
if_relation
)
return maybe_nodes, maybe_edges
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
"""Process a single chunk """Process a single chunk
Args: Args:
@ -2090,12 +2161,13 @@ async def extract_entities(
entity_extraction_user_prompt, final_result entity_extraction_user_prompt, final_result
) )
# Process initial extraction with file path # Process initial extraction with file path and metadata
maybe_nodes, maybe_edges = await _process_extraction_result( maybe_nodes, maybe_edges = await _process_extraction_result(
final_result, final_result,
chunk_key, chunk_key,
timestamp, timestamp,
file_path, file_path,
metadata,
tuple_delimiter=context_base["tuple_delimiter"], tuple_delimiter=context_base["tuple_delimiter"],
completion_delimiter=context_base["completion_delimiter"], completion_delimiter=context_base["completion_delimiter"],
) )
@ -2113,12 +2185,14 @@ async def extract_entities(
cache_keys_collector=cache_keys_collector, cache_keys_collector=cache_keys_collector,
) )
# Process gleaning result separately with file path
# Process gleaning result separately with file path and metadata
glean_nodes, glean_edges = await _process_extraction_result( glean_nodes, glean_edges = await _process_extraction_result(
glean_result, glean_result,
chunk_key, chunk_key,
timestamp, timestamp,
file_path, file_path,
metadata=metadata,
tuple_delimiter=context_base["tuple_delimiter"], tuple_delimiter=context_base["tuple_delimiter"],
completion_delimiter=context_base["completion_delimiter"], completion_delimiter=context_base["completion_delimiter"],
) )
@ -2264,6 +2338,7 @@ async def kg_query(
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None, chunks_vdb: BaseVectorStorage = None,
metadata_filters: list | None = None,
return_raw_data: Literal[False] = False, return_raw_data: Literal[False] = False,
) -> str | AsyncIterator[str]: ... ) -> str | AsyncIterator[str]: ...
@ -3493,10 +3568,24 @@ async def _get_node_data(
# Extract all entity IDs from your results list # Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results] node_ids = [r["entity_name"] for r in results]
# HARDCODED QUERY FOR DEBUGGING
filter_query = "WHERE (n.class) is not null AND n.class = 'bando' RETURN (n.entity_id) AS entity_id, n.class AS class_value"
# TODO update method to take in the metadata_filter dataclass
node_kg_ids = []
if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"):
node_kg_ids = await asyncio.gather(
knowledge_graph_inst.get_nodes_by_metadata_filter(filter_query)
)
filtered_node_ids = (
[nid for nid in node_ids if nid in node_kg_ids] if node_kg_ids else node_ids
)
# Call the batch node retrieval and degree functions concurrently. # Call the batch node retrieval and degree functions concurrently.
nodes_dict, degrees_dict = await asyncio.gather( nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(node_ids), knowledge_graph_inst.get_nodes_batch(filtered_node_ids),
knowledge_graph_inst.node_degrees_batch(node_ids), knowledge_graph_inst.node_degrees_batch(filtered_node_ids),
) )
# Now, if you need the node data and degree in order: # Now, if you need the node data and degree in order: