feat (metadata): added metadata filter in query

Added metadata filter dataclass for serializing and deserializing
complex filter to json dict, added node filtering based on metadata
This commit is contained in:
Giulio Grassia 2025-09-19 15:43:48 +02:00
parent 7be24a3c60
commit 0c721fa7f1
6 changed files with 143 additions and 33 deletions

2
.gitignore vendored
View file

@ -65,6 +65,8 @@ LightRAG.pdf
download_models_hf.py
lightrag-dev/
gui/
/md
/uv.lock
# unit-test files
test_*

View file

@ -4,10 +4,11 @@ This module contains all query-related routes for the LightRAG API.
import json
import logging
import textwrap
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam
from lightrag.base import QueryParam, MetadataFilter
from lightrag.api.utils_api import get_combined_auth_dependency
from pydantic import BaseModel, Field, field_validator
@ -22,8 +23,8 @@ class QueryRequest(BaseModel):
description="The query text",
)
metadata_filter: dict[str, str] | None = Field(
default=None,
metadata_filter: MetadataFilter | None = Field(
default=textwrap.dedent('{"operator": "AND","operands": [{"MetadataFilter": {}},"string"]},'),
description="Optional dictionary of metadata key-value pairs to filter nodes",
)
@ -79,7 +80,7 @@ class QueryRequest(BaseModel):
)
conversation_history: Optional[List[Dict[str, Any]]] = Field(
default=None,
default=[],
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
)

View file

@ -15,6 +15,7 @@ from typing import (
Dict,
List,
AsyncIterator,
Union,
)
from .utils import EmbeddingFunc
from .types import KnowledgeGraph
@ -39,6 +40,46 @@ from .constants import (
load_dotenv(dotenv_path=".env", override=False)
@dataclass
class MetadataFilter:
"""
Represents a logical expression for metadata filtering.
Args:
operator: "AND", "OR", or "NOT"
operands: List of either simple key-value pairs or nested MetadataFilter objects
"""
operator: str
operands: List[Union[Dict[str, Any], 'MetadataFilter']] = None
def __post_init__(self):
if self.operands is None:
self.operands = []
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary representation."""
return {
"operator": self.operator,
"operands": [
operand.to_dict() if isinstance(operand, MetadataFilter) else operand
for operand in self.operands
]
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter':
"""Create from dictionary representation."""
operands = []
for operand in data.get("operands", []):
if isinstance(operand, dict) and "operator" in operand:
operands.append(cls.from_dict(operand))
else:
operands.append(operand)
return cls(operator=data.get("operator", "AND"), operands=operands)
class OllamaServerInfos:
def __init__(self, name=None, tag=None):
self._lightrag_name = name or os.getenv(
@ -162,15 +203,15 @@ class QueryParam:
Default is True to enable reranking when rerank model is available.
"""
metadata_filter: MetadataFilter | None = None
"""Metadata for filtering nodes and edges, allowing for more precise querying."""
include_references: bool = False
"""If True, includes reference list in the response for supported endpoints.
This parameter controls whether the API response includes a references field
containing citation information for the retrieved content.
"""
metadata_filter: dict | None = None
"""Metadata for filtering nodes and edges, allowing for more precise querying."""
@dataclass
class StorageNameSpace(ABC):
@ -447,11 +488,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
or None if the node doesn't exist
"""
async def get_nodes_by_metadata_filter(self, metadata_filter: str) -> list[str]:
result = []
text_query = f"MATCH (n:`{self.workspace_label}` {metadata_filter})"
result = await self.query(text_query)
return result
async def get_nodes_by_metadata_filter(self, metadata_filter: MetadataFilter | None) -> list[str]:
"""Get node IDs that match the given metadata filter with logical expressions."""
# Default implementation - subclasses should override this method
# This is a placeholder that will be overridden by specific implementations
raise NotImplementedError("Subclasses must implement get_nodes_by_metadata_filter")
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""Get nodes as a batch using UNWIND

View file

@ -2,10 +2,11 @@ import os
import re
import json
from dataclasses import dataclass
from typing import final
from typing import final, Any
import configparser
from tenacity import (
retry,
stop_after_attempt,
@ -15,7 +16,7 @@ from tenacity import (
import logging
from ..utils import logger
from ..base import BaseGraphStorage
from ..base import BaseGraphStorage, MetadataFilter
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
@ -425,23 +426,88 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are consumed even on error
raise
async def get_nodes_by_metadata_filter(self, query: str) -> list[str]:
"""Get nodes by filtering query, return nodes id"""
async def get_nodes_by_metadata_filter(self, metadata_filter: MetadataFilter | None) -> list[str]:
"""Get node IDs that match the given metadata filter with logical expressions."""
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = f"MATCH (n:`{workspace_label}`) {query}"
result = await session.run(query)
debug_results = [record async for record in result]
debug_ids = [r["entity_id"] for r in debug_results]
await result.consume()
return debug_ids
except Exception as e:
logger.error(f"[{self.workspace}] Error getting node for: {str(e)}")
raise
# Build metadata conditions
params = {}
condition, params = self._build_metadata_conditions(metadata_filter, params)
if not condition:
# If no condition, return empty list for safety
return []
# Build the query
query = f"""
MATCH (n:`{workspace_label}`)
WHERE {condition}
RETURN n.entity_id AS entity_id
"""
async with self._driver.session(database=self._DATABASE) as session:
result = await session.run(query, params)
return [record["entity_id"] async for record in result]
def _build_metadata_conditions(
self,
metadata_filter: MetadataFilter | None,
params: dict[str, Any],
node_var: str = "n"
) -> tuple[str, dict[str, Any]]:
"""
Build Cypher WHERE conditions from a MetadataFilter.
Args:
metadata_filter: The MetadataFilter object
params: Dictionary to collect parameters for the query
node_var: The variable name for the node in the Cypher query
Returns:
Tuple of (condition_string, updated_params)
"""
if metadata_filter is None:
return "", params
conditions = []
for operand in metadata_filter.operands:
if isinstance(operand, MetadataFilter):
# Recursive call for nested filters
sub_condition, params = self._build_metadata_conditions(operand, params, node_var)
if sub_condition:
conditions.append(f"({sub_condition})")
else:
# Simple key-value pair
for key, value in operand.items():
prop_name = f"meta_{key}" # Using our prefix
param_name = f"{prop_name}_{len(params)}"
if value is None:
# Check for existence of the key
conditions.append(f"{node_var}.{prop_name} IS NOT NULL")
else:
# Check for specific value
conditions.append(f"{node_var}.{prop_name} = ${param_name}")
params[param_name] = value
if not conditions:
return "", params
# Join conditions with the operator
if metadata_filter.operator == "AND":
condition = " AND ".join(conditions)
elif metadata_filter.operator == "OR":
condition = " OR ".join(conditions)
elif metadata_filter.operator == "NOT":
if len(conditions) == 1:
condition = f"NOT ({conditions[0]})"
else:
condition = f"NOT ({' AND '.join(conditions)})"
else:
raise ValueError(f"Unknown operator: {metadata_filter.operator}")
return condition, params
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties

View file

@ -41,6 +41,7 @@ from .base import (
QueryParam,
QueryResult,
QueryContextResult,
MetadataFilter
)
from .prompt import PROMPTS
from .constants import (
@ -3568,14 +3569,12 @@ async def _get_node_data(
# Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results]
# HARDCODED QUERY FOR DEBUGGING
filter_query = "WHERE (n.class) is not null AND n.class = 'bando' RETURN (n.entity_id) AS entity_id, n.class AS class_value"
# TODO update method to take in the metadata_filter dataclass
node_kg_ids = []
if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"):
node_kg_ids = await asyncio.gather(
knowledge_graph_inst.get_nodes_by_metadata_filter(filter_query)
knowledge_graph_inst.get_nodes_by_metadata_filter(QueryParam.metadata_filter)
)
filtered_node_ids = (

View file

@ -27,6 +27,7 @@ dependencies = [
"future",
"json_repair",
"nano-vectordb",
"neo4j>=5.28.2",
"networkx",
"numpy",
"pandas>=2.0.0",