feat (metadata filter): added metadata filtering

Added functioning (needs testing) metadata filtering on chunks for
query. Fully implemented only on Postgres with pgvector and Neo4j
This commit is contained in:
Giulio Grassia 2025-09-24 16:57:44 +02:00
parent 40afb0441a
commit d0fba28e1f
5 changed files with 263 additions and 90 deletions

View file

@ -4,11 +4,11 @@ This module contains all query-related routes for the LightRAG API.
import json
import logging
import textwrap
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam, MetadataFilter
from lightrag.base import QueryParam
from lightrag.types import MetadataFilter
from lightrag.api.utils_api import get_combined_auth_dependency
from pydantic import BaseModel, Field, field_validator
@ -24,8 +24,8 @@ class QueryRequest(BaseModel):
)
metadata_filter: MetadataFilter | None = Field(
default=textwrap.dedent('{"operator": "AND","operands": [{"MetadataFilter": {}},"string"]},'),
description="Optional dictionary of metadata key-value pairs to filter nodes",
default=None,
description="Optional metadata filter for nodes and edges. Can be a MetadataFilter object or a dict that will be converted to MetadataFilter.",
)
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
@ -118,6 +118,16 @@ class QueryRequest(BaseModel):
)
return conversation_history
@field_validator("metadata_filter", mode="before")
@classmethod
def metadata_filter_convert(cls, v):
"""Convert dict inputs to MetadataFilter objects."""
if v is None:
return None
if isinstance(v, dict):
return MetadataFilter.from_dict(v)
return v
def to_query_params(self, is_stream: bool) -> "QueryParam":
"""Converts a QueryRequest instance into a QueryParam instance."""
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
@ -126,6 +136,11 @@ class QueryRequest(BaseModel):
# Ensure `mode` and `stream` are set explicitly
param = QueryParam(**request_data)
param.stream = is_stream
# Ensure metadata_filter remains as MetadataFilter object if it exists
if self.metadata_filter:
param.metadata_filter = self.metadata_filter
return param
@ -175,10 +190,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
try:
param = request.to_query_params(False)
# Inject metadata_filter into param if present
if request.metadata_filter:
setattr(param, "metadata_filter", request.metadata_filter)
response = await rag.aquery(request.query, param=param)
# Get reference list if requested

View file

@ -18,7 +18,7 @@ from typing import (
Union,
)
from .utils import EmbeddingFunc
from .types import KnowledgeGraph
from .types import KnowledgeGraph, MetadataFilter
from .constants import (
GRAPH_FIELD_SEP,
DEFAULT_TOP_K,
@ -40,42 +40,7 @@ from .constants import (
load_dotenv(dotenv_path=".env", override=False)
@dataclass
class MetadataFilter:
"""
Represents a logical expression for metadata filtering.
Args:
operator: "AND", "OR", or "NOT"
operands: List of either simple key-value pairs or nested MetadataFilter objects
"""
operator: str
operands: List[Union[Dict[str, Any], 'MetadataFilter']] = None
def __post_init__(self):
if self.operands is None:
self.operands = []
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary representation."""
return {
"operator": self.operator,
"operands": [
operand.to_dict() if isinstance(operand, MetadataFilter) else operand
for operand in self.operands
]
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter':
"""Create from dictionary representation."""
operands = []
for operand in data.get("operands", []):
if isinstance(operand, dict) and "operator" in operand:
operands.append(cls.from_dict(operand))
else:
operands.append(operand)
return cls(operator=data.get("operator", "AND"), operands=operands)
@ -266,7 +231,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
@abstractmethod
async def query(
self, query: str, top_k: int, query_embedding: list[float] = None
self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: MetadataFilter | None = None
) -> list[dict[str, Any]]:
"""Query the vector storage and retrieve top_k results.

View file

@ -11,7 +11,7 @@ import configparser
import ssl
import itertools
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge, MetadataFilter
from tenacity import (
retry,
@ -931,6 +931,13 @@ class PostgreSQLDB:
logger.error(
f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}"
)
# Compatibility check - add metadata columns to LIGHTRAG_DOC_CHUNKS and LIGHTRAG_VDB_CHUNKS
try:
await self.add_metadata_to_tables()
except Exception as e:
logger.error(f"PostgreSQL, Failed to add metadata columns to existing tables: {e}")
# After all tables are created, attempt to migrate timestamp fields
try:
await self._migrate_timestamp_columns()
@ -1010,6 +1017,54 @@ class PostgreSQLDB:
f"PostgreSQL, Failed to create full entities/relations tables: {e}"
)
async def add_metadata_to_tables(self):
"""Add metadata columns to LIGHTRAG_DOC_CHUNKS and LIGHTRAG_VDB_CHUNKS tables if they don't exist"""
tables_to_check = [
{
"name": "LIGHTRAG_DOC_CHUNKS",
"description": "Document chunks storage table",
},
{
"name": "LIGHTRAG_VDB_CHUNKS",
"description": "Vector database chunks storage table",
},
]
for table_info in tables_to_check:
table_name = table_info["name"]
try:
# Check if metadata column exists
check_column_sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_name = $1
AND column_name = 'metadata'
"""
column_info = await self.query(
check_column_sql, {"table_name": table_name.lower()}
)
if not column_info:
logger.info(f"Adding metadata column to {table_name} table")
add_column_sql = f"""
ALTER TABLE {table_name}
ADD COLUMN metadata JSONB NULL DEFAULT '{{}}'::jsonb
"""
await self.execute(add_column_sql)
logger.info(
f"Successfully added metadata column to {table_name} table"
)
else:
logger.debug(
f"metadata column already exists in {table_name} table"
)
except Exception as e:
logger.warning(
f"Failed to add metadata column to {table_name}: {e}"
)
async def _migrate_create_full_entities_relations_tables(self):
"""Create LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS tables if they don't exist"""
tables_to_check = [
@ -1742,6 +1797,7 @@ class PGKVStorage(BaseKVStorage):
"llm_cache_list": json.dumps(v.get("llm_cache_list", [])),
"create_time": current_time,
"update_time": current_time,
"metadata": json.dumps(v.get("metadata", {})),
}
await self.db.execute(upsert_sql, _data)
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
@ -1908,6 +1964,7 @@ class PGVectorStorage(BaseVectorStorage):
"file_path": item["file_path"],
"create_time": current_time,
"update_time": current_time,
"metadata": json.dumps(item.get("metadata", {})),
}
except Exception as e:
logger.error(
@ -2002,9 +2059,71 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data)
############# Metadata building function #################
@staticmethod
def build_metadata_filter_clause(metadata_filter):
def escape_str(val: str) -> str:
return str(val).replace("'", "''") # escape single quotes
def build_single_condition(key, value):
if isinstance(value, (dict, list)):
json_value = json.dumps(value).replace("'", "''")
return f"metadata->'{key}' = '{json_value}'::jsonb"
elif isinstance(value, (int, float, bool)):
return f"metadata->'{key}' = '{json.dumps(value)}'::jsonb"
else: # string
return f"metadata->>'{key}' = '{escape_str(value)}'"
def build_conditions(filter_dict):
conditions = []
for key, value in filter_dict.items():
if isinstance(value, (list, tuple)):
conds = [build_single_condition(key, v) for v in value]
conditions.append("(" + " OR ".join(conds) + ")")
else:
conditions.append(build_single_condition(key, value))
return conditions
def recurse(filter_obj):
if isinstance(filter_obj, dict):
return build_conditions(filter_obj)
if isinstance(filter_obj, MetadataFilter):
sub_conditions = []
for operand in filter_obj.operands:
if isinstance(operand, dict):
sub_conditions.append("(" + " AND ".join(build_conditions(operand)) + ")")
elif isinstance(operand, MetadataFilter):
nested = recurse(operand)
if nested:
sub_conditions.append("(" + " AND ".join(nested) + ")")
if not sub_conditions:
return []
op = filter_obj.operator.upper()
if op == "AND":
return [" AND ".join(sub_conditions)]
elif op == "OR":
return [" OR ".join(sub_conditions)]
elif op == "NOT":
if len(sub_conditions) == 1:
return [f"NOT {sub_conditions[0]}"]
else:
return [f"NOT ({' AND '.join(sub_conditions)})"]
return []
conditions = recurse(metadata_filter)
clause = ""
if conditions:
clause = " AND " + " AND ".join(conditions)
return clause
#################### query method ###############
async def query(
self, query: str, top_k: int, query_embedding: list[float] = None
self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: MetadataFilter | None = None
) -> list[dict[str, Any]]:
if query_embedding is not None:
embedding = query_embedding
@ -2015,8 +2134,8 @@ class PGVectorStorage(BaseVectorStorage):
embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding))
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
metadata_filter_clause = self.build_metadata_filter_clause(metadata_filter)
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string, metadata_filter_clause=metadata_filter_clause)
params = {
"workspace": self.workspace,
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
@ -4480,6 +4599,7 @@ TABLES = {
content TEXT,
file_path TEXT NULL,
llm_cache_list JSONB NULL DEFAULT '[]'::jsonb,
metadata JSONB NULL DEFAULT '{}'::jsonb,
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
@ -4495,6 +4615,7 @@ TABLES = {
content TEXT,
content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
file_path TEXT NULL,
metadata JSONB NULL DEFAULT '{{}}'::jsonb,
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id)
@ -4656,8 +4777,8 @@ SQL_TEMPLATES = {
""",
"upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, file_path, llm_cache_list,
create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
create_time, update_time, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (workspace,id) DO UPDATE
SET tokens=EXCLUDED.tokens,
chunk_order_index=EXCLUDED.chunk_order_index,
@ -4665,7 +4786,8 @@ SQL_TEMPLATES = {
content = EXCLUDED.content,
file_path=EXCLUDED.file_path,
llm_cache_list=EXCLUDED.llm_cache_list,
update_time = EXCLUDED.update_time
update_time = EXCLUDED.update_time,
metadata = EXCLUDED.metadata
""",
"upsert_full_entities": """INSERT INTO LIGHTRAG_FULL_ENTITIES (workspace, id, entity_names, count,
create_time, update_time)
@ -4686,8 +4808,8 @@ SQL_TEMPLATES = {
# SQL for VectorStorage
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, content_vector, file_path,
create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
create_time, update_time, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (workspace,id) DO UPDATE
SET tokens=EXCLUDED.tokens,
chunk_order_index=EXCLUDED.chunk_order_index,
@ -4695,7 +4817,8 @@ SQL_TEMPLATES = {
content = EXCLUDED.content,
content_vector=EXCLUDED.content_vector,
file_path=EXCLUDED.file_path,
update_time = EXCLUDED.update_time
update_time = EXCLUDED.update_time,
metadata = EXCLUDED.metadata
""",
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
content_vector, chunk_ids, file_path, create_time, update_time)
@ -4721,34 +4844,71 @@ SQL_TEMPLATES = {
update_time = EXCLUDED.update_time
""",
"relationships": """
SELECT r.source_id AS src_id,
r.target_id AS tgt_id,
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
FROM LIGHTRAG_VDB_RELATION r
WHERE r.workspace = $1
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
LIMIT $3;
WITH relevant_chunks AS (
SELECT id as chunk_id
FROM LIGHTRAG_VDB_CHUNKS
WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[]))
{metadata_filter_clause}
),
rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
)
SELECT r.source_id AS src_id,
r.target_id AS tgt_id,
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
FROM LIGHTRAG_VDB_RELATION r
JOIN rc ON TRUE
WHERE r.workspace = $1
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
AND r.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
LIMIT $3;
""",
"entities": """
WITH relevant_chunks AS (
SELECT id as chunk_id
FROM LIGHTRAG_VDB_CHUNKS
WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[]))
{metadata_filter_clause}
),
rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
)
SELECT e.entity_name,
EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
FROM LIGHTRAG_VDB_ENTITY e
JOIN rc ON TRUE
WHERE e.workspace = $1
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
AND e.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
LIMIT $3;
""",
"chunks": """
SELECT c.id,
c.content,
c.file_path,
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
FROM LIGHTRAG_VDB_CHUNKS c
WHERE c.workspace = $1
AND c.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
LIMIT $3;
WITH relevant_chunks AS (
SELECT id as chunk_id
FROM LIGHTRAG_VDB_CHUNKS
WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[]))
{metadata_filter_clause}
),
rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
)
SELECT c.id,
c.content,
c.file_path,
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at,
c.metadata
FROM LIGHTRAG_VDB_CHUNKS c
JOIN rc ON TRUE
WHERE c.workspace = $1
AND c.content_vector <=> '[{embedding_string}]'::vector < $2
AND c.id = ANY (rc.chunk_arr)
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
LIMIT $3;
""",
# DROP tables
"drop_specifiy_table_workspace": """

View file

@ -41,7 +41,6 @@ from .base import (
QueryParam,
QueryResult,
QueryContextResult,
MetadataFilter
)
from .prompt import PROMPTS
from .constants import (
@ -2409,6 +2408,7 @@ async def kg_query(
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
query_param.metadata_filter,
)
cached_result = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@ -2720,7 +2720,7 @@ async def _get_vector_context(
cosine_threshold = chunks_vdb.cosine_better_than_threshold
results = await chunks_vdb.query(
query, top_k=search_top_k, query_embedding=query_embedding
query, top_k=search_top_k, query_embedding=query_embedding, metadata_filter=query_param.metadata_filter
)
if not results:
logger.info(
@ -2737,6 +2737,7 @@ async def _get_vector_context(
"file_path": result.get("file_path", "unknown_source"),
"source_type": "vector", # Mark the source type
"chunk_id": result.get("id"), # Add chunk_id for deduplication
"metadata": result.get("metadata")
}
valid_chunks.append(chunk_with_metadata)
@ -3561,7 +3562,8 @@ async def _get_node_data(
f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})"
)
results = await entities_vdb.query(query, top_k=query_param.top_k)
results = await entities_vdb.query(query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
if not len(results):
return [], []
@ -3570,21 +3572,13 @@ async def _get_node_data(
node_ids = [r["entity_name"] for r in results]
# TODO update method to take in the metadata_filter dataclass
node_kg_ids = []
if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"):
node_kg_ids = await asyncio.gather(
knowledge_graph_inst.get_nodes_by_metadata_filter(QueryParam.metadata_filter)
)
filtered_node_ids = (
[nid for nid in node_ids if nid in node_kg_ids] if node_kg_ids else node_ids
)
# Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results]
# Call the batch node retrieval and degree functions concurrently.
nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(filtered_node_ids),
knowledge_graph_inst.node_degrees_batch(filtered_node_ids),
knowledge_graph_inst.get_nodes_batch(node_ids),
knowledge_graph_inst.node_degrees_batch(node_ids),
)
# Now, if you need the node data and degree in order:
@ -3849,7 +3843,7 @@ async def _get_edge_data(
f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})"
)
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
results = await relationships_vdb.query(keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
if not len(results):
return [], []

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from pydantic import BaseModel
from typing import Any, Optional
from pydantic import BaseModel, Field, validator
from typing import Any, Optional, List, Union, Dict
class GPTKeywordExtractionFormat(BaseModel):
@ -27,3 +27,46 @@ class KnowledgeGraph(BaseModel):
nodes: list[KnowledgeGraphNode] = []
edges: list[KnowledgeGraphEdge] = []
is_truncated: bool = False
class MetadataFilter(BaseModel):
"""
Represents a logical expression for metadata filtering.
Args:
operator: "AND", "OR", or "NOT"
operands: List of either simple key-value pairs or nested MetadataFilter objects
"""
operator: str = Field(..., description="Logical operator: AND, OR, or NOT")
operands: List[Union[Dict[str, Any], 'MetadataFilter']] = Field(default_factory=list, description="List of operands for filtering")
@validator('operator')
def validate_operator(cls, v):
if v not in ["AND", "OR", "NOT"]:
raise ValueError('operator must be one of: "AND", "OR", "NOT"')
return v
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary representation."""
return {
"operator": self.operator,
"operands": [
operand.dict() if isinstance(operand, MetadataFilter) else operand
for operand in self.operands
]
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter':
"""Create from dictionary representation."""
operands = []
for operand in data.get("operands", []):
if isinstance(operand, dict) and "operator" in operand:
operands.append(cls.from_dict(operand))
else:
operands.append(operand)
return cls(operator=data.get("operator", "AND"), operands=operands)
class Config:
"""Pydantic configuration."""
validate_assignment = True