feat (metadata): added metadata filter in query
Added metadata filter dataclass for serializing and deserializing complex filter to json dict, added node filtering based on metadata
This commit is contained in:
parent
7be24a3c60
commit
0c721fa7f1
6 changed files with 143 additions and 33 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -65,6 +65,8 @@ LightRAG.pdf
|
|||
download_models_hf.py
|
||||
lightrag-dev/
|
||||
gui/
|
||||
/md
|
||||
/uv.lock
|
||||
|
||||
# unit-test files
|
||||
test_*
|
||||
|
|
|
|||
|
|
@ -4,10 +4,11 @@ This module contains all query-related routes for the LightRAG API.
|
|||
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from lightrag.base import QueryParam
|
||||
from lightrag.base import QueryParam, MetadataFilter
|
||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
|
@ -22,8 +23,8 @@ class QueryRequest(BaseModel):
|
|||
description="The query text",
|
||||
)
|
||||
|
||||
metadata_filter: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
metadata_filter: MetadataFilter | None = Field(
|
||||
default=textwrap.dedent('{"operator": "AND","operands": [{"MetadataFilter": {}},"string"]},'),
|
||||
description="Optional dictionary of metadata key-value pairs to filter nodes",
|
||||
)
|
||||
|
||||
|
|
@ -79,7 +80,7 @@ class QueryRequest(BaseModel):
|
|||
)
|
||||
|
||||
conversation_history: Optional[List[Dict[str, Any]]] = Field(
|
||||
default=None,
|
||||
default=[],
|
||||
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from typing import (
|
|||
Dict,
|
||||
List,
|
||||
AsyncIterator,
|
||||
Union,
|
||||
)
|
||||
from .utils import EmbeddingFunc
|
||||
from .types import KnowledgeGraph
|
||||
|
|
@ -39,6 +40,46 @@ from .constants import (
|
|||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataFilter:
|
||||
"""
|
||||
Represents a logical expression for metadata filtering.
|
||||
|
||||
Args:
|
||||
operator: "AND", "OR", or "NOT"
|
||||
operands: List of either simple key-value pairs or nested MetadataFilter objects
|
||||
"""
|
||||
operator: str
|
||||
operands: List[Union[Dict[str, Any], 'MetadataFilter']] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.operands is None:
|
||||
self.operands = []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary representation."""
|
||||
return {
|
||||
"operator": self.operator,
|
||||
"operands": [
|
||||
operand.to_dict() if isinstance(operand, MetadataFilter) else operand
|
||||
for operand in self.operands
|
||||
]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter':
|
||||
"""Create from dictionary representation."""
|
||||
operands = []
|
||||
for operand in data.get("operands", []):
|
||||
if isinstance(operand, dict) and "operator" in operand:
|
||||
operands.append(cls.from_dict(operand))
|
||||
else:
|
||||
operands.append(operand)
|
||||
return cls(operator=data.get("operator", "AND"), operands=operands)
|
||||
|
||||
|
||||
|
||||
|
||||
class OllamaServerInfos:
|
||||
def __init__(self, name=None, tag=None):
|
||||
self._lightrag_name = name or os.getenv(
|
||||
|
|
@ -162,15 +203,15 @@ class QueryParam:
|
|||
Default is True to enable reranking when rerank model is available.
|
||||
"""
|
||||
|
||||
metadata_filter: MetadataFilter | None = None
|
||||
"""Metadata for filtering nodes and edges, allowing for more precise querying."""
|
||||
|
||||
include_references: bool = False
|
||||
"""If True, includes reference list in the response for supported endpoints.
|
||||
This parameter controls whether the API response includes a references field
|
||||
containing citation information for the retrieved content.
|
||||
"""
|
||||
|
||||
metadata_filter: dict | None = None
|
||||
"""Metadata for filtering nodes and edges, allowing for more precise querying."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageNameSpace(ABC):
|
||||
|
|
@ -447,11 +488,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||
or None if the node doesn't exist
|
||||
"""
|
||||
|
||||
async def get_nodes_by_metadata_filter(self, metadata_filter: str) -> list[str]:
|
||||
result = []
|
||||
text_query = f"MATCH (n:`{self.workspace_label}` {metadata_filter})"
|
||||
result = await self.query(text_query)
|
||||
return result
|
||||
async def get_nodes_by_metadata_filter(self, metadata_filter: MetadataFilter | None) -> list[str]:
|
||||
"""Get node IDs that match the given metadata filter with logical expressions."""
|
||||
# Default implementation - subclasses should override this method
|
||||
# This is a placeholder that will be overridden by specific implementations
|
||||
raise NotImplementedError("Subclasses must implement get_nodes_by_metadata_filter")
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""Get nodes as a batch using UNWIND
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@ import os
|
|||
import re
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import final
|
||||
from typing import final, Any
|
||||
import configparser
|
||||
|
||||
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
|
|
@ -15,7 +16,7 @@ from tenacity import (
|
|||
|
||||
import logging
|
||||
from ..utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
from ..base import BaseGraphStorage, MetadataFilter
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
from ..constants import GRAPH_FIELD_SEP
|
||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
||||
|
|
@ -425,23 +426,88 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
await result.consume() # Ensure results are consumed even on error
|
||||
raise
|
||||
|
||||
async def get_nodes_by_metadata_filter(self, query: str) -> list[str]:
|
||||
"""Get nodes by filtering query, return nodes id"""
|
||||
|
||||
async def get_nodes_by_metadata_filter(self, metadata_filter: MetadataFilter | None) -> list[str]:
|
||||
"""Get node IDs that match the given metadata filter with logical expressions."""
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
try:
|
||||
query = f"MATCH (n:`{workspace_label}`) {query}"
|
||||
result = await session.run(query)
|
||||
debug_results = [record async for record in result]
|
||||
debug_ids = [r["entity_id"] for r in debug_results]
|
||||
await result.consume()
|
||||
return debug_ids
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error getting node for: {str(e)}")
|
||||
raise
|
||||
|
||||
# Build metadata conditions
|
||||
params = {}
|
||||
condition, params = self._build_metadata_conditions(metadata_filter, params)
|
||||
|
||||
if not condition:
|
||||
# If no condition, return empty list for safety
|
||||
return []
|
||||
|
||||
# Build the query
|
||||
query = f"""
|
||||
MATCH (n:`{workspace_label}`)
|
||||
WHERE {condition}
|
||||
RETURN n.entity_id AS entity_id
|
||||
"""
|
||||
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
result = await session.run(query, params)
|
||||
return [record["entity_id"] async for record in result]
|
||||
|
||||
def _build_metadata_conditions(
|
||||
self,
|
||||
metadata_filter: MetadataFilter | None,
|
||||
params: dict[str, Any],
|
||||
node_var: str = "n"
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
"""
|
||||
Build Cypher WHERE conditions from a MetadataFilter.
|
||||
|
||||
Args:
|
||||
metadata_filter: The MetadataFilter object
|
||||
params: Dictionary to collect parameters for the query
|
||||
node_var: The variable name for the node in the Cypher query
|
||||
|
||||
Returns:
|
||||
Tuple of (condition_string, updated_params)
|
||||
"""
|
||||
if metadata_filter is None:
|
||||
return "", params
|
||||
|
||||
conditions = []
|
||||
|
||||
for operand in metadata_filter.operands:
|
||||
if isinstance(operand, MetadataFilter):
|
||||
# Recursive call for nested filters
|
||||
sub_condition, params = self._build_metadata_conditions(operand, params, node_var)
|
||||
if sub_condition:
|
||||
conditions.append(f"({sub_condition})")
|
||||
else:
|
||||
# Simple key-value pair
|
||||
for key, value in operand.items():
|
||||
prop_name = f"meta_{key}" # Using our prefix
|
||||
param_name = f"{prop_name}_{len(params)}"
|
||||
|
||||
if value is None:
|
||||
# Check for existence of the key
|
||||
conditions.append(f"{node_var}.{prop_name} IS NOT NULL")
|
||||
else:
|
||||
# Check for specific value
|
||||
conditions.append(f"{node_var}.{prop_name} = ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
if not conditions:
|
||||
return "", params
|
||||
|
||||
# Join conditions with the operator
|
||||
if metadata_filter.operator == "AND":
|
||||
condition = " AND ".join(conditions)
|
||||
elif metadata_filter.operator == "OR":
|
||||
condition = " OR ".join(conditions)
|
||||
elif metadata_filter.operator == "NOT":
|
||||
if len(conditions) == 1:
|
||||
condition = f"NOT ({conditions[0]})"
|
||||
else:
|
||||
condition = f"NOT ({' AND '.join(conditions)})"
|
||||
else:
|
||||
raise ValueError(f"Unknown operator: {metadata_filter.operator}")
|
||||
|
||||
return condition, params
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get node by its label identifier, return only node properties
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ from .base import (
|
|||
QueryParam,
|
||||
QueryResult,
|
||||
QueryContextResult,
|
||||
MetadataFilter
|
||||
)
|
||||
from .prompt import PROMPTS
|
||||
from .constants import (
|
||||
|
|
@ -3568,14 +3569,12 @@ async def _get_node_data(
|
|||
# Extract all entity IDs from your results list
|
||||
node_ids = [r["entity_name"] for r in results]
|
||||
|
||||
# HARDCODED QUERY FOR DEBUGGING
|
||||
filter_query = "WHERE (n.class) is not null AND n.class = 'bando' RETURN (n.entity_id) AS entity_id, n.class AS class_value"
|
||||
|
||||
# TODO update method to take in the metadata_filter dataclass
|
||||
node_kg_ids = []
|
||||
if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"):
|
||||
node_kg_ids = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_by_metadata_filter(filter_query)
|
||||
knowledge_graph_inst.get_nodes_by_metadata_filter(QueryParam.metadata_filter)
|
||||
)
|
||||
|
||||
filtered_node_ids = (
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ dependencies = [
|
|||
"future",
|
||||
"json_repair",
|
||||
"nano-vectordb",
|
||||
"neo4j>=5.28.2",
|
||||
"networkx",
|
||||
"numpy",
|
||||
"pandas>=2.0.0",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue