feat (metadata filter): added metadata filtering
Added functioning (needs testing) metadata filtering on chunks for query. Fully implemented only on Postgres with pgvector and Neo4j
This commit is contained in:
parent
40afb0441a
commit
d0fba28e1f
5 changed files with 263 additions and 90 deletions
|
|
@ -4,11 +4,11 @@ This module contains all query-related routes for the LightRAG API.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from lightrag.base import QueryParam, MetadataFilter
|
from lightrag.base import QueryParam
|
||||||
|
from lightrag.types import MetadataFilter
|
||||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
@ -24,8 +24,8 @@ class QueryRequest(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_filter: MetadataFilter | None = Field(
|
metadata_filter: MetadataFilter | None = Field(
|
||||||
default=textwrap.dedent('{"operator": "AND","operands": [{"MetadataFilter": {}},"string"]},'),
|
default=None,
|
||||||
description="Optional dictionary of metadata key-value pairs to filter nodes",
|
description="Optional metadata filter for nodes and edges. Can be a MetadataFilter object or a dict that will be converted to MetadataFilter.",
|
||||||
)
|
)
|
||||||
|
|
||||||
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
|
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
|
||||||
|
|
@ -118,6 +118,16 @@ class QueryRequest(BaseModel):
|
||||||
)
|
)
|
||||||
return conversation_history
|
return conversation_history
|
||||||
|
|
||||||
|
@field_validator("metadata_filter", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def metadata_filter_convert(cls, v):
|
||||||
|
"""Convert dict inputs to MetadataFilter objects."""
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if isinstance(v, dict):
|
||||||
|
return MetadataFilter.from_dict(v)
|
||||||
|
return v
|
||||||
|
|
||||||
def to_query_params(self, is_stream: bool) -> "QueryParam":
|
def to_query_params(self, is_stream: bool) -> "QueryParam":
|
||||||
"""Converts a QueryRequest instance into a QueryParam instance."""
|
"""Converts a QueryRequest instance into a QueryParam instance."""
|
||||||
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
|
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
|
||||||
|
|
@ -126,6 +136,11 @@ class QueryRequest(BaseModel):
|
||||||
# Ensure `mode` and `stream` are set explicitly
|
# Ensure `mode` and `stream` are set explicitly
|
||||||
param = QueryParam(**request_data)
|
param = QueryParam(**request_data)
|
||||||
param.stream = is_stream
|
param.stream = is_stream
|
||||||
|
|
||||||
|
# Ensure metadata_filter remains as MetadataFilter object if it exists
|
||||||
|
if self.metadata_filter:
|
||||||
|
param.metadata_filter = self.metadata_filter
|
||||||
|
|
||||||
return param
|
return param
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -175,10 +190,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
try:
|
try:
|
||||||
param = request.to_query_params(False)
|
param = request.to_query_params(False)
|
||||||
|
|
||||||
# Inject metadata_filter into param if present
|
|
||||||
if request.metadata_filter:
|
|
||||||
setattr(param, "metadata_filter", request.metadata_filter)
|
|
||||||
|
|
||||||
response = await rag.aquery(request.query, param=param)
|
response = await rag.aquery(request.query, param=param)
|
||||||
|
|
||||||
# Get reference list if requested
|
# Get reference list if requested
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from .utils import EmbeddingFunc
|
from .utils import EmbeddingFunc
|
||||||
from .types import KnowledgeGraph
|
from .types import KnowledgeGraph, MetadataFilter
|
||||||
from .constants import (
|
from .constants import (
|
||||||
GRAPH_FIELD_SEP,
|
GRAPH_FIELD_SEP,
|
||||||
DEFAULT_TOP_K,
|
DEFAULT_TOP_K,
|
||||||
|
|
@ -40,42 +40,7 @@ from .constants import (
|
||||||
load_dotenv(dotenv_path=".env", override=False)
|
load_dotenv(dotenv_path=".env", override=False)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MetadataFilter:
|
|
||||||
"""
|
|
||||||
Represents a logical expression for metadata filtering.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operator: "AND", "OR", or "NOT"
|
|
||||||
operands: List of either simple key-value pairs or nested MetadataFilter objects
|
|
||||||
"""
|
|
||||||
operator: str
|
|
||||||
operands: List[Union[Dict[str, Any], 'MetadataFilter']] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.operands is None:
|
|
||||||
self.operands = []
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""Convert to dictionary representation."""
|
|
||||||
return {
|
|
||||||
"operator": self.operator,
|
|
||||||
"operands": [
|
|
||||||
operand.to_dict() if isinstance(operand, MetadataFilter) else operand
|
|
||||||
for operand in self.operands
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter':
|
|
||||||
"""Create from dictionary representation."""
|
|
||||||
operands = []
|
|
||||||
for operand in data.get("operands", []):
|
|
||||||
if isinstance(operand, dict) and "operator" in operand:
|
|
||||||
operands.append(cls.from_dict(operand))
|
|
||||||
else:
|
|
||||||
operands.append(operand)
|
|
||||||
return cls(operator=data.get("operator", "AND"), operands=operands)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -266,7 +231,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query(
|
async def query(
|
||||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: MetadataFilter | None = None
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Query the vector storage and retrieve top_k results.
|
"""Query the vector storage and retrieve top_k results.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import configparser
|
||||||
import ssl
|
import ssl
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge, MetadataFilter
|
||||||
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
|
|
@ -931,6 +931,13 @@ class PostgreSQLDB:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}"
|
f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}"
|
||||||
)
|
)
|
||||||
|
# Compatibility check - add metadata columns to LIGHTRAG_DOC_CHUNKS and LIGHTRAG_VDB_CHUNKS
|
||||||
|
try:
|
||||||
|
await self.add_metadata_to_tables()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"PostgreSQL, Failed to add metadata columns to existing tables: {e}")
|
||||||
|
|
||||||
|
|
||||||
# After all tables are created, attempt to migrate timestamp fields
|
# After all tables are created, attempt to migrate timestamp fields
|
||||||
try:
|
try:
|
||||||
await self._migrate_timestamp_columns()
|
await self._migrate_timestamp_columns()
|
||||||
|
|
@ -1010,6 +1017,54 @@ class PostgreSQLDB:
|
||||||
f"PostgreSQL, Failed to create full entities/relations tables: {e}"
|
f"PostgreSQL, Failed to create full entities/relations tables: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def add_metadata_to_tables(self):
|
||||||
|
"""Add metadata columns to LIGHTRAG_DOC_CHUNKS and LIGHTRAG_VDB_CHUNKS tables if they don't exist"""
|
||||||
|
tables_to_check = [
|
||||||
|
{
|
||||||
|
"name": "LIGHTRAG_DOC_CHUNKS",
|
||||||
|
"description": "Document chunks storage table",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "LIGHTRAG_VDB_CHUNKS",
|
||||||
|
"description": "Vector database chunks storage table",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for table_info in tables_to_check:
|
||||||
|
table_name = table_info["name"]
|
||||||
|
try:
|
||||||
|
# Check if metadata column exists
|
||||||
|
check_column_sql = """
|
||||||
|
SELECT column_name
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_name = $1
|
||||||
|
AND column_name = 'metadata'
|
||||||
|
"""
|
||||||
|
|
||||||
|
column_info = await self.query(
|
||||||
|
check_column_sql, {"table_name": table_name.lower()}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not column_info:
|
||||||
|
logger.info(f"Adding metadata column to {table_name} table")
|
||||||
|
add_column_sql = f"""
|
||||||
|
ALTER TABLE {table_name}
|
||||||
|
ADD COLUMN metadata JSONB NULL DEFAULT '{{}}'::jsonb
|
||||||
|
"""
|
||||||
|
await self.execute(add_column_sql)
|
||||||
|
logger.info(
|
||||||
|
f"Successfully added metadata column to {table_name} table"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"metadata column already exists in {table_name} table"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to add metadata column to {table_name}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
async def _migrate_create_full_entities_relations_tables(self):
|
async def _migrate_create_full_entities_relations_tables(self):
|
||||||
"""Create LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS tables if they don't exist"""
|
"""Create LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS tables if they don't exist"""
|
||||||
tables_to_check = [
|
tables_to_check = [
|
||||||
|
|
@ -1742,6 +1797,7 @@ class PGKVStorage(BaseKVStorage):
|
||||||
"llm_cache_list": json.dumps(v.get("llm_cache_list", [])),
|
"llm_cache_list": json.dumps(v.get("llm_cache_list", [])),
|
||||||
"create_time": current_time,
|
"create_time": current_time,
|
||||||
"update_time": current_time,
|
"update_time": current_time,
|
||||||
|
"metadata": json.dumps(v.get("metadata", {})),
|
||||||
}
|
}
|
||||||
await self.db.execute(upsert_sql, _data)
|
await self.db.execute(upsert_sql, _data)
|
||||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
||||||
|
|
@ -1908,6 +1964,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
"file_path": item["file_path"],
|
"file_path": item["file_path"],
|
||||||
"create_time": current_time,
|
"create_time": current_time,
|
||||||
"update_time": current_time,
|
"update_time": current_time,
|
||||||
|
"metadata": json.dumps(item.get("metadata", {})),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -2002,9 +2059,71 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
|
|
||||||
await self.db.execute(upsert_sql, data)
|
await self.db.execute(upsert_sql, data)
|
||||||
|
|
||||||
|
############# Metadata building function #################
|
||||||
|
@staticmethod
|
||||||
|
def build_metadata_filter_clause(metadata_filter):
|
||||||
|
def escape_str(val: str) -> str:
|
||||||
|
return str(val).replace("'", "''") # escape single quotes
|
||||||
|
|
||||||
|
def build_single_condition(key, value):
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
json_value = json.dumps(value).replace("'", "''")
|
||||||
|
return f"metadata->'{key}' = '{json_value}'::jsonb"
|
||||||
|
elif isinstance(value, (int, float, bool)):
|
||||||
|
return f"metadata->'{key}' = '{json.dumps(value)}'::jsonb"
|
||||||
|
else: # string
|
||||||
|
return f"metadata->>'{key}' = '{escape_str(value)}'"
|
||||||
|
|
||||||
|
def build_conditions(filter_dict):
|
||||||
|
conditions = []
|
||||||
|
for key, value in filter_dict.items():
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
conds = [build_single_condition(key, v) for v in value]
|
||||||
|
conditions.append("(" + " OR ".join(conds) + ")")
|
||||||
|
else:
|
||||||
|
conditions.append(build_single_condition(key, value))
|
||||||
|
return conditions
|
||||||
|
|
||||||
|
def recurse(filter_obj):
|
||||||
|
if isinstance(filter_obj, dict):
|
||||||
|
return build_conditions(filter_obj)
|
||||||
|
|
||||||
|
if isinstance(filter_obj, MetadataFilter):
|
||||||
|
sub_conditions = []
|
||||||
|
for operand in filter_obj.operands:
|
||||||
|
if isinstance(operand, dict):
|
||||||
|
sub_conditions.append("(" + " AND ".join(build_conditions(operand)) + ")")
|
||||||
|
elif isinstance(operand, MetadataFilter):
|
||||||
|
nested = recurse(operand)
|
||||||
|
if nested:
|
||||||
|
sub_conditions.append("(" + " AND ".join(nested) + ")")
|
||||||
|
|
||||||
|
if not sub_conditions:
|
||||||
|
return []
|
||||||
|
|
||||||
|
op = filter_obj.operator.upper()
|
||||||
|
if op == "AND":
|
||||||
|
return [" AND ".join(sub_conditions)]
|
||||||
|
elif op == "OR":
|
||||||
|
return [" OR ".join(sub_conditions)]
|
||||||
|
elif op == "NOT":
|
||||||
|
if len(sub_conditions) == 1:
|
||||||
|
return [f"NOT {sub_conditions[0]}"]
|
||||||
|
else:
|
||||||
|
return [f"NOT ({' AND '.join(sub_conditions)})"]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
conditions = recurse(metadata_filter)
|
||||||
|
clause = ""
|
||||||
|
if conditions:
|
||||||
|
clause = " AND " + " AND ".join(conditions)
|
||||||
|
|
||||||
|
return clause
|
||||||
|
|
||||||
#################### query method ###############
|
#################### query method ###############
|
||||||
async def query(
|
async def query(
|
||||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: MetadataFilter | None = None
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
if query_embedding is not None:
|
if query_embedding is not None:
|
||||||
embedding = query_embedding
|
embedding = query_embedding
|
||||||
|
|
@ -2015,8 +2134,8 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
|
|
||||||
embedding_string = ",".join(map(str, embedding))
|
embedding_string = ",".join(map(str, embedding))
|
||||||
|
metadata_filter_clause = self.build_metadata_filter_clause(metadata_filter)
|
||||||
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
|
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string, metadata_filter_clause=metadata_filter_clause)
|
||||||
params = {
|
params = {
|
||||||
"workspace": self.workspace,
|
"workspace": self.workspace,
|
||||||
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
|
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
|
||||||
|
|
@ -4480,6 +4599,7 @@ TABLES = {
|
||||||
content TEXT,
|
content TEXT,
|
||||||
file_path TEXT NULL,
|
file_path TEXT NULL,
|
||||||
llm_cache_list JSONB NULL DEFAULT '[]'::jsonb,
|
llm_cache_list JSONB NULL DEFAULT '[]'::jsonb,
|
||||||
|
metadata JSONB NULL DEFAULT '{}'::jsonb,
|
||||||
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||||
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||||
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
|
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
|
||||||
|
|
@ -4495,6 +4615,7 @@ TABLES = {
|
||||||
content TEXT,
|
content TEXT,
|
||||||
content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
|
content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
|
||||||
file_path TEXT NULL,
|
file_path TEXT NULL,
|
||||||
|
metadata JSONB NULL DEFAULT '{{}}'::jsonb,
|
||||||
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||||
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||||
CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id)
|
CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id)
|
||||||
|
|
@ -4656,8 +4777,8 @@ SQL_TEMPLATES = {
|
||||||
""",
|
""",
|
||||||
"upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
|
"upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
|
||||||
chunk_order_index, full_doc_id, content, file_path, llm_cache_list,
|
chunk_order_index, full_doc_id, content, file_path, llm_cache_list,
|
||||||
create_time, update_time)
|
create_time, update_time, metadata)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||||
ON CONFLICT (workspace,id) DO UPDATE
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
SET tokens=EXCLUDED.tokens,
|
SET tokens=EXCLUDED.tokens,
|
||||||
chunk_order_index=EXCLUDED.chunk_order_index,
|
chunk_order_index=EXCLUDED.chunk_order_index,
|
||||||
|
|
@ -4665,7 +4786,8 @@ SQL_TEMPLATES = {
|
||||||
content = EXCLUDED.content,
|
content = EXCLUDED.content,
|
||||||
file_path=EXCLUDED.file_path,
|
file_path=EXCLUDED.file_path,
|
||||||
llm_cache_list=EXCLUDED.llm_cache_list,
|
llm_cache_list=EXCLUDED.llm_cache_list,
|
||||||
update_time = EXCLUDED.update_time
|
update_time = EXCLUDED.update_time,
|
||||||
|
metadata = EXCLUDED.metadata
|
||||||
""",
|
""",
|
||||||
"upsert_full_entities": """INSERT INTO LIGHTRAG_FULL_ENTITIES (workspace, id, entity_names, count,
|
"upsert_full_entities": """INSERT INTO LIGHTRAG_FULL_ENTITIES (workspace, id, entity_names, count,
|
||||||
create_time, update_time)
|
create_time, update_time)
|
||||||
|
|
@ -4686,8 +4808,8 @@ SQL_TEMPLATES = {
|
||||||
# SQL for VectorStorage
|
# SQL for VectorStorage
|
||||||
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
|
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
|
||||||
chunk_order_index, full_doc_id, content, content_vector, file_path,
|
chunk_order_index, full_doc_id, content, content_vector, file_path,
|
||||||
create_time, update_time)
|
create_time, update_time, metadata)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||||
ON CONFLICT (workspace,id) DO UPDATE
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
SET tokens=EXCLUDED.tokens,
|
SET tokens=EXCLUDED.tokens,
|
||||||
chunk_order_index=EXCLUDED.chunk_order_index,
|
chunk_order_index=EXCLUDED.chunk_order_index,
|
||||||
|
|
@ -4695,7 +4817,8 @@ SQL_TEMPLATES = {
|
||||||
content = EXCLUDED.content,
|
content = EXCLUDED.content,
|
||||||
content_vector=EXCLUDED.content_vector,
|
content_vector=EXCLUDED.content_vector,
|
||||||
file_path=EXCLUDED.file_path,
|
file_path=EXCLUDED.file_path,
|
||||||
update_time = EXCLUDED.update_time
|
update_time = EXCLUDED.update_time,
|
||||||
|
metadata = EXCLUDED.metadata
|
||||||
""",
|
""",
|
||||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
||||||
content_vector, chunk_ids, file_path, create_time, update_time)
|
content_vector, chunk_ids, file_path, create_time, update_time)
|
||||||
|
|
@ -4721,32 +4844,69 @@ SQL_TEMPLATES = {
|
||||||
update_time = EXCLUDED.update_time
|
update_time = EXCLUDED.update_time
|
||||||
""",
|
""",
|
||||||
"relationships": """
|
"relationships": """
|
||||||
|
WITH relevant_chunks AS (
|
||||||
|
SELECT id as chunk_id
|
||||||
|
FROM LIGHTRAG_VDB_CHUNKS
|
||||||
|
WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[]))
|
||||||
|
{metadata_filter_clause}
|
||||||
|
),
|
||||||
|
rc AS (
|
||||||
|
SELECT array_agg(chunk_id) AS chunk_arr
|
||||||
|
FROM relevant_chunks
|
||||||
|
)
|
||||||
SELECT r.source_id AS src_id,
|
SELECT r.source_id AS src_id,
|
||||||
r.target_id AS tgt_id,
|
r.target_id AS tgt_id,
|
||||||
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
|
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
|
||||||
FROM LIGHTRAG_VDB_RELATION r
|
FROM LIGHTRAG_VDB_RELATION r
|
||||||
|
JOIN rc ON TRUE
|
||||||
WHERE r.workspace = $1
|
WHERE r.workspace = $1
|
||||||
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
|
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||||
|
AND r.chunk_ids && (rc.chunk_arr::varchar[])
|
||||||
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
|
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
|
||||||
LIMIT $3;
|
LIMIT $3;
|
||||||
""",
|
""",
|
||||||
"entities": """
|
"entities": """
|
||||||
|
WITH relevant_chunks AS (
|
||||||
|
SELECT id as chunk_id
|
||||||
|
FROM LIGHTRAG_VDB_CHUNKS
|
||||||
|
WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[]))
|
||||||
|
{metadata_filter_clause}
|
||||||
|
),
|
||||||
|
rc AS (
|
||||||
|
SELECT array_agg(chunk_id) AS chunk_arr
|
||||||
|
FROM relevant_chunks
|
||||||
|
)
|
||||||
SELECT e.entity_name,
|
SELECT e.entity_name,
|
||||||
EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
|
EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
|
||||||
FROM LIGHTRAG_VDB_ENTITY e
|
FROM LIGHTRAG_VDB_ENTITY e
|
||||||
|
JOIN rc ON TRUE
|
||||||
WHERE e.workspace = $1
|
WHERE e.workspace = $1
|
||||||
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
|
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||||
|
AND e.chunk_ids && (rc.chunk_arr::varchar[])
|
||||||
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
|
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
|
||||||
LIMIT $3;
|
LIMIT $3;
|
||||||
""",
|
""",
|
||||||
"chunks": """
|
"chunks": """
|
||||||
|
WITH relevant_chunks AS (
|
||||||
|
SELECT id as chunk_id
|
||||||
|
FROM LIGHTRAG_VDB_CHUNKS
|
||||||
|
WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[]))
|
||||||
|
{metadata_filter_clause}
|
||||||
|
),
|
||||||
|
rc AS (
|
||||||
|
SELECT array_agg(chunk_id) AS chunk_arr
|
||||||
|
FROM relevant_chunks
|
||||||
|
)
|
||||||
SELECT c.id,
|
SELECT c.id,
|
||||||
c.content,
|
c.content,
|
||||||
c.file_path,
|
c.file_path,
|
||||||
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
|
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at,
|
||||||
|
c.metadata
|
||||||
FROM LIGHTRAG_VDB_CHUNKS c
|
FROM LIGHTRAG_VDB_CHUNKS c
|
||||||
|
JOIN rc ON TRUE
|
||||||
WHERE c.workspace = $1
|
WHERE c.workspace = $1
|
||||||
AND c.content_vector <=> '[{embedding_string}]'::vector < $2
|
AND c.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||||
|
AND c.id = ANY (rc.chunk_arr)
|
||||||
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
|
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
|
||||||
LIMIT $3;
|
LIMIT $3;
|
||||||
""",
|
""",
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,6 @@ from .base import (
|
||||||
QueryParam,
|
QueryParam,
|
||||||
QueryResult,
|
QueryResult,
|
||||||
QueryContextResult,
|
QueryContextResult,
|
||||||
MetadataFilter
|
|
||||||
)
|
)
|
||||||
from .prompt import PROMPTS
|
from .prompt import PROMPTS
|
||||||
from .constants import (
|
from .constants import (
|
||||||
|
|
@ -2409,6 +2408,7 @@ async def kg_query(
|
||||||
query_param.ll_keywords or [],
|
query_param.ll_keywords or [],
|
||||||
query_param.user_prompt or "",
|
query_param.user_prompt or "",
|
||||||
query_param.enable_rerank,
|
query_param.enable_rerank,
|
||||||
|
query_param.metadata_filter,
|
||||||
)
|
)
|
||||||
cached_result = await handle_cache(
|
cached_result = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
|
|
@ -2720,7 +2720,7 @@ async def _get_vector_context(
|
||||||
cosine_threshold = chunks_vdb.cosine_better_than_threshold
|
cosine_threshold = chunks_vdb.cosine_better_than_threshold
|
||||||
|
|
||||||
results = await chunks_vdb.query(
|
results = await chunks_vdb.query(
|
||||||
query, top_k=search_top_k, query_embedding=query_embedding
|
query, top_k=search_top_k, query_embedding=query_embedding, metadata_filter=query_param.metadata_filter
|
||||||
)
|
)
|
||||||
if not results:
|
if not results:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -2737,6 +2737,7 @@ async def _get_vector_context(
|
||||||
"file_path": result.get("file_path", "unknown_source"),
|
"file_path": result.get("file_path", "unknown_source"),
|
||||||
"source_type": "vector", # Mark the source type
|
"source_type": "vector", # Mark the source type
|
||||||
"chunk_id": result.get("id"), # Add chunk_id for deduplication
|
"chunk_id": result.get("id"), # Add chunk_id for deduplication
|
||||||
|
"metadata": result.get("metadata")
|
||||||
}
|
}
|
||||||
valid_chunks.append(chunk_with_metadata)
|
valid_chunks.append(chunk_with_metadata)
|
||||||
|
|
||||||
|
|
@ -3561,7 +3562,8 @@ async def _get_node_data(
|
||||||
f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})"
|
f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})"
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
|
||||||
|
results = await entities_vdb.query(query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
|
||||||
|
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return [], []
|
return [], []
|
||||||
|
|
@ -3570,21 +3572,13 @@ async def _get_node_data(
|
||||||
node_ids = [r["entity_name"] for r in results]
|
node_ids = [r["entity_name"] for r in results]
|
||||||
|
|
||||||
|
|
||||||
# TODO update method to take in the metadata_filter dataclass
|
# Extract all entity IDs from your results list
|
||||||
node_kg_ids = []
|
node_ids = [r["entity_name"] for r in results]
|
||||||
if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"):
|
|
||||||
node_kg_ids = await asyncio.gather(
|
|
||||||
knowledge_graph_inst.get_nodes_by_metadata_filter(QueryParam.metadata_filter)
|
|
||||||
)
|
|
||||||
|
|
||||||
filtered_node_ids = (
|
|
||||||
[nid for nid in node_ids if nid in node_kg_ids] if node_kg_ids else node_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call the batch node retrieval and degree functions concurrently.
|
# Call the batch node retrieval and degree functions concurrently.
|
||||||
nodes_dict, degrees_dict = await asyncio.gather(
|
nodes_dict, degrees_dict = await asyncio.gather(
|
||||||
knowledge_graph_inst.get_nodes_batch(filtered_node_ids),
|
knowledge_graph_inst.get_nodes_batch(node_ids),
|
||||||
knowledge_graph_inst.node_degrees_batch(filtered_node_ids),
|
knowledge_graph_inst.node_degrees_batch(node_ids),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now, if you need the node data and degree in order:
|
# Now, if you need the node data and degree in order:
|
||||||
|
|
@ -3849,7 +3843,7 @@ async def _get_edge_data(
|
||||||
f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})"
|
f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})"
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
results = await relationships_vdb.query(keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
|
||||||
|
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return [], []
|
return [], []
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field, validator
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, List, Union, Dict
|
||||||
|
|
||||||
|
|
||||||
class GPTKeywordExtractionFormat(BaseModel):
|
class GPTKeywordExtractionFormat(BaseModel):
|
||||||
|
|
@ -27,3 +27,46 @@ class KnowledgeGraph(BaseModel):
|
||||||
nodes: list[KnowledgeGraphNode] = []
|
nodes: list[KnowledgeGraphNode] = []
|
||||||
edges: list[KnowledgeGraphEdge] = []
|
edges: list[KnowledgeGraphEdge] = []
|
||||||
is_truncated: bool = False
|
is_truncated: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataFilter(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents a logical expression for metadata filtering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operator: "AND", "OR", or "NOT"
|
||||||
|
operands: List of either simple key-value pairs or nested MetadataFilter objects
|
||||||
|
"""
|
||||||
|
operator: str = Field(..., description="Logical operator: AND, OR, or NOT")
|
||||||
|
operands: List[Union[Dict[str, Any], 'MetadataFilter']] = Field(default_factory=list, description="List of operands for filtering")
|
||||||
|
|
||||||
|
@validator('operator')
|
||||||
|
def validate_operator(cls, v):
|
||||||
|
if v not in ["AND", "OR", "NOT"]:
|
||||||
|
raise ValueError('operator must be one of: "AND", "OR", "NOT"')
|
||||||
|
return v
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary representation."""
|
||||||
|
return {
|
||||||
|
"operator": self.operator,
|
||||||
|
"operands": [
|
||||||
|
operand.dict() if isinstance(operand, MetadataFilter) else operand
|
||||||
|
for operand in self.operands
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter':
|
||||||
|
"""Create from dictionary representation."""
|
||||||
|
operands = []
|
||||||
|
for operand in data.get("operands", []):
|
||||||
|
if isinstance(operand, dict) and "operator" in operand:
|
||||||
|
operands.append(cls.from_dict(operand))
|
||||||
|
else:
|
||||||
|
operands.append(operand)
|
||||||
|
return cls(operator=data.get("operator", "AND"), operands=operands)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Pydantic configuration."""
|
||||||
|
validate_assignment = True
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue