feat (metadata filter): added metadata filtering

Added functioning (needs testing) metadata filtering on chunks for
query. Fully implemented only on Postgres with pgvector and Neo4j
This commit is contained in:
Giulio Grassia 2025-09-24 16:57:44 +02:00
parent 40afb0441a
commit d0fba28e1f
5 changed files with 263 additions and 90 deletions

View file

@ -4,11 +4,11 @@ This module contains all query-related routes for the LightRAG API.
import json import json
import logging import logging
import textwrap
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam, MetadataFilter from lightrag.base import QueryParam
from lightrag.types import MetadataFilter
from lightrag.api.utils_api import get_combined_auth_dependency from lightrag.api.utils_api import get_combined_auth_dependency
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -24,8 +24,8 @@ class QueryRequest(BaseModel):
) )
metadata_filter: MetadataFilter | None = Field( metadata_filter: MetadataFilter | None = Field(
default=textwrap.dedent('{"operator": "AND","operands": [{"MetadataFilter": {}},"string"]},'), default=None,
description="Optional dictionary of metadata key-value pairs to filter nodes", description="Optional metadata filter for nodes and edges. Can be a MetadataFilter object or a dict that will be converted to MetadataFilter.",
) )
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field( mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
@ -118,6 +118,16 @@ class QueryRequest(BaseModel):
) )
return conversation_history return conversation_history
@field_validator("metadata_filter", mode="before")
@classmethod
def metadata_filter_convert(cls, v):
"""Convert dict inputs to MetadataFilter objects."""
if v is None:
return None
if isinstance(v, dict):
return MetadataFilter.from_dict(v)
return v
def to_query_params(self, is_stream: bool) -> "QueryParam": def to_query_params(self, is_stream: bool) -> "QueryParam":
"""Converts a QueryRequest instance into a QueryParam instance.""" """Converts a QueryRequest instance into a QueryParam instance."""
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
@ -126,6 +136,11 @@ class QueryRequest(BaseModel):
# Ensure `mode` and `stream` are set explicitly # Ensure `mode` and `stream` are set explicitly
param = QueryParam(**request_data) param = QueryParam(**request_data)
param.stream = is_stream param.stream = is_stream
# Ensure metadata_filter remains as MetadataFilter object if it exists
if self.metadata_filter:
param.metadata_filter = self.metadata_filter
return param return param
@ -175,10 +190,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
try: try:
param = request.to_query_params(False) param = request.to_query_params(False)
# Inject metadata_filter into param if present
if request.metadata_filter:
setattr(param, "metadata_filter", request.metadata_filter)
response = await rag.aquery(request.query, param=param) response = await rag.aquery(request.query, param=param)
# Get reference list if requested # Get reference list if requested

View file

@ -18,7 +18,7 @@ from typing import (
Union, Union,
) )
from .utils import EmbeddingFunc from .utils import EmbeddingFunc
from .types import KnowledgeGraph from .types import KnowledgeGraph, MetadataFilter
from .constants import ( from .constants import (
GRAPH_FIELD_SEP, GRAPH_FIELD_SEP,
DEFAULT_TOP_K, DEFAULT_TOP_K,
@ -40,42 +40,7 @@ from .constants import (
load_dotenv(dotenv_path=".env", override=False) load_dotenv(dotenv_path=".env", override=False)
@dataclass
class MetadataFilter:
"""
Represents a logical expression for metadata filtering.
Args:
operator: "AND", "OR", or "NOT"
operands: List of either simple key-value pairs or nested MetadataFilter objects
"""
operator: str
operands: List[Union[Dict[str, Any], 'MetadataFilter']] = None
def __post_init__(self):
if self.operands is None:
self.operands = []
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary representation."""
return {
"operator": self.operator,
"operands": [
operand.to_dict() if isinstance(operand, MetadataFilter) else operand
for operand in self.operands
]
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter':
"""Create from dictionary representation."""
operands = []
for operand in data.get("operands", []):
if isinstance(operand, dict) and "operator" in operand:
operands.append(cls.from_dict(operand))
else:
operands.append(operand)
return cls(operator=data.get("operator", "AND"), operands=operands)
@ -266,7 +231,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def query( async def query(
self, query: str, top_k: int, query_embedding: list[float] = None self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: MetadataFilter | None = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Query the vector storage and retrieve top_k results. """Query the vector storage and retrieve top_k results.

View file

@ -11,7 +11,7 @@ import configparser
import ssl import ssl
import itertools import itertools
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge, MetadataFilter
from tenacity import ( from tenacity import (
retry, retry,
@ -931,6 +931,13 @@ class PostgreSQLDB:
logger.error( logger.error(
f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}" f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}"
) )
# Compatibility check - add metadata columns to LIGHTRAG_DOC_CHUNKS and LIGHTRAG_VDB_CHUNKS
try:
await self.add_metadata_to_tables()
except Exception as e:
logger.error(f"PostgreSQL, Failed to add metadata columns to existing tables: {e}")
# After all tables are created, attempt to migrate timestamp fields # After all tables are created, attempt to migrate timestamp fields
try: try:
await self._migrate_timestamp_columns() await self._migrate_timestamp_columns()
@ -1010,6 +1017,54 @@ class PostgreSQLDB:
f"PostgreSQL, Failed to create full entities/relations tables: {e}" f"PostgreSQL, Failed to create full entities/relations tables: {e}"
) )
async def add_metadata_to_tables(self):
"""Add metadata columns to LIGHTRAG_DOC_CHUNKS and LIGHTRAG_VDB_CHUNKS tables if they don't exist"""
tables_to_check = [
{
"name": "LIGHTRAG_DOC_CHUNKS",
"description": "Document chunks storage table",
},
{
"name": "LIGHTRAG_VDB_CHUNKS",
"description": "Vector database chunks storage table",
},
]
for table_info in tables_to_check:
table_name = table_info["name"]
try:
# Check if metadata column exists
check_column_sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_name = $1
AND column_name = 'metadata'
"""
column_info = await self.query(
check_column_sql, {"table_name": table_name.lower()}
)
if not column_info:
logger.info(f"Adding metadata column to {table_name} table")
add_column_sql = f"""
ALTER TABLE {table_name}
ADD COLUMN metadata JSONB NULL DEFAULT '{{}}'::jsonb
"""
await self.execute(add_column_sql)
logger.info(
f"Successfully added metadata column to {table_name} table"
)
else:
logger.debug(
f"metadata column already exists in {table_name} table"
)
except Exception as e:
logger.warning(
f"Failed to add metadata column to {table_name}: {e}"
)
async def _migrate_create_full_entities_relations_tables(self): async def _migrate_create_full_entities_relations_tables(self):
"""Create LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS tables if they don't exist""" """Create LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS tables if they don't exist"""
tables_to_check = [ tables_to_check = [
@ -1742,6 +1797,7 @@ class PGKVStorage(BaseKVStorage):
"llm_cache_list": json.dumps(v.get("llm_cache_list", [])), "llm_cache_list": json.dumps(v.get("llm_cache_list", [])),
"create_time": current_time, "create_time": current_time,
"update_time": current_time, "update_time": current_time,
"metadata": json.dumps(v.get("metadata", {})),
} }
await self.db.execute(upsert_sql, _data) await self.db.execute(upsert_sql, _data)
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
@ -1908,6 +1964,7 @@ class PGVectorStorage(BaseVectorStorage):
"file_path": item["file_path"], "file_path": item["file_path"],
"create_time": current_time, "create_time": current_time,
"update_time": current_time, "update_time": current_time,
"metadata": json.dumps(item.get("metadata", {})),
} }
except Exception as e: except Exception as e:
logger.error( logger.error(
@ -2002,9 +2059,71 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data) await self.db.execute(upsert_sql, data)
############# Metadata building function #################
@staticmethod
def build_metadata_filter_clause(metadata_filter):
def escape_str(val: str) -> str:
return str(val).replace("'", "''") # escape single quotes
def build_single_condition(key, value):
if isinstance(value, (dict, list)):
json_value = json.dumps(value).replace("'", "''")
return f"metadata->'{key}' = '{json_value}'::jsonb"
elif isinstance(value, (int, float, bool)):
return f"metadata->'{key}' = '{json.dumps(value)}'::jsonb"
else: # string
return f"metadata->>'{key}' = '{escape_str(value)}'"
def build_conditions(filter_dict):
conditions = []
for key, value in filter_dict.items():
if isinstance(value, (list, tuple)):
conds = [build_single_condition(key, v) for v in value]
conditions.append("(" + " OR ".join(conds) + ")")
else:
conditions.append(build_single_condition(key, value))
return conditions
def recurse(filter_obj):
if isinstance(filter_obj, dict):
return build_conditions(filter_obj)
if isinstance(filter_obj, MetadataFilter):
sub_conditions = []
for operand in filter_obj.operands:
if isinstance(operand, dict):
sub_conditions.append("(" + " AND ".join(build_conditions(operand)) + ")")
elif isinstance(operand, MetadataFilter):
nested = recurse(operand)
if nested:
sub_conditions.append("(" + " AND ".join(nested) + ")")
if not sub_conditions:
return []
op = filter_obj.operator.upper()
if op == "AND":
return [" AND ".join(sub_conditions)]
elif op == "OR":
return [" OR ".join(sub_conditions)]
elif op == "NOT":
if len(sub_conditions) == 1:
return [f"NOT {sub_conditions[0]}"]
else:
return [f"NOT ({' AND '.join(sub_conditions)})"]
return []
conditions = recurse(metadata_filter)
clause = ""
if conditions:
clause = " AND " + " AND ".join(conditions)
return clause
#################### query method ############### #################### query method ###############
async def query( async def query(
self, query: str, top_k: int, query_embedding: list[float] = None self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: MetadataFilter | None = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
if query_embedding is not None: if query_embedding is not None:
embedding = query_embedding embedding = query_embedding
@ -2015,8 +2134,8 @@ class PGVectorStorage(BaseVectorStorage):
embedding = embeddings[0] embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding)) embedding_string = ",".join(map(str, embedding))
metadata_filter_clause = self.build_metadata_filter_clause(metadata_filter)
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string, metadata_filter_clause=metadata_filter_clause)
params = { params = {
"workspace": self.workspace, "workspace": self.workspace,
"closer_than_threshold": 1 - self.cosine_better_than_threshold, "closer_than_threshold": 1 - self.cosine_better_than_threshold,
@ -4480,6 +4599,7 @@ TABLES = {
content TEXT, content TEXT,
file_path TEXT NULL, file_path TEXT NULL,
llm_cache_list JSONB NULL DEFAULT '[]'::jsonb, llm_cache_list JSONB NULL DEFAULT '[]'::jsonb,
metadata JSONB NULL DEFAULT '{}'::jsonb,
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
@ -4495,6 +4615,7 @@ TABLES = {
content TEXT, content TEXT,
content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}), content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
file_path TEXT NULL, file_path TEXT NULL,
metadata JSONB NULL DEFAULT '{{}}'::jsonb,
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id) CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id)
@ -4656,8 +4777,8 @@ SQL_TEMPLATES = {
""", """,
"upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, file_path, llm_cache_list, chunk_order_index, full_doc_id, content, file_path, llm_cache_list,
create_time, update_time) create_time, update_time, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (workspace,id) DO UPDATE ON CONFLICT (workspace,id) DO UPDATE
SET tokens=EXCLUDED.tokens, SET tokens=EXCLUDED.tokens,
chunk_order_index=EXCLUDED.chunk_order_index, chunk_order_index=EXCLUDED.chunk_order_index,
@ -4665,7 +4786,8 @@ SQL_TEMPLATES = {
content = EXCLUDED.content, content = EXCLUDED.content,
file_path=EXCLUDED.file_path, file_path=EXCLUDED.file_path,
llm_cache_list=EXCLUDED.llm_cache_list, llm_cache_list=EXCLUDED.llm_cache_list,
update_time = EXCLUDED.update_time update_time = EXCLUDED.update_time,
metadata = EXCLUDED.metadata
""", """,
"upsert_full_entities": """INSERT INTO LIGHTRAG_FULL_ENTITIES (workspace, id, entity_names, count, "upsert_full_entities": """INSERT INTO LIGHTRAG_FULL_ENTITIES (workspace, id, entity_names, count,
create_time, update_time) create_time, update_time)
@ -4686,8 +4808,8 @@ SQL_TEMPLATES = {
# SQL for VectorStorage # SQL for VectorStorage
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens, "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, content_vector, file_path, chunk_order_index, full_doc_id, content, content_vector, file_path,
create_time, update_time) create_time, update_time, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (workspace,id) DO UPDATE ON CONFLICT (workspace,id) DO UPDATE
SET tokens=EXCLUDED.tokens, SET tokens=EXCLUDED.tokens,
chunk_order_index=EXCLUDED.chunk_order_index, chunk_order_index=EXCLUDED.chunk_order_index,
@ -4695,7 +4817,8 @@ SQL_TEMPLATES = {
content = EXCLUDED.content, content = EXCLUDED.content,
content_vector=EXCLUDED.content_vector, content_vector=EXCLUDED.content_vector,
file_path=EXCLUDED.file_path, file_path=EXCLUDED.file_path,
update_time = EXCLUDED.update_time update_time = EXCLUDED.update_time,
metadata = EXCLUDED.metadata
""", """,
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
content_vector, chunk_ids, file_path, create_time, update_time) content_vector, chunk_ids, file_path, create_time, update_time)
@ -4721,32 +4844,69 @@ SQL_TEMPLATES = {
update_time = EXCLUDED.update_time update_time = EXCLUDED.update_time
""", """,
"relationships": """ "relationships": """
WITH relevant_chunks AS (
SELECT id as chunk_id
FROM LIGHTRAG_VDB_CHUNKS
WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[]))
{metadata_filter_clause}
),
rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
)
SELECT r.source_id AS src_id, SELECT r.source_id AS src_id,
r.target_id AS tgt_id, r.target_id AS tgt_id,
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
FROM LIGHTRAG_VDB_RELATION r FROM LIGHTRAG_VDB_RELATION r
JOIN rc ON TRUE
WHERE r.workspace = $1 WHERE r.workspace = $1
AND r.content_vector <=> '[{embedding_string}]'::vector < $2 AND r.content_vector <=> '[{embedding_string}]'::vector < $2
AND r.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
LIMIT $3; LIMIT $3;
""", """,
"entities": """ "entities": """
WITH relevant_chunks AS (
SELECT id as chunk_id
FROM LIGHTRAG_VDB_CHUNKS
WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[]))
{metadata_filter_clause}
),
rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
)
SELECT e.entity_name, SELECT e.entity_name,
EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
FROM LIGHTRAG_VDB_ENTITY e FROM LIGHTRAG_VDB_ENTITY e
JOIN rc ON TRUE
WHERE e.workspace = $1 WHERE e.workspace = $1
AND e.content_vector <=> '[{embedding_string}]'::vector < $2 AND e.content_vector <=> '[{embedding_string}]'::vector < $2
AND e.chunk_ids && (rc.chunk_arr::varchar[])
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
LIMIT $3; LIMIT $3;
""", """,
"chunks": """ "chunks": """
WITH relevant_chunks AS (
SELECT id as chunk_id
FROM LIGHTRAG_VDB_CHUNKS
WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[]))
{metadata_filter_clause}
),
rc AS (
SELECT array_agg(chunk_id) AS chunk_arr
FROM relevant_chunks
)
SELECT c.id, SELECT c.id,
c.content, c.content,
c.file_path, c.file_path,
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at,
c.metadata
FROM LIGHTRAG_VDB_CHUNKS c FROM LIGHTRAG_VDB_CHUNKS c
JOIN rc ON TRUE
WHERE c.workspace = $1 WHERE c.workspace = $1
AND c.content_vector <=> '[{embedding_string}]'::vector < $2 AND c.content_vector <=> '[{embedding_string}]'::vector < $2
AND c.id = ANY (rc.chunk_arr)
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
LIMIT $3; LIMIT $3;
""", """,

View file

@ -41,7 +41,6 @@ from .base import (
QueryParam, QueryParam,
QueryResult, QueryResult,
QueryContextResult, QueryContextResult,
MetadataFilter
) )
from .prompt import PROMPTS from .prompt import PROMPTS
from .constants import ( from .constants import (
@ -2409,6 +2408,7 @@ async def kg_query(
query_param.ll_keywords or [], query_param.ll_keywords or [],
query_param.user_prompt or "", query_param.user_prompt or "",
query_param.enable_rerank, query_param.enable_rerank,
query_param.metadata_filter,
) )
cached_result = await handle_cache( cached_result = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@ -2720,7 +2720,7 @@ async def _get_vector_context(
cosine_threshold = chunks_vdb.cosine_better_than_threshold cosine_threshold = chunks_vdb.cosine_better_than_threshold
results = await chunks_vdb.query( results = await chunks_vdb.query(
query, top_k=search_top_k, query_embedding=query_embedding query, top_k=search_top_k, query_embedding=query_embedding, metadata_filter=query_param.metadata_filter
) )
if not results: if not results:
logger.info( logger.info(
@ -2737,6 +2737,7 @@ async def _get_vector_context(
"file_path": result.get("file_path", "unknown_source"), "file_path": result.get("file_path", "unknown_source"),
"source_type": "vector", # Mark the source type "source_type": "vector", # Mark the source type
"chunk_id": result.get("id"), # Add chunk_id for deduplication "chunk_id": result.get("id"), # Add chunk_id for deduplication
"metadata": result.get("metadata")
} }
valid_chunks.append(chunk_with_metadata) valid_chunks.append(chunk_with_metadata)
@ -3561,7 +3562,8 @@ async def _get_node_data(
f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})" f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})"
) )
results = await entities_vdb.query(query, top_k=query_param.top_k)
results = await entities_vdb.query(query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
if not len(results): if not len(results):
return [], [] return [], []
@ -3570,21 +3572,13 @@ async def _get_node_data(
node_ids = [r["entity_name"] for r in results] node_ids = [r["entity_name"] for r in results]
# TODO update method to take in the metadata_filter dataclass # Extract all entity IDs from your results list
node_kg_ids = [] node_ids = [r["entity_name"] for r in results]
if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"):
node_kg_ids = await asyncio.gather(
knowledge_graph_inst.get_nodes_by_metadata_filter(QueryParam.metadata_filter)
)
filtered_node_ids = (
[nid for nid in node_ids if nid in node_kg_ids] if node_kg_ids else node_ids
)
# Call the batch node retrieval and degree functions concurrently. # Call the batch node retrieval and degree functions concurrently.
nodes_dict, degrees_dict = await asyncio.gather( nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(filtered_node_ids), knowledge_graph_inst.get_nodes_batch(node_ids),
knowledge_graph_inst.node_degrees_batch(filtered_node_ids), knowledge_graph_inst.node_degrees_batch(node_ids),
) )
# Now, if you need the node data and degree in order: # Now, if you need the node data and degree in order:
@ -3849,7 +3843,7 @@ async def _get_edge_data(
f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})" f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})"
) )
results = await relationships_vdb.query(keywords, top_k=query_param.top_k) results = await relationships_vdb.query(keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
if not len(results): if not len(results):
return [], [] return [], []

View file

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel from pydantic import BaseModel, Field, validator
from typing import Any, Optional from typing import Any, Optional, List, Union, Dict
class GPTKeywordExtractionFormat(BaseModel): class GPTKeywordExtractionFormat(BaseModel):
@ -27,3 +27,46 @@ class KnowledgeGraph(BaseModel):
nodes: list[KnowledgeGraphNode] = [] nodes: list[KnowledgeGraphNode] = []
edges: list[KnowledgeGraphEdge] = [] edges: list[KnowledgeGraphEdge] = []
is_truncated: bool = False is_truncated: bool = False
class MetadataFilter(BaseModel):
"""
Represents a logical expression for metadata filtering.
Args:
operator: "AND", "OR", or "NOT"
operands: List of either simple key-value pairs or nested MetadataFilter objects
"""
operator: str = Field(..., description="Logical operator: AND, OR, or NOT")
operands: List[Union[Dict[str, Any], 'MetadataFilter']] = Field(default_factory=list, description="List of operands for filtering")
@validator('operator')
def validate_operator(cls, v):
if v not in ["AND", "OR", "NOT"]:
raise ValueError('operator must be one of: "AND", "OR", "NOT"')
return v
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary representation."""
return {
"operator": self.operator,
"operands": [
operand.dict() if isinstance(operand, MetadataFilter) else operand
for operand in self.operands
]
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter':
"""Create from dictionary representation."""
operands = []
for operand in data.get("operands", []):
if isinstance(operand, dict) and "operator" in operand:
operands.append(cls.from_dict(operand))
else:
operands.append(operand)
return cls(operator=data.get("operator", "AND"), operands=operands)
class Config:
"""Pydantic configuration."""
validate_assignment = True