From d0fba28e1f9da4392b6feff288f4f457b203e781 Mon Sep 17 00:00:00 2001 From: Giulio Grassia Date: Wed, 24 Sep 2025 16:57:44 +0200 Subject: [PATCH] feat (metadata filter): added metadata filtering Added functioning (needs testing) metadata filtering on chunks for query. Fully implemented only on Postgres with pgvector and Neo4j --- lightrag/api/routers/query_routes.py | 27 +++- lightrag/base.py | 39 +---- lightrag/kg/postgres_impl.py | 214 +++++++++++++++++++++++---- lightrag/operate.py | 26 ++-- lightrag/types.py | 47 +++++- 5 files changed, 263 insertions(+), 90 deletions(-) diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index e99294b5..a9f62215 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -4,11 +4,11 @@ This module contains all query-related routes for the LightRAG API. import json import logging -import textwrap from typing import Any, Dict, List, Literal, Optional from fastapi import APIRouter, Depends, HTTPException -from lightrag.base import QueryParam, MetadataFilter +from lightrag.base import QueryParam +from lightrag.types import MetadataFilter from lightrag.api.utils_api import get_combined_auth_dependency from pydantic import BaseModel, Field, field_validator @@ -24,8 +24,8 @@ class QueryRequest(BaseModel): ) metadata_filter: MetadataFilter | None = Field( - default=textwrap.dedent('{"operator": "AND","operands": [{"MetadataFilter": {}},"string"]},'), - description="Optional dictionary of metadata key-value pairs to filter nodes", + default=None, + description="Optional metadata filter for nodes and edges. Can be a MetadataFilter object or a dict that will be converted to MetadataFilter.", ) mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field( @@ -118,6 +118,16 @@ class QueryRequest(BaseModel): ) return conversation_history + @field_validator("metadata_filter", mode="before") + @classmethod + def metadata_filter_convert(cls, v): + """Convert dict inputs to MetadataFilter objects.""" + if v is None: + return None + if isinstance(v, dict): + return MetadataFilter.from_dict(v) + return v + def to_query_params(self, is_stream: bool) -> "QueryParam": """Converts a QueryRequest instance into a QueryParam instance.""" # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically @@ -126,6 +136,11 @@ class QueryRequest(BaseModel): # Ensure `mode` and `stream` are set explicitly param = QueryParam(**request_data) param.stream = is_stream + + # Ensure metadata_filter remains as MetadataFilter object if it exists + if self.metadata_filter: + param.metadata_filter = self.metadata_filter + return param @@ -175,10 +190,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): try: param = request.to_query_params(False) - # Inject metadata_filter into param if present - if request.metadata_filter: - setattr(param, "metadata_filter", request.metadata_filter) - response = await rag.aquery(request.query, param=param) # Get reference list if requested diff --git a/lightrag/base.py b/lightrag/base.py index e1c808a3..2d13a98a 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -18,7 +18,7 @@ from typing import ( Union, ) from .utils import EmbeddingFunc -from .types import KnowledgeGraph +from .types import KnowledgeGraph, MetadataFilter from .constants import ( GRAPH_FIELD_SEP, DEFAULT_TOP_K, @@ -40,42 +40,7 @@ from .constants import ( load_dotenv(dotenv_path=".env", override=False) -@dataclass -class MetadataFilter: - """ - Represents a logical expression for metadata filtering. - Args: - operator: "AND", "OR", or "NOT" - operands: List of either simple key-value pairs or nested MetadataFilter objects - """ - operator: str - operands: List[Union[Dict[str, Any], 'MetadataFilter']] = None - - def __post_init__(self): - if self.operands is None: - self.operands = [] - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary representation.""" - return { - "operator": self.operator, - "operands": [ - operand.to_dict() if isinstance(operand, MetadataFilter) else operand - for operand in self.operands - ] - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter': - """Create from dictionary representation.""" - operands = [] - for operand in data.get("operands", []): - if isinstance(operand, dict) and "operator" in operand: - operands.append(cls.from_dict(operand)) - else: - operands.append(operand) - return cls(operator=data.get("operator", "AND"), operands=operands) @@ -266,7 +231,7 @@ class BaseVectorStorage(StorageNameSpace, ABC): @abstractmethod async def query( - self, query: str, top_k: int, query_embedding: list[float] = None + self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: MetadataFilter | None = None ) -> list[dict[str, Any]]: """Query the vector storage and retrieve top_k results. diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index ad271b15..099cefd7 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -11,7 +11,7 @@ import configparser import ssl import itertools -from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge, MetadataFilter from tenacity import ( retry, @@ -931,6 +931,13 @@ class PostgreSQLDB: logger.error( f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}" ) + # Compatibility check - add metadata columns to LIGHTRAG_DOC_CHUNKS and LIGHTRAG_VDB_CHUNKS + try: + await self.add_metadata_to_tables() + except Exception as e: + logger.error(f"PostgreSQL, Failed to add metadata columns to existing tables: {e}") + + # After all tables are created, attempt to migrate timestamp fields try: await self._migrate_timestamp_columns() @@ -1010,6 +1017,54 @@ class PostgreSQLDB: f"PostgreSQL, Failed to create full entities/relations tables: {e}" ) + async def add_metadata_to_tables(self): + """Add metadata columns to LIGHTRAG_DOC_CHUNKS and LIGHTRAG_VDB_CHUNKS tables if they don't exist""" + tables_to_check = [ + { + "name": "LIGHTRAG_DOC_CHUNKS", + "description": "Document chunks storage table", + }, + { + "name": "LIGHTRAG_VDB_CHUNKS", + "description": "Vector database chunks storage table", + }, + ] + + for table_info in tables_to_check: + table_name = table_info["name"] + try: + # Check if metadata column exists + check_column_sql = """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = $1 + AND column_name = 'metadata' + """ + + column_info = await self.query( + check_column_sql, {"table_name": table_name.lower()} + ) + + if not column_info: + logger.info(f"Adding metadata column to {table_name} table") + add_column_sql = f""" + ALTER TABLE {table_name} + ADD COLUMN metadata JSONB NULL DEFAULT '{{}}'::jsonb + """ + await self.execute(add_column_sql) + logger.info( + f"Successfully added metadata column to {table_name} table" + ) + else: + logger.debug( + f"metadata column already exists in {table_name} table" + ) + + except Exception as e: + logger.warning( + f"Failed to add metadata column to {table_name}: {e}" + ) + async def _migrate_create_full_entities_relations_tables(self): """Create LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS tables if they don't exist""" tables_to_check = [ @@ -1742,6 +1797,7 @@ class PGKVStorage(BaseKVStorage): "llm_cache_list": json.dumps(v.get("llm_cache_list", [])), "create_time": current_time, "update_time": current_time, + "metadata": json.dumps(v.get("metadata", {})), } await self.db.execute(upsert_sql, _data) elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): @@ -1908,6 +1964,7 @@ class PGVectorStorage(BaseVectorStorage): "file_path": item["file_path"], "create_time": current_time, "update_time": current_time, + "metadata": json.dumps(item.get("metadata", {})), } except Exception as e: logger.error( @@ -2002,9 +2059,71 @@ class PGVectorStorage(BaseVectorStorage): await self.db.execute(upsert_sql, data) + ############# Metadata building function ################# + @staticmethod + def build_metadata_filter_clause(metadata_filter): + def escape_str(val: str) -> str: + return str(val).replace("'", "''") # escape single quotes + + def build_single_condition(key, value): + if isinstance(value, (dict, list)): + json_value = json.dumps(value).replace("'", "''") + return f"metadata->'{key}' = '{json_value}'::jsonb" + elif isinstance(value, (int, float, bool)): + return f"metadata->'{key}' = '{json.dumps(value)}'::jsonb" + else: # string + return f"metadata->>'{key}' = '{escape_str(value)}'" + + def build_conditions(filter_dict): + conditions = [] + for key, value in filter_dict.items(): + if isinstance(value, (list, tuple)): + conds = [build_single_condition(key, v) for v in value] + conditions.append("(" + " OR ".join(conds) + ")") + else: + conditions.append(build_single_condition(key, value)) + return conditions + + def recurse(filter_obj): + if isinstance(filter_obj, dict): + return build_conditions(filter_obj) + + if isinstance(filter_obj, MetadataFilter): + sub_conditions = [] + for operand in filter_obj.operands: + if isinstance(operand, dict): + sub_conditions.append("(" + " AND ".join(build_conditions(operand)) + ")") + elif isinstance(operand, MetadataFilter): + nested = recurse(operand) + if nested: + sub_conditions.append("(" + " AND ".join(nested) + ")") + + if not sub_conditions: + return [] + + op = filter_obj.operator.upper() + if op == "AND": + return [" AND ".join(sub_conditions)] + elif op == "OR": + return [" OR ".join(sub_conditions)] + elif op == "NOT": + if len(sub_conditions) == 1: + return [f"NOT {sub_conditions[0]}"] + else: + return [f"NOT ({' AND '.join(sub_conditions)})"] + + return [] + + conditions = recurse(metadata_filter) + clause = "" + if conditions: + clause = " AND " + " AND ".join(conditions) + + return clause + #################### query method ############### async def query( - self, query: str, top_k: int, query_embedding: list[float] = None + self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: MetadataFilter | None = None ) -> list[dict[str, Any]]: if query_embedding is not None: embedding = query_embedding @@ -2015,8 +2134,8 @@ class PGVectorStorage(BaseVectorStorage): embedding = embeddings[0] embedding_string = ",".join(map(str, embedding)) - - sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) + metadata_filter_clause = self.build_metadata_filter_clause(metadata_filter) + sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string, metadata_filter_clause=metadata_filter_clause) params = { "workspace": self.workspace, "closer_than_threshold": 1 - self.cosine_better_than_threshold, @@ -4480,6 +4599,7 @@ TABLES = { content TEXT, file_path TEXT NULL, llm_cache_list JSONB NULL DEFAULT '[]'::jsonb, + metadata JSONB NULL DEFAULT '{}'::jsonb, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) @@ -4495,6 +4615,7 @@ TABLES = { content TEXT, content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}), file_path TEXT NULL, + metadata JSONB NULL DEFAULT '{{}}'::jsonb, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id) @@ -4656,8 +4777,8 @@ SQL_TEMPLATES = { """, "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, chunk_order_index, full_doc_id, content, file_path, llm_cache_list, - create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + create_time, update_time, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ON CONFLICT (workspace,id) DO UPDATE SET tokens=EXCLUDED.tokens, chunk_order_index=EXCLUDED.chunk_order_index, @@ -4665,7 +4786,8 @@ SQL_TEMPLATES = { content = EXCLUDED.content, file_path=EXCLUDED.file_path, llm_cache_list=EXCLUDED.llm_cache_list, - update_time = EXCLUDED.update_time + update_time = EXCLUDED.update_time, + metadata = EXCLUDED.metadata """, "upsert_full_entities": """INSERT INTO LIGHTRAG_FULL_ENTITIES (workspace, id, entity_names, count, create_time, update_time) @@ -4686,8 +4808,8 @@ SQL_TEMPLATES = { # SQL for VectorStorage "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens, chunk_order_index, full_doc_id, content, content_vector, file_path, - create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + create_time, update_time, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ON CONFLICT (workspace,id) DO UPDATE SET tokens=EXCLUDED.tokens, chunk_order_index=EXCLUDED.chunk_order_index, @@ -4695,7 +4817,8 @@ SQL_TEMPLATES = { content = EXCLUDED.content, content_vector=EXCLUDED.content_vector, file_path=EXCLUDED.file_path, - update_time = EXCLUDED.update_time + update_time = EXCLUDED.update_time, + metadata = EXCLUDED.metadata """, "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector, chunk_ids, file_path, create_time, update_time) @@ -4721,34 +4844,71 @@ SQL_TEMPLATES = { update_time = EXCLUDED.update_time """, "relationships": """ - SELECT r.source_id AS src_id, - r.target_id AS tgt_id, - EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_RELATION r - WHERE r.workspace = $1 - AND r.content_vector <=> '[{embedding_string}]'::vector < $2 - ORDER BY r.content_vector <=> '[{embedding_string}]'::vector - LIMIT $3; + WITH relevant_chunks AS ( + SELECT id as chunk_id + FROM LIGHTRAG_VDB_CHUNKS + WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[])) + {metadata_filter_clause} + ), + rc AS ( + SELECT array_agg(chunk_id) AS chunk_arr + FROM relevant_chunks + ) + SELECT r.source_id AS src_id, + r.target_id AS tgt_id, + EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at + FROM LIGHTRAG_VDB_RELATION r + JOIN rc ON TRUE + WHERE r.workspace = $1 + AND r.content_vector <=> '[{embedding_string}]'::vector < $2 + AND r.chunk_ids && (rc.chunk_arr::varchar[]) + ORDER BY r.content_vector <=> '[{embedding_string}]'::vector + LIMIT $3; """, "entities": """ + WITH relevant_chunks AS ( + SELECT id as chunk_id + FROM LIGHTRAG_VDB_CHUNKS + WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[])) + {metadata_filter_clause} + ), + rc AS ( + SELECT array_agg(chunk_id) AS chunk_arr + FROM relevant_chunks + ) SELECT e.entity_name, EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at FROM LIGHTRAG_VDB_ENTITY e + JOIN rc ON TRUE WHERE e.workspace = $1 AND e.content_vector <=> '[{embedding_string}]'::vector < $2 + AND e.chunk_ids && (rc.chunk_arr::varchar[]) ORDER BY e.content_vector <=> '[{embedding_string}]'::vector LIMIT $3; """, "chunks": """ - SELECT c.id, - c.content, - c.file_path, - EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_CHUNKS c - WHERE c.workspace = $1 - AND c.content_vector <=> '[{embedding_string}]'::vector < $2 - ORDER BY c.content_vector <=> '[{embedding_string}]'::vector - LIMIT $3; + WITH relevant_chunks AS ( + SELECT id as chunk_id + FROM LIGHTRAG_VDB_CHUNKS + WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[])) + {metadata_filter_clause} + ), + rc AS ( + SELECT array_agg(chunk_id) AS chunk_arr + FROM relevant_chunks + ) + SELECT c.id, + c.content, + c.file_path, + EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at, + c.metadata + FROM LIGHTRAG_VDB_CHUNKS c + JOIN rc ON TRUE + WHERE c.workspace = $1 + AND c.content_vector <=> '[{embedding_string}]'::vector < $2 + AND c.id = ANY (rc.chunk_arr) + ORDER BY c.content_vector <=> '[{embedding_string}]'::vector + LIMIT $3; """, # DROP tables "drop_specifiy_table_workspace": """ diff --git a/lightrag/operate.py b/lightrag/operate.py index b62f2135..c6c13602 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -41,7 +41,6 @@ from .base import ( QueryParam, QueryResult, QueryContextResult, - MetadataFilter ) from .prompt import PROMPTS from .constants import ( @@ -2409,6 +2408,7 @@ async def kg_query( query_param.ll_keywords or [], query_param.user_prompt or "", query_param.enable_rerank, + query_param.metadata_filter, ) cached_result = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" @@ -2720,7 +2720,7 @@ async def _get_vector_context( cosine_threshold = chunks_vdb.cosine_better_than_threshold results = await chunks_vdb.query( - query, top_k=search_top_k, query_embedding=query_embedding + query, top_k=search_top_k, query_embedding=query_embedding, metadata_filter=query_param.metadata_filter ) if not results: logger.info( @@ -2737,6 +2737,7 @@ async def _get_vector_context( "file_path": result.get("file_path", "unknown_source"), "source_type": "vector", # Mark the source type "chunk_id": result.get("id"), # Add chunk_id for deduplication + "metadata": result.get("metadata") } valid_chunks.append(chunk_with_metadata) @@ -3561,7 +3562,8 @@ async def _get_node_data( f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})" ) - results = await entities_vdb.query(query, top_k=query_param.top_k) + + results = await entities_vdb.query(query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter) if not len(results): return [], [] @@ -3570,21 +3572,13 @@ async def _get_node_data( node_ids = [r["entity_name"] for r in results] - # TODO update method to take in the metadata_filter dataclass - node_kg_ids = [] - if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"): - node_kg_ids = await asyncio.gather( - knowledge_graph_inst.get_nodes_by_metadata_filter(QueryParam.metadata_filter) - ) - - filtered_node_ids = ( - [nid for nid in node_ids if nid in node_kg_ids] if node_kg_ids else node_ids - ) + # Extract all entity IDs from your results list + node_ids = [r["entity_name"] for r in results] # Call the batch node retrieval and degree functions concurrently. nodes_dict, degrees_dict = await asyncio.gather( - knowledge_graph_inst.get_nodes_batch(filtered_node_ids), - knowledge_graph_inst.node_degrees_batch(filtered_node_ids), + knowledge_graph_inst.get_nodes_batch(node_ids), + knowledge_graph_inst.node_degrees_batch(node_ids), ) # Now, if you need the node data and degree in order: @@ -3849,7 +3843,7 @@ async def _get_edge_data( f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})" ) - results = await relationships_vdb.query(keywords, top_k=query_param.top_k) + results = await relationships_vdb.query(keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter) if not len(results): return [], [] diff --git a/lightrag/types.py b/lightrag/types.py index a18f2d3c..e61910b6 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -1,7 +1,7 @@ from __future__ import annotations -from pydantic import BaseModel -from typing import Any, Optional +from pydantic import BaseModel, Field, validator +from typing import Any, Optional, List, Union, Dict class GPTKeywordExtractionFormat(BaseModel): @@ -27,3 +27,46 @@ class KnowledgeGraph(BaseModel): nodes: list[KnowledgeGraphNode] = [] edges: list[KnowledgeGraphEdge] = [] is_truncated: bool = False + + +class MetadataFilter(BaseModel): + """ + Represents a logical expression for metadata filtering. + + Args: + operator: "AND", "OR", or "NOT" + operands: List of either simple key-value pairs or nested MetadataFilter objects + """ + operator: str = Field(..., description="Logical operator: AND, OR, or NOT") + operands: List[Union[Dict[str, Any], 'MetadataFilter']] = Field(default_factory=list, description="List of operands for filtering") + + @validator('operator') + def validate_operator(cls, v): + if v not in ["AND", "OR", "NOT"]: + raise ValueError('operator must be one of: "AND", "OR", "NOT"') + return v + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary representation.""" + return { + "operator": self.operator, + "operands": [ + operand.dict() if isinstance(operand, MetadataFilter) else operand + for operand in self.operands + ] + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter': + """Create from dictionary representation.""" + operands = [] + for operand in data.get("operands", []): + if isinstance(operand, dict) and "operator" in operand: + operands.append(cls.from_dict(operand)) + else: + operands.append(operand) + return cls(operator=data.get("operator", "AND"), operands=operands) + + class Config: + """Pydantic configuration.""" + validate_assignment = True