feat (metadata): added metadata filter in query
Added metadata filter dataclass for serializing and deserializing complex filter to json dict, added node filtering based on metadata
This commit is contained in:
parent
7be24a3c60
commit
0c721fa7f1
6 changed files with 143 additions and 33 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -65,6 +65,8 @@ LightRAG.pdf
|
||||||
download_models_hf.py
|
download_models_hf.py
|
||||||
lightrag-dev/
|
lightrag-dev/
|
||||||
gui/
|
gui/
|
||||||
|
/md
|
||||||
|
/uv.lock
|
||||||
|
|
||||||
# unit-test files
|
# unit-test files
|
||||||
test_*
|
test_*
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,11 @@ This module contains all query-related routes for the LightRAG API.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import textwrap
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from lightrag.base import QueryParam
|
from lightrag.base import QueryParam, MetadataFilter
|
||||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
@ -22,8 +23,8 @@ class QueryRequest(BaseModel):
|
||||||
description="The query text",
|
description="The query text",
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_filter: dict[str, str] | None = Field(
|
metadata_filter: MetadataFilter | None = Field(
|
||||||
default=None,
|
default=textwrap.dedent('{"operator": "AND","operands": [{"MetadataFilter": {}},"string"]},'),
|
||||||
description="Optional dictionary of metadata key-value pairs to filter nodes",
|
description="Optional dictionary of metadata key-value pairs to filter nodes",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -79,7 +80,7 @@ class QueryRequest(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_history: Optional[List[Dict[str, Any]]] = Field(
|
conversation_history: Optional[List[Dict[str, Any]]] = Field(
|
||||||
default=None,
|
default=[],
|
||||||
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
|
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
from .utils import EmbeddingFunc
|
from .utils import EmbeddingFunc
|
||||||
from .types import KnowledgeGraph
|
from .types import KnowledgeGraph
|
||||||
|
|
@ -39,6 +40,46 @@ from .constants import (
|
||||||
load_dotenv(dotenv_path=".env", override=False)
|
load_dotenv(dotenv_path=".env", override=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetadataFilter:
|
||||||
|
"""
|
||||||
|
Represents a logical expression for metadata filtering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operator: "AND", "OR", or "NOT"
|
||||||
|
operands: List of either simple key-value pairs or nested MetadataFilter objects
|
||||||
|
"""
|
||||||
|
operator: str
|
||||||
|
operands: List[Union[Dict[str, Any], 'MetadataFilter']] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.operands is None:
|
||||||
|
self.operands = []
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary representation."""
|
||||||
|
return {
|
||||||
|
"operator": self.operator,
|
||||||
|
"operands": [
|
||||||
|
operand.to_dict() if isinstance(operand, MetadataFilter) else operand
|
||||||
|
for operand in self.operands
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter':
|
||||||
|
"""Create from dictionary representation."""
|
||||||
|
operands = []
|
||||||
|
for operand in data.get("operands", []):
|
||||||
|
if isinstance(operand, dict) and "operator" in operand:
|
||||||
|
operands.append(cls.from_dict(operand))
|
||||||
|
else:
|
||||||
|
operands.append(operand)
|
||||||
|
return cls(operator=data.get("operator", "AND"), operands=operands)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaServerInfos:
|
class OllamaServerInfos:
|
||||||
def __init__(self, name=None, tag=None):
|
def __init__(self, name=None, tag=None):
|
||||||
self._lightrag_name = name or os.getenv(
|
self._lightrag_name = name or os.getenv(
|
||||||
|
|
@ -162,15 +203,15 @@ class QueryParam:
|
||||||
Default is True to enable reranking when rerank model is available.
|
Default is True to enable reranking when rerank model is available.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
metadata_filter: MetadataFilter | None = None
|
||||||
|
"""Metadata for filtering nodes and edges, allowing for more precise querying."""
|
||||||
|
|
||||||
include_references: bool = False
|
include_references: bool = False
|
||||||
"""If True, includes reference list in the response for supported endpoints.
|
"""If True, includes reference list in the response for supported endpoints.
|
||||||
This parameter controls whether the API response includes a references field
|
This parameter controls whether the API response includes a references field
|
||||||
containing citation information for the retrieved content.
|
containing citation information for the retrieved content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
metadata_filter: dict | None = None
|
|
||||||
"""Metadata for filtering nodes and edges, allowing for more precise querying."""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StorageNameSpace(ABC):
|
class StorageNameSpace(ABC):
|
||||||
|
|
@ -447,11 +488,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||||
or None if the node doesn't exist
|
or None if the node doesn't exist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def get_nodes_by_metadata_filter(self, metadata_filter: str) -> list[str]:
|
async def get_nodes_by_metadata_filter(self, metadata_filter: MetadataFilter | None) -> list[str]:
|
||||||
result = []
|
"""Get node IDs that match the given metadata filter with logical expressions."""
|
||||||
text_query = f"MATCH (n:`{self.workspace_label}` {metadata_filter})"
|
# Default implementation - subclasses should override this method
|
||||||
result = await self.query(text_query)
|
# This is a placeholder that will be overridden by specific implementations
|
||||||
return result
|
raise NotImplementedError("Subclasses must implement get_nodes_by_metadata_filter")
|
||||||
|
|
||||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||||
"""Get nodes as a batch using UNWIND
|
"""Get nodes as a batch using UNWIND
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,11 @@ import os
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import final
|
from typing import final, Any
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
|
|
@ -15,7 +16,7 @@ from tenacity import (
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage, MetadataFilter
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from ..constants import GRAPH_FIELD_SEP
|
from ..constants import GRAPH_FIELD_SEP
|
||||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
||||||
|
|
@ -425,23 +426,88 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume() # Ensure results are consumed even on error
|
await result.consume() # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_nodes_by_metadata_filter(self, query: str) -> list[str]:
|
async def get_nodes_by_metadata_filter(self, metadata_filter: MetadataFilter | None) -> list[str]:
|
||||||
"""Get nodes by filtering query, return nodes id"""
|
"""Get node IDs that match the given metadata filter with logical expressions."""
|
||||||
|
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
async with self._driver.session(
|
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
# Build metadata conditions
|
||||||
) as session:
|
params = {}
|
||||||
try:
|
condition, params = self._build_metadata_conditions(metadata_filter, params)
|
||||||
query = f"MATCH (n:`{workspace_label}`) {query}"
|
|
||||||
result = await session.run(query)
|
if not condition:
|
||||||
debug_results = [record async for record in result]
|
# If no condition, return empty list for safety
|
||||||
debug_ids = [r["entity_id"] for r in debug_results]
|
return []
|
||||||
await result.consume()
|
|
||||||
return debug_ids
|
# Build the query
|
||||||
except Exception as e:
|
query = f"""
|
||||||
logger.error(f"[{self.workspace}] Error getting node for: {str(e)}")
|
MATCH (n:`{workspace_label}`)
|
||||||
raise
|
WHERE {condition}
|
||||||
|
RETURN n.entity_id AS entity_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
|
result = await session.run(query, params)
|
||||||
|
return [record["entity_id"] async for record in result]
|
||||||
|
|
||||||
|
def _build_metadata_conditions(
|
||||||
|
self,
|
||||||
|
metadata_filter: MetadataFilter | None,
|
||||||
|
params: dict[str, Any],
|
||||||
|
node_var: str = "n"
|
||||||
|
) -> tuple[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Build Cypher WHERE conditions from a MetadataFilter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata_filter: The MetadataFilter object
|
||||||
|
params: Dictionary to collect parameters for the query
|
||||||
|
node_var: The variable name for the node in the Cypher query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (condition_string, updated_params)
|
||||||
|
"""
|
||||||
|
if metadata_filter is None:
|
||||||
|
return "", params
|
||||||
|
|
||||||
|
conditions = []
|
||||||
|
|
||||||
|
for operand in metadata_filter.operands:
|
||||||
|
if isinstance(operand, MetadataFilter):
|
||||||
|
# Recursive call for nested filters
|
||||||
|
sub_condition, params = self._build_metadata_conditions(operand, params, node_var)
|
||||||
|
if sub_condition:
|
||||||
|
conditions.append(f"({sub_condition})")
|
||||||
|
else:
|
||||||
|
# Simple key-value pair
|
||||||
|
for key, value in operand.items():
|
||||||
|
prop_name = f"meta_{key}" # Using our prefix
|
||||||
|
param_name = f"{prop_name}_{len(params)}"
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
# Check for existence of the key
|
||||||
|
conditions.append(f"{node_var}.{prop_name} IS NOT NULL")
|
||||||
|
else:
|
||||||
|
# Check for specific value
|
||||||
|
conditions.append(f"{node_var}.{prop_name} = ${param_name}")
|
||||||
|
params[param_name] = value
|
||||||
|
|
||||||
|
if not conditions:
|
||||||
|
return "", params
|
||||||
|
|
||||||
|
# Join conditions with the operator
|
||||||
|
if metadata_filter.operator == "AND":
|
||||||
|
condition = " AND ".join(conditions)
|
||||||
|
elif metadata_filter.operator == "OR":
|
||||||
|
condition = " OR ".join(conditions)
|
||||||
|
elif metadata_filter.operator == "NOT":
|
||||||
|
if len(conditions) == 1:
|
||||||
|
condition = f"NOT ({conditions[0]})"
|
||||||
|
else:
|
||||||
|
condition = f"NOT ({' AND '.join(conditions)})"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown operator: {metadata_filter.operator}")
|
||||||
|
|
||||||
|
return condition, params
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""Get node by its label identifier, return only node properties
|
"""Get node by its label identifier, return only node properties
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ from .base import (
|
||||||
QueryParam,
|
QueryParam,
|
||||||
QueryResult,
|
QueryResult,
|
||||||
QueryContextResult,
|
QueryContextResult,
|
||||||
|
MetadataFilter
|
||||||
)
|
)
|
||||||
from .prompt import PROMPTS
|
from .prompt import PROMPTS
|
||||||
from .constants import (
|
from .constants import (
|
||||||
|
|
@ -3568,14 +3569,12 @@ async def _get_node_data(
|
||||||
# Extract all entity IDs from your results list
|
# Extract all entity IDs from your results list
|
||||||
node_ids = [r["entity_name"] for r in results]
|
node_ids = [r["entity_name"] for r in results]
|
||||||
|
|
||||||
# HARDCODED QUERY FOR DEBUGGING
|
|
||||||
filter_query = "WHERE (n.class) is not null AND n.class = 'bando' RETURN (n.entity_id) AS entity_id, n.class AS class_value"
|
|
||||||
|
|
||||||
# TODO update method to take in the metadata_filter dataclass
|
# TODO update method to take in the metadata_filter dataclass
|
||||||
node_kg_ids = []
|
node_kg_ids = []
|
||||||
if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"):
|
if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"):
|
||||||
node_kg_ids = await asyncio.gather(
|
node_kg_ids = await asyncio.gather(
|
||||||
knowledge_graph_inst.get_nodes_by_metadata_filter(filter_query)
|
knowledge_graph_inst.get_nodes_by_metadata_filter(QueryParam.metadata_filter)
|
||||||
)
|
)
|
||||||
|
|
||||||
filtered_node_ids = (
|
filtered_node_ids = (
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ dependencies = [
|
||||||
"future",
|
"future",
|
||||||
"json_repair",
|
"json_repair",
|
||||||
"nano-vectordb",
|
"nano-vectordb",
|
||||||
|
"neo4j>=5.28.2",
|
||||||
"networkx",
|
"networkx",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pandas>=2.0.0",
|
"pandas>=2.0.0",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue