feat (metadata): added metadata filter in query

Added metadata filter dataclass for serializing and deserializing
complex filter to json dict, added node filtering based on metadata
This commit is contained in:
Giulio Grassia 2025-09-19 15:43:48 +02:00
parent 7be24a3c60
commit 0c721fa7f1
6 changed files with 143 additions and 33 deletions

2
.gitignore vendored
View file

@ -65,6 +65,8 @@ LightRAG.pdf
download_models_hf.py download_models_hf.py
lightrag-dev/ lightrag-dev/
gui/ gui/
/md
/uv.lock
# unit-test files # unit-test files
test_* test_*

View file

@ -4,10 +4,11 @@ This module contains all query-related routes for the LightRAG API.
import json import json
import logging import logging
import textwrap
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam from lightrag.base import QueryParam, MetadataFilter
from lightrag.api.utils_api import get_combined_auth_dependency from lightrag.api.utils_api import get_combined_auth_dependency
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -22,8 +23,8 @@ class QueryRequest(BaseModel):
description="The query text", description="The query text",
) )
metadata_filter: dict[str, str] | None = Field( metadata_filter: MetadataFilter | None = Field(
default=None, default=textwrap.dedent('{"operator": "AND","operands": [{"MetadataFilter": {}},"string"]},'),
description="Optional dictionary of metadata key-value pairs to filter nodes", description="Optional dictionary of metadata key-value pairs to filter nodes",
) )
@ -79,7 +80,7 @@ class QueryRequest(BaseModel):
) )
conversation_history: Optional[List[Dict[str, Any]]] = Field( conversation_history: Optional[List[Dict[str, Any]]] = Field(
default=None, default=[],
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
) )

View file

@ -15,6 +15,7 @@ from typing import (
Dict, Dict,
List, List,
AsyncIterator, AsyncIterator,
Union,
) )
from .utils import EmbeddingFunc from .utils import EmbeddingFunc
from .types import KnowledgeGraph from .types import KnowledgeGraph
@ -39,6 +40,46 @@ from .constants import (
load_dotenv(dotenv_path=".env", override=False) load_dotenv(dotenv_path=".env", override=False)
@dataclass
class MetadataFilter:
"""
Represents a logical expression for metadata filtering.
Args:
operator: "AND", "OR", or "NOT"
operands: List of either simple key-value pairs or nested MetadataFilter objects
"""
operator: str
operands: List[Union[Dict[str, Any], 'MetadataFilter']] = None
def __post_init__(self):
if self.operands is None:
self.operands = []
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary representation."""
return {
"operator": self.operator,
"operands": [
operand.to_dict() if isinstance(operand, MetadataFilter) else operand
for operand in self.operands
]
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter':
"""Create from dictionary representation."""
operands = []
for operand in data.get("operands", []):
if isinstance(operand, dict) and "operator" in operand:
operands.append(cls.from_dict(operand))
else:
operands.append(operand)
return cls(operator=data.get("operator", "AND"), operands=operands)
class OllamaServerInfos: class OllamaServerInfos:
def __init__(self, name=None, tag=None): def __init__(self, name=None, tag=None):
self._lightrag_name = name or os.getenv( self._lightrag_name = name or os.getenv(
@ -162,15 +203,15 @@ class QueryParam:
Default is True to enable reranking when rerank model is available. Default is True to enable reranking when rerank model is available.
""" """
metadata_filter: MetadataFilter | None = None
"""Metadata for filtering nodes and edges, allowing for more precise querying."""
include_references: bool = False include_references: bool = False
"""If True, includes reference list in the response for supported endpoints. """If True, includes reference list in the response for supported endpoints.
This parameter controls whether the API response includes a references field This parameter controls whether the API response includes a references field
containing citation information for the retrieved content. containing citation information for the retrieved content.
""" """
metadata_filter: dict | None = None
"""Metadata for filtering nodes and edges, allowing for more precise querying."""
@dataclass @dataclass
class StorageNameSpace(ABC): class StorageNameSpace(ABC):
@ -447,11 +488,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
or None if the node doesn't exist or None if the node doesn't exist
""" """
async def get_nodes_by_metadata_filter(self, metadata_filter: str) -> list[str]: async def get_nodes_by_metadata_filter(self, metadata_filter: MetadataFilter | None) -> list[str]:
result = [] """Get node IDs that match the given metadata filter with logical expressions."""
text_query = f"MATCH (n:`{self.workspace_label}` {metadata_filter})" # Default implementation - subclasses should override this method
result = await self.query(text_query) # This is a placeholder that will be overridden by specific implementations
return result raise NotImplementedError("Subclasses must implement get_nodes_by_metadata_filter")
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""Get nodes as a batch using UNWIND """Get nodes as a batch using UNWIND

View file

@ -2,10 +2,11 @@ import os
import re import re
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import final from typing import final, Any
import configparser import configparser
from tenacity import ( from tenacity import (
retry, retry,
stop_after_attempt, stop_after_attempt,
@ -15,7 +16,7 @@ from tenacity import (
import logging import logging
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage, MetadataFilter
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
@ -425,23 +426,88 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are consumed even on error await result.consume() # Ensure results are consumed even on error
raise raise
async def get_nodes_by_metadata_filter(self, query: str) -> list[str]: async def get_nodes_by_metadata_filter(self, metadata_filter: MetadataFilter | None) -> list[str]:
"""Get nodes by filtering query, return nodes id""" """Get node IDs that match the given metadata filter with logical expressions."""
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" # Build metadata conditions
) as session: params = {}
try: condition, params = self._build_metadata_conditions(metadata_filter, params)
query = f"MATCH (n:`{workspace_label}`) {query}"
result = await session.run(query) if not condition:
debug_results = [record async for record in result] # If no condition, return empty list for safety
debug_ids = [r["entity_id"] for r in debug_results] return []
await result.consume()
return debug_ids # Build the query
except Exception as e: query = f"""
logger.error(f"[{self.workspace}] Error getting node for: {str(e)}") MATCH (n:`{workspace_label}`)
raise WHERE {condition}
RETURN n.entity_id AS entity_id
"""
async with self._driver.session(database=self._DATABASE) as session:
result = await session.run(query, params)
return [record["entity_id"] async for record in result]
def _build_metadata_conditions(
self,
metadata_filter: MetadataFilter | None,
params: dict[str, Any],
node_var: str = "n"
) -> tuple[str, dict[str, Any]]:
"""
Build Cypher WHERE conditions from a MetadataFilter.
Args:
metadata_filter: The MetadataFilter object
params: Dictionary to collect parameters for the query
node_var: The variable name for the node in the Cypher query
Returns:
Tuple of (condition_string, updated_params)
"""
if metadata_filter is None:
return "", params
conditions = []
for operand in metadata_filter.operands:
if isinstance(operand, MetadataFilter):
# Recursive call for nested filters
sub_condition, params = self._build_metadata_conditions(operand, params, node_var)
if sub_condition:
conditions.append(f"({sub_condition})")
else:
# Simple key-value pair
for key, value in operand.items():
prop_name = f"meta_{key}" # Using our prefix
param_name = f"{prop_name}_{len(params)}"
if value is None:
# Check for existence of the key
conditions.append(f"{node_var}.{prop_name} IS NOT NULL")
else:
# Check for specific value
conditions.append(f"{node_var}.{prop_name} = ${param_name}")
params[param_name] = value
if not conditions:
return "", params
# Join conditions with the operator
if metadata_filter.operator == "AND":
condition = " AND ".join(conditions)
elif metadata_filter.operator == "OR":
condition = " OR ".join(conditions)
elif metadata_filter.operator == "NOT":
if len(conditions) == 1:
condition = f"NOT ({conditions[0]})"
else:
condition = f"NOT ({' AND '.join(conditions)})"
else:
raise ValueError(f"Unknown operator: {metadata_filter.operator}")
return condition, params
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties """Get node by its label identifier, return only node properties

View file

@ -41,6 +41,7 @@ from .base import (
QueryParam, QueryParam,
QueryResult, QueryResult,
QueryContextResult, QueryContextResult,
MetadataFilter
) )
from .prompt import PROMPTS from .prompt import PROMPTS
from .constants import ( from .constants import (
@ -3568,14 +3569,12 @@ async def _get_node_data(
# Extract all entity IDs from your results list # Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results] node_ids = [r["entity_name"] for r in results]
# HARDCODED QUERY FOR DEBUGGING
filter_query = "WHERE (n.class) is not null AND n.class = 'bando' RETURN (n.entity_id) AS entity_id, n.class AS class_value"
# TODO update method to take in the metadata_filter dataclass # TODO update method to take in the metadata_filter dataclass
node_kg_ids = [] node_kg_ids = []
if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"): if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"):
node_kg_ids = await asyncio.gather( node_kg_ids = await asyncio.gather(
knowledge_graph_inst.get_nodes_by_metadata_filter(filter_query) knowledge_graph_inst.get_nodes_by_metadata_filter(QueryParam.metadata_filter)
) )
filtered_node_ids = ( filtered_node_ids = (

View file

@ -27,6 +27,7 @@ dependencies = [
"future", "future",
"json_repair", "json_repair",
"nano-vectordb", "nano-vectordb",
"neo4j>=5.28.2",
"networkx", "networkx",
"numpy", "numpy",
"pandas>=2.0.0", "pandas>=2.0.0",