feat (metadata): Add custom metadata support for node properties and querying
- Implement custom metadata insertion as node properties during file upload. - Add basic metadata filtering functionality to query API --NOTE: While the base.py file has been modified, the base implementation is incomplete and untested. Only Neo4j database has been properly implemented and tested. WIP: Query API is temporarily mocked for debugging. Full implementation with complex AND/OR filtering capabilities is in development. # Conflicts: # lightrag/base.py # lightrag/lightrag.py # lightrag/operate.py
This commit is contained in:
parent
058ce83dba
commit
7be24a3c60
6 changed files with 281 additions and 42 deletions
|
|
@ -11,11 +11,13 @@ import pipmaster as pm
|
|||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Literal
|
||||
import json
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
UploadFile,
|
||||
)
|
||||
|
|
@ -833,7 +835,7 @@ def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str
|
|||
|
||||
|
||||
async def pipeline_enqueue_file(
|
||||
rag: LightRAG, file_path: Path, track_id: str = None
|
||||
rag: LightRAG, file_path: Path, track_id: str = None, metadata: dict | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""Add a file to the queue for processing
|
||||
|
||||
|
|
@ -841,6 +843,7 @@ async def pipeline_enqueue_file(
|
|||
rag: LightRAG instance
|
||||
file_path: Path to the saved file
|
||||
track_id: Optional tracking ID, if not provided will be generated
|
||||
metadata: Optional metadata to associate with the document
|
||||
Returns:
|
||||
tuple: (success: bool, track_id: str)
|
||||
"""
|
||||
|
|
@ -1212,8 +1215,12 @@ async def pipeline_enqueue_file(
|
|||
return False, track_id
|
||||
|
||||
try:
|
||||
# Pass metadata to apipeline_enqueue_documents
|
||||
await rag.apipeline_enqueue_documents(
|
||||
content, file_paths=file_path.name, track_id=track_id
|
||||
content,
|
||||
file_paths=file_path.name,
|
||||
track_id=track_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
|
@ -1695,19 +1702,21 @@ def create_document_routes(
|
|||
"/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def upload_to_input_dir(
|
||||
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
metadata: Optional[str] = Form(None),
|
||||
):
|
||||
"""
|
||||
Upload a file to the input directory and index it.
|
||||
Upload a file to the input directory and index it with optional metadata.
|
||||
|
||||
This API endpoint accepts a file through an HTTP POST request, checks if the
|
||||
uploaded file is of a supported type, saves it in the specified input directory,
|
||||
indexes it for retrieval, and returns a success status with relevant details.
|
||||
|
||||
Metadata can be provided to associate custom data with the uploaded document.
|
||||
Args:
|
||||
background_tasks: FastAPI BackgroundTasks for async processing
|
||||
file (UploadFile): The file to be uploaded. It must have an allowed extension.
|
||||
|
||||
metadata (dict, optional): Custom metadata to associate with the document.
|
||||
Returns:
|
||||
InsertResponse: A response object containing the upload status and a message.
|
||||
status can be "success", "duplicated", or error is thrown.
|
||||
|
|
@ -1750,9 +1759,30 @@ def create_document_routes(
|
|||
|
||||
track_id = generate_track_id("upload")
|
||||
|
||||
# Add to background tasks and get track_id
|
||||
background_tasks.add_task(pipeline_index_file, rag, file_path, track_id)
|
||||
# Parse metadata if provided
|
||||
parsed_metadata = None
|
||||
if metadata:
|
||||
try:
|
||||
parsed_metadata = json.loads(metadata)
|
||||
if not isinstance(parsed_metadata, dict):
|
||||
raise ValueError(
|
||||
"Metadata must be a valid JSON dictionary string."
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Metadata must be a valid JSON string."
|
||||
)
|
||||
except ValueError as ve:
|
||||
raise HTTPException(status_code=400, detail=str(ve))
|
||||
|
||||
# Add to background tasks with metadata
|
||||
background_tasks.add_task(
|
||||
pipeline_index_file_with_metadata,
|
||||
rag,
|
||||
file_path,
|
||||
track_id,
|
||||
parsed_metadata,
|
||||
)
|
||||
return InsertResponse(
|
||||
status="success",
|
||||
message=f"File '{safe_filename}' uploaded successfully. Processing will continue in background.",
|
||||
|
|
@ -1764,6 +1794,35 @@ def create_document_routes(
|
|||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# New function to handle metadata during indexing
|
||||
async def pipeline_index_file_with_metadata(
|
||||
rag: LightRAG, file_path: Path, track_id: str, metadata: dict | None
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Index a file with metadata by leveraging the existing pipeline.
|
||||
Args:
|
||||
rag: LightRAG instance
|
||||
file_path: Path to the file to index
|
||||
track_id: Tracking ID for the document
|
||||
metadata: Optional metadata dictionary to associate with the document
|
||||
Returns:
|
||||
tuple[bool, str]: Success status and track ID
|
||||
"""
|
||||
# Use the existing pipeline to enqueue the file
|
||||
success, returned_track_id = await pipeline_enqueue_file(
|
||||
rag, file_path, track_id
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Successfully enqueued file with metadata: {metadata}")
|
||||
else:
|
||||
logger.error("Failed to enqueue file with metadata")
|
||||
|
||||
# Trigger the pipeline processing
|
||||
await rag.apipeline_process_enqueue_documents(metadata=metadata)
|
||||
|
||||
return success, returned_track_id
|
||||
|
||||
@router.post(
|
||||
"/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,11 @@ class QueryRequest(BaseModel):
|
|||
description="The query text",
|
||||
)
|
||||
|
||||
metadata_filter: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="Optional dictionary of metadata key-value pairs to filter nodes",
|
||||
)
|
||||
|
||||
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
|
||||
default="mix",
|
||||
description="Query mode",
|
||||
|
|
@ -168,6 +173,11 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
"""
|
||||
try:
|
||||
param = request.to_query_params(False)
|
||||
|
||||
# Inject metadata_filter into param if present
|
||||
if request.metadata_filter:
|
||||
setattr(param, "metadata_filter", request.metadata_filter)
|
||||
|
||||
response = await rag.aquery(request.query, param=param)
|
||||
|
||||
# Get reference list if requested
|
||||
|
|
|
|||
|
|
@ -168,6 +168,9 @@ class QueryParam:
|
|||
containing citation information for the retrieved content.
|
||||
"""
|
||||
|
||||
metadata_filter: dict | None = None
|
||||
"""Metadata for filtering nodes and edges, allowing for more precise querying."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageNameSpace(ABC):
|
||||
|
|
@ -444,6 +447,12 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||
or None if the node doesn't exist
|
||||
"""
|
||||
|
||||
async def get_nodes_by_metadata_filter(self, metadata_filter: str) -> list[str]:
|
||||
result = []
|
||||
text_query = f"MATCH (n:`{self.workspace_label}` {metadata_filter})"
|
||||
result = await self.query(text_query)
|
||||
return result
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""Get nodes as a batch using UNWIND
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import re
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import final
|
||||
import configparser
|
||||
|
|
@ -424,6 +425,24 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
await result.consume() # Ensure results are consumed even on error
|
||||
raise
|
||||
|
||||
async def get_nodes_by_metadata_filter(self, query: str) -> list[str]:
|
||||
"""Get nodes by filtering query, return nodes id"""
|
||||
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
try:
|
||||
query = f"MATCH (n:`{workspace_label}`) {query}"
|
||||
result = await session.run(query)
|
||||
debug_results = [record async for record in result]
|
||||
debug_ids = [r["entity_id"] for r in debug_results]
|
||||
await result.consume()
|
||||
return debug_ids
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error getting node for: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get node by its label identifier, return only node properties
|
||||
|
||||
|
|
@ -962,7 +981,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
)
|
||||
),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
async def upsert_node(
|
||||
self,
|
||||
node_id: str,
|
||||
node_data: dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
Upsert a node in the Neo4j database.
|
||||
|
||||
|
|
@ -971,8 +994,25 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
node_data: Dictionary of node properties
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
properties = node_data
|
||||
entity_type = properties["entity_type"]
|
||||
properties = node_data.copy()
|
||||
|
||||
metadata = properties.pop("metadata", None)
|
||||
for key, value in metadata.items():
|
||||
neo4j_key = key
|
||||
|
||||
# Handle complex data types by converting them to strings
|
||||
if isinstance(value, (dict, list)):
|
||||
try:
|
||||
properties[neo4j_key] = json.dumps(value)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to serialize metadata field {key} for node {node_id}: {e}"
|
||||
)
|
||||
properties[neo4j_key] = str(value)
|
||||
else:
|
||||
properties[neo4j_key] = value
|
||||
|
||||
entity_type = properties.get("entity_type", "Unknown")
|
||||
if "entity_id" not in properties:
|
||||
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
|
||||
|
||||
|
|
@ -992,7 +1032,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
|
||||
await session.execute_write(execute_upsert)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error during upsert: {str(e)}")
|
||||
logger.error(f"[{self.workspace}] Error during node upsert: {str(e)}")
|
||||
raise
|
||||
|
||||
@retry(
|
||||
|
|
@ -1026,12 +1066,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
Raises:
|
||||
ValueError: If either source or target node does not exist or is not unique
|
||||
"""
|
||||
edge_properties = edge_data
|
||||
workspace_label = self._get_workspace_label()
|
||||
|
||||
# Extract metadata if present and handle it properly
|
||||
metadata = edge_properties.pop("metadata", None)
|
||||
if metadata and isinstance(metadata, dict):
|
||||
# Serialize metadata to JSON string
|
||||
try:
|
||||
edge_properties["metadata"] = json.dumps(metadata)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to serialize metadata: {e}")
|
||||
edge_properties["metadata"] = str(metadata)
|
||||
|
||||
try:
|
||||
edge_properties = edge_data
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""
|
||||
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
|
||||
WITH source
|
||||
|
|
|
|||
|
|
@ -928,7 +928,8 @@ class LightRAG:
|
|||
|
||||
await self.apipeline_enqueue_documents(input, ids, file_paths, track_id)
|
||||
await self.apipeline_process_enqueue_documents(
|
||||
split_by_character, split_by_character_only
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
)
|
||||
|
||||
return track_id
|
||||
|
|
@ -1011,6 +1012,7 @@ class LightRAG:
|
|||
ids: list[str] | None = None,
|
||||
file_paths: str | list[str] | None = None,
|
||||
track_id: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Pipeline for Processing Documents
|
||||
|
|
@ -1025,6 +1027,7 @@ class LightRAG:
|
|||
ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
|
||||
file_paths: list of file paths corresponding to each document, used for citation
|
||||
track_id: tracking ID for monitoring processing status, if not provided, will be generated with "enqueue" prefix
|
||||
metadata: Optional metadata to associate with the documents
|
||||
|
||||
Returns:
|
||||
str: tracking ID for monitoring processing status
|
||||
|
|
@ -1038,6 +1041,8 @@ class LightRAG:
|
|||
ids = [ids]
|
||||
if isinstance(file_paths, str):
|
||||
file_paths = [file_paths]
|
||||
if isinstance(metadata, dict):
|
||||
metadata = [metadata]
|
||||
|
||||
# If file_paths is provided, ensure it matches the number of documents
|
||||
if file_paths is not None:
|
||||
|
|
@ -1102,6 +1107,7 @@ class LightRAG:
|
|||
"file_path"
|
||||
], # Store file path in document status
|
||||
"track_id": track_id, # Store track_id in document status
|
||||
"metadata": metadata, # added provided custom metadata
|
||||
}
|
||||
for id_, content_data in contents.items()
|
||||
}
|
||||
|
|
@ -1354,6 +1360,7 @@ class LightRAG:
|
|||
self,
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Process pending documents by splitting them into chunks, processing
|
||||
|
|
@ -1475,6 +1482,7 @@ class LightRAG:
|
|||
pipeline_status: dict,
|
||||
pipeline_status_lock: asyncio.Lock,
|
||||
semaphore: asyncio.Semaphore,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""Process single document"""
|
||||
file_extraction_stage_ok = False
|
||||
|
|
@ -1530,6 +1538,7 @@ class LightRAG:
|
|||
"full_doc_id": doc_id,
|
||||
"file_path": file_path, # Add file path to each chunk
|
||||
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk
|
||||
"metadata": metadata,
|
||||
}
|
||||
for dp in self.chunking_func(
|
||||
self.tokenizer,
|
||||
|
|
@ -1549,6 +1558,8 @@ class LightRAG:
|
|||
|
||||
# Process document in two stages
|
||||
# Stage 1: Process text chunks and docs (parallel execution)
|
||||
metadata["processing_start_time"] = processing_start_time
|
||||
|
||||
doc_status_task = asyncio.create_task(
|
||||
self.doc_status.upsert(
|
||||
{
|
||||
|
|
@ -1566,9 +1577,7 @@ class LightRAG:
|
|||
).isoformat(),
|
||||
"file_path": file_path,
|
||||
"track_id": status_doc.track_id, # Preserve existing track_id
|
||||
"metadata": {
|
||||
"processing_start_time": processing_start_time
|
||||
},
|
||||
"metadata": metadata,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
@ -1593,8 +1602,11 @@ class LightRAG:
|
|||
|
||||
# Stage 2: Process entity relation graph (after text_chunks are saved)
|
||||
entity_relation_task = asyncio.create_task(
|
||||
self._process_extract_entities(
|
||||
chunks, pipeline_status, pipeline_status_lock
|
||||
self._process_entity_relation_graph(
|
||||
chunks,
|
||||
metadata,
|
||||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
)
|
||||
)
|
||||
await entity_relation_task
|
||||
|
|
@ -1628,6 +1640,8 @@ class LightRAG:
|
|||
processing_end_time = int(time.time())
|
||||
|
||||
# Update document status to failed
|
||||
metadata["processing_start_time"] = processing_start_time
|
||||
metadata["processing_end_time"] = processing_end_time
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
|
|
@ -1641,10 +1655,7 @@ class LightRAG:
|
|||
).isoformat(),
|
||||
"file_path": file_path,
|
||||
"track_id": status_doc.track_id, # Preserve existing track_id
|
||||
"metadata": {
|
||||
"processing_start_time": processing_start_time,
|
||||
"processing_end_time": processing_end_time,
|
||||
},
|
||||
"metadata": metadata,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
@ -1669,10 +1680,15 @@ class LightRAG:
|
|||
current_file_number=current_file_number,
|
||||
total_files=total_files,
|
||||
file_path=file_path,
|
||||
metadata=metadata, # NEW: Pass metadata to merge function
|
||||
)
|
||||
|
||||
# Record processing end time
|
||||
processing_end_time = int(time.time())
|
||||
metadata["processing_start_time"] = (
|
||||
processing_start_time
|
||||
)
|
||||
metadata["processing_end_time"] = processing_end_time
|
||||
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
|
|
@ -1688,10 +1704,7 @@ class LightRAG:
|
|||
).isoformat(),
|
||||
"file_path": file_path,
|
||||
"track_id": status_doc.track_id, # Preserve existing track_id
|
||||
"metadata": {
|
||||
"processing_start_time": processing_start_time,
|
||||
"processing_end_time": processing_end_time,
|
||||
},
|
||||
"metadata": metadata,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
@ -1729,6 +1742,11 @@ class LightRAG:
|
|||
processing_end_time = int(time.time())
|
||||
|
||||
# Update document status to failed
|
||||
metadata["processing_start_time"] = (
|
||||
processing_start_time
|
||||
)
|
||||
metadata["processing_end_time"] = processing_end_time
|
||||
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
|
|
@ -1740,10 +1758,7 @@ class LightRAG:
|
|||
"updated_at": datetime.now().isoformat(),
|
||||
"file_path": file_path,
|
||||
"track_id": status_doc.track_id, # Preserve existing track_id
|
||||
"metadata": {
|
||||
"processing_start_time": processing_start_time,
|
||||
"processing_end_time": processing_end_time,
|
||||
},
|
||||
"metadata": metadata,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
@ -1760,6 +1775,7 @@ class LightRAG:
|
|||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
semaphore,
|
||||
metadata,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1803,13 +1819,18 @@ class LightRAG:
|
|||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
async def _process_extract_entities(
|
||||
self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None
|
||||
async def _process_entity_relation_graph(
|
||||
self,
|
||||
chunk: dict[str, Any],
|
||||
metadata: dict | None,
|
||||
pipeline_status=None,
|
||||
pipeline_status_lock=None,
|
||||
) -> list:
|
||||
try:
|
||||
chunk_results = await extract_entities(
|
||||
chunk,
|
||||
global_config=asdict(self),
|
||||
metadata=metadata, # Pass metadata here
|
||||
pipeline_status=pipeline_status,
|
||||
pipeline_status_lock=pipeline_status_lock,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
|
|
|
|||
|
|
@ -323,6 +323,7 @@ async def _handle_single_entity_extraction(
|
|||
chunk_key: str,
|
||||
timestamp: int,
|
||||
file_path: str = "unknown_source",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
if len(record_attributes) != 4 or "entity" not in record_attributes[0]:
|
||||
if len(record_attributes) > 1 and "entity" in record_attributes[0]:
|
||||
|
|
@ -376,6 +377,7 @@ async def _handle_single_entity_extraction(
|
|||
source_id=chunk_key,
|
||||
file_path=file_path,
|
||||
timestamp=timestamp,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
|
@ -395,6 +397,7 @@ async def _handle_single_relationship_extraction(
|
|||
chunk_key: str,
|
||||
timestamp: int,
|
||||
file_path: str = "unknown_source",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
if (
|
||||
len(record_attributes) != 5 or "relation" not in record_attributes[0]
|
||||
|
|
@ -458,6 +461,8 @@ async def _handle_single_relationship_extraction(
|
|||
source_id=edge_source_id,
|
||||
file_path=file_path,
|
||||
timestamp=timestamp,
|
||||
metadata=metadata,
|
||||
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
|
@ -862,6 +867,7 @@ async def _process_extraction_result(
|
|||
file_path: str = "unknown_source",
|
||||
tuple_delimiter: str = "<|#|>",
|
||||
completion_delimiter: str = "<|COMPLETE|>",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> tuple[dict, dict]:
|
||||
"""Process a single extraction result (either initial or gleaning)
|
||||
Args:
|
||||
|
|
@ -943,7 +949,7 @@ async def _process_extraction_result(
|
|||
|
||||
# Try to parse as entity
|
||||
entity_data = await _handle_single_entity_extraction(
|
||||
record_attributes, chunk_key, timestamp, file_path
|
||||
record_attributes, chunk_key, timestamp, file_path, metadata
|
||||
)
|
||||
if entity_data is not None:
|
||||
maybe_nodes[entity_data["entity_name"]].append(entity_data)
|
||||
|
|
@ -951,7 +957,7 @@ async def _process_extraction_result(
|
|||
|
||||
# Try to parse as relationship
|
||||
relationship_data = await _handle_single_relationship_extraction(
|
||||
record_attributes, chunk_key, timestamp, file_path
|
||||
record_attributes, chunk_key, timestamp, file_path, metadata
|
||||
)
|
||||
if relationship_data is not None:
|
||||
maybe_edges[
|
||||
|
|
@ -1295,6 +1301,7 @@ async def _merge_nodes_then_upsert(
|
|||
pipeline_status: dict = None,
|
||||
pipeline_status_lock=None,
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
|
||||
already_entity_types = []
|
||||
|
|
@ -1382,6 +1389,7 @@ async def _merge_nodes_then_upsert(
|
|||
description=description,
|
||||
source_id=source_id,
|
||||
file_path=file_path,
|
||||
metadata=metadata, # Add metadata here
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
await knowledge_graph_inst.upsert_node(
|
||||
|
|
@ -1402,6 +1410,7 @@ async def _merge_edges_then_upsert(
|
|||
pipeline_status_lock=None,
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
added_entities: list = None, # New parameter to track entities added during edge processing
|
||||
metadata: dict | None = None,
|
||||
):
|
||||
if src_id == tgt_id:
|
||||
return None
|
||||
|
|
@ -1534,6 +1543,7 @@ async def _merge_edges_then_upsert(
|
|||
"description": description,
|
||||
"entity_type": "UNKNOWN",
|
||||
"file_path": file_path,
|
||||
"metadata": metadata, # Add metadata here
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data)
|
||||
|
|
@ -1546,6 +1556,7 @@ async def _merge_edges_then_upsert(
|
|||
"description": description,
|
||||
"source_id": source_id,
|
||||
"file_path": file_path,
|
||||
"metadata": metadata, # Add metadata here
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
added_entities.append(entity_data)
|
||||
|
|
@ -1559,6 +1570,7 @@ async def _merge_edges_then_upsert(
|
|||
keywords=keywords,
|
||||
source_id=source_id,
|
||||
file_path=file_path,
|
||||
metadata=metadata, # Add metadata here
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
|
@ -1570,6 +1582,7 @@ async def _merge_edges_then_upsert(
|
|||
keywords=keywords,
|
||||
source_id=source_id,
|
||||
file_path=file_path,
|
||||
metadata=metadata, # Add metadata here
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
|
||||
|
|
@ -1591,6 +1604,7 @@ async def merge_nodes_and_edges(
|
|||
current_file_number: int = 0,
|
||||
total_files: int = 0,
|
||||
file_path: str = "unknown_source",
|
||||
metadata: dict | None = None, # Added metadata parameter
|
||||
) -> None:
|
||||
"""Two-phase merge: process all entities first, then all relationships
|
||||
|
||||
|
|
@ -1614,6 +1628,7 @@ async def merge_nodes_and_edges(
|
|||
current_file_number: Current file number for logging
|
||||
total_files: Total files for logging
|
||||
file_path: File path for logging
|
||||
metadata: Document metadata to be attached to entities and relationships
|
||||
"""
|
||||
|
||||
# Collect all nodes and edges from all chunks
|
||||
|
|
@ -1667,6 +1682,7 @@ async def merge_nodes_and_edges(
|
|||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
llm_response_cache,
|
||||
metadata,
|
||||
)
|
||||
|
||||
# Vector database operation (equally critical, must succeed)
|
||||
|
|
@ -1682,6 +1698,7 @@ async def merge_nodes_and_edges(
|
|||
"file_path": entity_data.get(
|
||||
"file_path", "unknown_source"
|
||||
),
|
||||
"metadata": metadata,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1797,7 +1814,8 @@ async def merge_nodes_and_edges(
|
|||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
llm_response_cache,
|
||||
added_entities, # Pass list to collect added entities
|
||||
added_entities,
|
||||
metadata,
|
||||
)
|
||||
|
||||
if edge_data is None:
|
||||
|
|
@ -1818,6 +1836,7 @@ async def merge_nodes_and_edges(
|
|||
"file_path", "unknown_source"
|
||||
),
|
||||
"weight": edge_data.get("weight", 1.0),
|
||||
"metadata": metadata,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1967,13 +1986,14 @@ async def merge_nodes_and_edges(
|
|||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
# Update storage
|
||||
# Update storage with metadata
|
||||
if final_entity_names:
|
||||
await full_entities_storage.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"entity_names": list(final_entity_names),
|
||||
"count": len(final_entity_names),
|
||||
"metadata": metadata, # Add metadata here
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
@ -1986,6 +2006,7 @@ async def merge_nodes_and_edges(
|
|||
list(pair) for pair in final_relation_pairs
|
||||
],
|
||||
"count": len(final_relation_pairs),
|
||||
"metadata": metadata, # Add metadata here
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
@ -2010,6 +2031,7 @@ async def merge_nodes_and_edges(
|
|||
async def extract_entities(
|
||||
chunks: dict[str, TextChunkSchema],
|
||||
global_config: dict[str, str],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
pipeline_status: dict = None,
|
||||
pipeline_status_lock=None,
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
|
|
@ -2047,6 +2069,55 @@ async def extract_entities(
|
|||
processed_chunks = 0
|
||||
total_chunks = len(ordered_chunks)
|
||||
|
||||
async def _process_extraction_result(
|
||||
result: str,
|
||||
chunk_key: str,
|
||||
file_path: str = "unknown_source",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Process a single extraction result (either initial or gleaning)
|
||||
Args:
|
||||
result (str): The extraction result to process
|
||||
chunk_key (str): The chunk key for source tracking
|
||||
file_path (str): The file path for citation
|
||||
metadata (dict, optional): Additional metadata to include in extracted entities/relationships.
|
||||
Returns:
|
||||
tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships
|
||||
"""
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
|
||||
records = split_string_by_multi_markers(
|
||||
result,
|
||||
[context_base["record_delimiter"], context_base["completion_delimiter"]],
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record = re.search(r"\((.*)\)", record)
|
||||
if record is None:
|
||||
continue
|
||||
record = record.group(1)
|
||||
record_attributes = split_string_by_multi_markers(
|
||||
record, [context_base["tuple_delimiter"]]
|
||||
)
|
||||
|
||||
if_entities = await _handle_single_entity_extraction(
|
||||
record_attributes, chunk_key, file_path, metadata
|
||||
)
|
||||
if if_entities is not None:
|
||||
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
||||
continue
|
||||
|
||||
if_relation = await _handle_single_relationship_extraction(
|
||||
record_attributes, chunk_key, file_path, metadata
|
||||
)
|
||||
if if_relation is not None:
|
||||
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
||||
if_relation
|
||||
)
|
||||
|
||||
return maybe_nodes, maybe_edges
|
||||
|
||||
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
||||
"""Process a single chunk
|
||||
Args:
|
||||
|
|
@ -2090,12 +2161,13 @@ async def extract_entities(
|
|||
entity_extraction_user_prompt, final_result
|
||||
)
|
||||
|
||||
# Process initial extraction with file path
|
||||
# Process initial extraction with file path and metadata
|
||||
maybe_nodes, maybe_edges = await _process_extraction_result(
|
||||
final_result,
|
||||
chunk_key,
|
||||
timestamp,
|
||||
file_path,
|
||||
metadata,
|
||||
tuple_delimiter=context_base["tuple_delimiter"],
|
||||
completion_delimiter=context_base["completion_delimiter"],
|
||||
)
|
||||
|
|
@ -2113,12 +2185,14 @@ async def extract_entities(
|
|||
cache_keys_collector=cache_keys_collector,
|
||||
)
|
||||
|
||||
# Process gleaning result separately with file path
|
||||
|
||||
# Process gleaning result separately with file path and metadata
|
||||
glean_nodes, glean_edges = await _process_extraction_result(
|
||||
glean_result,
|
||||
chunk_key,
|
||||
timestamp,
|
||||
file_path,
|
||||
metadata=metadata,
|
||||
tuple_delimiter=context_base["tuple_delimiter"],
|
||||
completion_delimiter=context_base["completion_delimiter"],
|
||||
)
|
||||
|
|
@ -2264,6 +2338,7 @@ async def kg_query(
|
|||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
metadata_filters: list | None = None,
|
||||
return_raw_data: Literal[False] = False,
|
||||
) -> str | AsyncIterator[str]: ...
|
||||
|
||||
|
|
@ -3493,10 +3568,24 @@ async def _get_node_data(
|
|||
# Extract all entity IDs from your results list
|
||||
node_ids = [r["entity_name"] for r in results]
|
||||
|
||||
# HARDCODED QUERY FOR DEBUGGING
|
||||
filter_query = "WHERE (n.class) is not null AND n.class = 'bando' RETURN (n.entity_id) AS entity_id, n.class AS class_value"
|
||||
|
||||
# TODO update method to take in the metadata_filter dataclass
|
||||
node_kg_ids = []
|
||||
if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"):
|
||||
node_kg_ids = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_by_metadata_filter(filter_query)
|
||||
)
|
||||
|
||||
filtered_node_ids = (
|
||||
[nid for nid in node_ids if nid in node_kg_ids] if node_kg_ids else node_ids
|
||||
)
|
||||
|
||||
# Call the batch node retrieval and degree functions concurrently.
|
||||
nodes_dict, degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_batch(node_ids),
|
||||
knowledge_graph_inst.node_degrees_batch(node_ids),
|
||||
knowledge_graph_inst.get_nodes_batch(filtered_node_ids),
|
||||
knowledge_graph_inst.node_degrees_batch(filtered_node_ids),
|
||||
)
|
||||
|
||||
# Now, if you need the node data and degree in order:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue