diff --git a/.gitignore b/.gitignore index cb9f3049..d1857e85 100644 --- a/.gitignore +++ b/.gitignore @@ -65,6 +65,8 @@ LightRAG.pdf download_models_hf.py lightrag-dev/ gui/ +/md +/uv.lock # unit-test files test_* diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 28bed617..e99294b5 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -4,10 +4,11 @@ This module contains all query-related routes for the LightRAG API. import json import logging +import textwrap from typing import Any, Dict, List, Literal, Optional from fastapi import APIRouter, Depends, HTTPException -from lightrag.base import QueryParam +from lightrag.base import QueryParam, MetadataFilter from lightrag.api.utils_api import get_combined_auth_dependency from pydantic import BaseModel, Field, field_validator @@ -22,8 +23,8 @@ class QueryRequest(BaseModel): description="The query text", ) - metadata_filter: dict[str, str] | None = Field( - default=None, + metadata_filter: MetadataFilter | None = Field( + default=textwrap.dedent('{"operator": "AND","operands": [{"MetadataFilter": {}},"string"]},'), description="Optional dictionary of metadata key-value pairs to filter nodes", ) @@ -79,7 +80,7 @@ class QueryRequest(BaseModel): ) conversation_history: Optional[List[Dict[str, Any]]] = Field( - default=None, + default=[], description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", ) diff --git a/lightrag/base.py b/lightrag/base.py index 51d581be..e1c808a3 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -15,6 +15,7 @@ from typing import ( Dict, List, AsyncIterator, + Union, ) from .utils import EmbeddingFunc from .types import KnowledgeGraph @@ -39,6 +40,46 @@ from .constants import ( load_dotenv(dotenv_path=".env", override=False) +@dataclass +class MetadataFilter: + """ + Represents a logical expression for metadata filtering. + + Args: + operator: "AND", "OR", or "NOT" + operands: List of either simple key-value pairs or nested MetadataFilter objects + """ + operator: str + operands: List[Union[Dict[str, Any], 'MetadataFilter']] = None + + def __post_init__(self): + if self.operands is None: + self.operands = [] + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary representation.""" + return { + "operator": self.operator, + "operands": [ + operand.to_dict() if isinstance(operand, MetadataFilter) else operand + for operand in self.operands + ] + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter': + """Create from dictionary representation.""" + operands = [] + for operand in data.get("operands", []): + if isinstance(operand, dict) and "operator" in operand: + operands.append(cls.from_dict(operand)) + else: + operands.append(operand) + return cls(operator=data.get("operator", "AND"), operands=operands) + + + + class OllamaServerInfos: def __init__(self, name=None, tag=None): self._lightrag_name = name or os.getenv( @@ -162,15 +203,15 @@ class QueryParam: Default is True to enable reranking when rerank model is available. """ + metadata_filter: MetadataFilter | None = None + """Metadata for filtering nodes and edges, allowing for more precise querying.""" + include_references: bool = False """If True, includes reference list in the response for supported endpoints. This parameter controls whether the API response includes a references field containing citation information for the retrieved content. """ - metadata_filter: dict | None = None - """Metadata for filtering nodes and edges, allowing for more precise querying.""" - @dataclass class StorageNameSpace(ABC): @@ -447,11 +488,11 @@ class BaseGraphStorage(StorageNameSpace, ABC): or None if the node doesn't exist """ - async def get_nodes_by_metadata_filter(self, metadata_filter: str) -> list[str]: - result = [] - text_query = f"MATCH (n:`{self.workspace_label}` {metadata_filter})" - result = await self.query(text_query) - return result + async def get_nodes_by_metadata_filter(self, metadata_filter: MetadataFilter | None) -> list[str]: + """Get node IDs that match the given metadata filter with logical expressions.""" + # Default implementation - subclasses should override this method + # This is a placeholder that will be overridden by specific implementations + raise NotImplementedError("Subclasses must implement get_nodes_by_metadata_filter") async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """Get nodes as a batch using UNWIND diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 0996f4bc..fb727309 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -2,10 +2,11 @@ import os import re import json from dataclasses import dataclass -from typing import final +from typing import final, Any import configparser + from tenacity import ( retry, stop_after_attempt, @@ -15,7 +16,7 @@ from tenacity import ( import logging from ..utils import logger -from ..base import BaseGraphStorage +from ..base import BaseGraphStorage, MetadataFilter from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock @@ -425,23 +426,88 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure results are consumed even on error raise - async def get_nodes_by_metadata_filter(self, query: str) -> list[str]: - """Get nodes by filtering query, return nodes id""" - + async def get_nodes_by_metadata_filter(self, metadata_filter: MetadataFilter | None) -> list[str]: + """Get node IDs that match the given metadata filter with logical expressions.""" workspace_label = self._get_workspace_label() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - try: - query = f"MATCH (n:`{workspace_label}`) {query}" - result = await session.run(query) - debug_results = [record async for record in result] - debug_ids = [r["entity_id"] for r in debug_results] - await result.consume() - return debug_ids - except Exception as e: - logger.error(f"[{self.workspace}] Error getting node for: {str(e)}") - raise + + # Build metadata conditions + params = {} + condition, params = self._build_metadata_conditions(metadata_filter, params) + + if not condition: + # If no condition, return empty list for safety + return [] + + # Build the query + query = f""" + MATCH (n:`{workspace_label}`) + WHERE {condition} + RETURN n.entity_id AS entity_id + """ + + async with self._driver.session(database=self._DATABASE) as session: + result = await session.run(query, params) + return [record["entity_id"] async for record in result] + + def _build_metadata_conditions( + self, + metadata_filter: MetadataFilter | None, + params: dict[str, Any], + node_var: str = "n" + ) -> tuple[str, dict[str, Any]]: + """ + Build Cypher WHERE conditions from a MetadataFilter. + + Args: + metadata_filter: The MetadataFilter object + params: Dictionary to collect parameters for the query + node_var: The variable name for the node in the Cypher query + + Returns: + Tuple of (condition_string, updated_params) + """ + if metadata_filter is None: + return "", params + + conditions = [] + + for operand in metadata_filter.operands: + if isinstance(operand, MetadataFilter): + # Recursive call for nested filters + sub_condition, params = self._build_metadata_conditions(operand, params, node_var) + if sub_condition: + conditions.append(f"({sub_condition})") + else: + # Simple key-value pair + for key, value in operand.items(): + prop_name = f"meta_{key}" # Using our prefix + param_name = f"{prop_name}_{len(params)}" + + if value is None: + # Check for existence of the key + conditions.append(f"{node_var}.{prop_name} IS NOT NULL") + else: + # Check for specific value + conditions.append(f"{node_var}.{prop_name} = ${param_name}") + params[param_name] = value + + if not conditions: + return "", params + + # Join conditions with the operator + if metadata_filter.operator == "AND": + condition = " AND ".join(conditions) + elif metadata_filter.operator == "OR": + condition = " OR ".join(conditions) + elif metadata_filter.operator == "NOT": + if len(conditions) == 1: + condition = f"NOT ({conditions[0]})" + else: + condition = f"NOT ({' AND '.join(conditions)})" + else: + raise ValueError(f"Unknown operator: {metadata_filter.operator}") + + return condition, params async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties diff --git a/lightrag/operate.py b/lightrag/operate.py index 04e2da85..b62f2135 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -41,6 +41,7 @@ from .base import ( QueryParam, QueryResult, QueryContextResult, + MetadataFilter ) from .prompt import PROMPTS from .constants import ( @@ -3568,14 +3569,12 @@ async def _get_node_data( # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] - # HARDCODED QUERY FOR DEBUGGING - filter_query = "WHERE (n.class) is not null AND n.class = 'bando' RETURN (n.entity_id) AS entity_id, n.class AS class_value" # TODO update method to take in the metadata_filter dataclass node_kg_ids = [] if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"): node_kg_ids = await asyncio.gather( - knowledge_graph_inst.get_nodes_by_metadata_filter(filter_query) + knowledge_graph_inst.get_nodes_by_metadata_filter(QueryParam.metadata_filter) ) filtered_node_ids = ( diff --git a/pyproject.toml b/pyproject.toml index e850ce2c..7e8d70f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "future", "json_repair", "nano-vectordb", + "neo4j>=5.28.2", "networkx", "numpy", "pandas>=2.0.0",