feat (metadata filter): added metadata filtering
Added functioning (needs testing) metadata filtering on chunks for query. Fully implemented only on Postgres with pgvector and Neo4j
This commit is contained in:
parent
40afb0441a
commit
d0fba28e1f
5 changed files with 263 additions and 90 deletions
|
|
@ -4,11 +4,11 @@ This module contains all query-related routes for the LightRAG API.
|
|||
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from lightrag.base import QueryParam, MetadataFilter
|
||||
from lightrag.base import QueryParam
|
||||
from lightrag.types import MetadataFilter
|
||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
|
@ -24,8 +24,8 @@ class QueryRequest(BaseModel):
|
|||
)
|
||||
|
||||
metadata_filter: MetadataFilter | None = Field(
|
||||
default=textwrap.dedent('{"operator": "AND","operands": [{"MetadataFilter": {}},"string"]},'),
|
||||
description="Optional dictionary of metadata key-value pairs to filter nodes",
|
||||
default=None,
|
||||
description="Optional metadata filter for nodes and edges. Can be a MetadataFilter object or a dict that will be converted to MetadataFilter.",
|
||||
)
|
||||
|
||||
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
|
||||
|
|
@ -118,6 +118,16 @@ class QueryRequest(BaseModel):
|
|||
)
|
||||
return conversation_history
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def metadata_filter_convert(cls, v):
|
||||
"""Convert dict inputs to MetadataFilter objects."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, dict):
|
||||
return MetadataFilter.from_dict(v)
|
||||
return v
|
||||
|
||||
def to_query_params(self, is_stream: bool) -> "QueryParam":
|
||||
"""Converts a QueryRequest instance into a QueryParam instance."""
|
||||
# Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
|
||||
|
|
@ -126,6 +136,11 @@ class QueryRequest(BaseModel):
|
|||
# Ensure `mode` and `stream` are set explicitly
|
||||
param = QueryParam(**request_data)
|
||||
param.stream = is_stream
|
||||
|
||||
# Ensure metadata_filter remains as MetadataFilter object if it exists
|
||||
if self.metadata_filter:
|
||||
param.metadata_filter = self.metadata_filter
|
||||
|
||||
return param
|
||||
|
||||
|
||||
|
|
@ -175,10 +190,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
try:
|
||||
param = request.to_query_params(False)
|
||||
|
||||
# Inject metadata_filter into param if present
|
||||
if request.metadata_filter:
|
||||
setattr(param, "metadata_filter", request.metadata_filter)
|
||||
|
||||
response = await rag.aquery(request.query, param=param)
|
||||
|
||||
# Get reference list if requested
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
from .utils import EmbeddingFunc
|
||||
from .types import KnowledgeGraph
|
||||
from .types import KnowledgeGraph, MetadataFilter
|
||||
from .constants import (
|
||||
GRAPH_FIELD_SEP,
|
||||
DEFAULT_TOP_K,
|
||||
|
|
@ -40,42 +40,7 @@ from .constants import (
|
|||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataFilter:
|
||||
"""
|
||||
Represents a logical expression for metadata filtering.
|
||||
|
||||
Args:
|
||||
operator: "AND", "OR", or "NOT"
|
||||
operands: List of either simple key-value pairs or nested MetadataFilter objects
|
||||
"""
|
||||
operator: str
|
||||
operands: List[Union[Dict[str, Any], 'MetadataFilter']] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.operands is None:
|
||||
self.operands = []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary representation."""
|
||||
return {
|
||||
"operator": self.operator,
|
||||
"operands": [
|
||||
operand.to_dict() if isinstance(operand, MetadataFilter) else operand
|
||||
for operand in self.operands
|
||||
]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter':
|
||||
"""Create from dictionary representation."""
|
||||
operands = []
|
||||
for operand in data.get("operands", []):
|
||||
if isinstance(operand, dict) and "operator" in operand:
|
||||
operands.append(cls.from_dict(operand))
|
||||
else:
|
||||
operands.append(operand)
|
||||
return cls(operator=data.get("operator", "AND"), operands=operands)
|
||||
|
||||
|
||||
|
||||
|
|
@ -266,7 +231,7 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def query(
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: MetadataFilter | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Query the vector storage and retrieve top_k results.
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import configparser
|
|||
import ssl
|
||||
import itertools
|
||||
|
||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge, MetadataFilter
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
|
|
@ -931,6 +931,13 @@ class PostgreSQLDB:
|
|||
logger.error(
|
||||
f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}"
|
||||
)
|
||||
# Compatibility check - add metadata columns to LIGHTRAG_DOC_CHUNKS and LIGHTRAG_VDB_CHUNKS
|
||||
try:
|
||||
await self.add_metadata_to_tables()
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL, Failed to add metadata columns to existing tables: {e}")
|
||||
|
||||
|
||||
# After all tables are created, attempt to migrate timestamp fields
|
||||
try:
|
||||
await self._migrate_timestamp_columns()
|
||||
|
|
@ -1010,6 +1017,54 @@ class PostgreSQLDB:
|
|||
f"PostgreSQL, Failed to create full entities/relations tables: {e}"
|
||||
)
|
||||
|
||||
async def add_metadata_to_tables(self):
|
||||
"""Add metadata columns to LIGHTRAG_DOC_CHUNKS and LIGHTRAG_VDB_CHUNKS tables if they don't exist"""
|
||||
tables_to_check = [
|
||||
{
|
||||
"name": "LIGHTRAG_DOC_CHUNKS",
|
||||
"description": "Document chunks storage table",
|
||||
},
|
||||
{
|
||||
"name": "LIGHTRAG_VDB_CHUNKS",
|
||||
"description": "Vector database chunks storage table",
|
||||
},
|
||||
]
|
||||
|
||||
for table_info in tables_to_check:
|
||||
table_name = table_info["name"]
|
||||
try:
|
||||
# Check if metadata column exists
|
||||
check_column_sql = """
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = $1
|
||||
AND column_name = 'metadata'
|
||||
"""
|
||||
|
||||
column_info = await self.query(
|
||||
check_column_sql, {"table_name": table_name.lower()}
|
||||
)
|
||||
|
||||
if not column_info:
|
||||
logger.info(f"Adding metadata column to {table_name} table")
|
||||
add_column_sql = f"""
|
||||
ALTER TABLE {table_name}
|
||||
ADD COLUMN metadata JSONB NULL DEFAULT '{{}}'::jsonb
|
||||
"""
|
||||
await self.execute(add_column_sql)
|
||||
logger.info(
|
||||
f"Successfully added metadata column to {table_name} table"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"metadata column already exists in {table_name} table"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to add metadata column to {table_name}: {e}"
|
||||
)
|
||||
|
||||
async def _migrate_create_full_entities_relations_tables(self):
|
||||
"""Create LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS tables if they don't exist"""
|
||||
tables_to_check = [
|
||||
|
|
@ -1742,6 +1797,7 @@ class PGKVStorage(BaseKVStorage):
|
|||
"llm_cache_list": json.dumps(v.get("llm_cache_list", [])),
|
||||
"create_time": current_time,
|
||||
"update_time": current_time,
|
||||
"metadata": json.dumps(v.get("metadata", {})),
|
||||
}
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
||||
|
|
@ -1908,6 +1964,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
"file_path": item["file_path"],
|
||||
"create_time": current_time,
|
||||
"update_time": current_time,
|
||||
"metadata": json.dumps(item.get("metadata", {})),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
@ -2002,9 +2059,71 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
|
||||
await self.db.execute(upsert_sql, data)
|
||||
|
||||
############# Metadata building function #################
|
||||
@staticmethod
|
||||
def build_metadata_filter_clause(metadata_filter):
|
||||
def escape_str(val: str) -> str:
|
||||
return str(val).replace("'", "''") # escape single quotes
|
||||
|
||||
def build_single_condition(key, value):
|
||||
if isinstance(value, (dict, list)):
|
||||
json_value = json.dumps(value).replace("'", "''")
|
||||
return f"metadata->'{key}' = '{json_value}'::jsonb"
|
||||
elif isinstance(value, (int, float, bool)):
|
||||
return f"metadata->'{key}' = '{json.dumps(value)}'::jsonb"
|
||||
else: # string
|
||||
return f"metadata->>'{key}' = '{escape_str(value)}'"
|
||||
|
||||
def build_conditions(filter_dict):
|
||||
conditions = []
|
||||
for key, value in filter_dict.items():
|
||||
if isinstance(value, (list, tuple)):
|
||||
conds = [build_single_condition(key, v) for v in value]
|
||||
conditions.append("(" + " OR ".join(conds) + ")")
|
||||
else:
|
||||
conditions.append(build_single_condition(key, value))
|
||||
return conditions
|
||||
|
||||
def recurse(filter_obj):
|
||||
if isinstance(filter_obj, dict):
|
||||
return build_conditions(filter_obj)
|
||||
|
||||
if isinstance(filter_obj, MetadataFilter):
|
||||
sub_conditions = []
|
||||
for operand in filter_obj.operands:
|
||||
if isinstance(operand, dict):
|
||||
sub_conditions.append("(" + " AND ".join(build_conditions(operand)) + ")")
|
||||
elif isinstance(operand, MetadataFilter):
|
||||
nested = recurse(operand)
|
||||
if nested:
|
||||
sub_conditions.append("(" + " AND ".join(nested) + ")")
|
||||
|
||||
if not sub_conditions:
|
||||
return []
|
||||
|
||||
op = filter_obj.operator.upper()
|
||||
if op == "AND":
|
||||
return [" AND ".join(sub_conditions)]
|
||||
elif op == "OR":
|
||||
return [" OR ".join(sub_conditions)]
|
||||
elif op == "NOT":
|
||||
if len(sub_conditions) == 1:
|
||||
return [f"NOT {sub_conditions[0]}"]
|
||||
else:
|
||||
return [f"NOT ({' AND '.join(sub_conditions)})"]
|
||||
|
||||
return []
|
||||
|
||||
conditions = recurse(metadata_filter)
|
||||
clause = ""
|
||||
if conditions:
|
||||
clause = " AND " + " AND ".join(conditions)
|
||||
|
||||
return clause
|
||||
|
||||
#################### query method ###############
|
||||
async def query(
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: MetadataFilter | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
if query_embedding is not None:
|
||||
embedding = query_embedding
|
||||
|
|
@ -2015,8 +2134,8 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
embedding = embeddings[0]
|
||||
|
||||
embedding_string = ",".join(map(str, embedding))
|
||||
|
||||
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
|
||||
metadata_filter_clause = self.build_metadata_filter_clause(metadata_filter)
|
||||
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string, metadata_filter_clause=metadata_filter_clause)
|
||||
params = {
|
||||
"workspace": self.workspace,
|
||||
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
|
||||
|
|
@ -4480,6 +4599,7 @@ TABLES = {
|
|||
content TEXT,
|
||||
file_path TEXT NULL,
|
||||
llm_cache_list JSONB NULL DEFAULT '[]'::jsonb,
|
||||
metadata JSONB NULL DEFAULT '{}'::jsonb,
|
||||
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
|
||||
|
|
@ -4495,6 +4615,7 @@ TABLES = {
|
|||
content TEXT,
|
||||
content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
|
||||
file_path TEXT NULL,
|
||||
metadata JSONB NULL DEFAULT '{{}}'::jsonb,
|
||||
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id)
|
||||
|
|
@ -4656,8 +4777,8 @@ SQL_TEMPLATES = {
|
|||
""",
|
||||
"upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
|
||||
chunk_order_index, full_doc_id, content, file_path, llm_cache_list,
|
||||
create_time, update_time)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
create_time, update_time, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
SET tokens=EXCLUDED.tokens,
|
||||
chunk_order_index=EXCLUDED.chunk_order_index,
|
||||
|
|
@ -4665,7 +4786,8 @@ SQL_TEMPLATES = {
|
|||
content = EXCLUDED.content,
|
||||
file_path=EXCLUDED.file_path,
|
||||
llm_cache_list=EXCLUDED.llm_cache_list,
|
||||
update_time = EXCLUDED.update_time
|
||||
update_time = EXCLUDED.update_time,
|
||||
metadata = EXCLUDED.metadata
|
||||
""",
|
||||
"upsert_full_entities": """INSERT INTO LIGHTRAG_FULL_ENTITIES (workspace, id, entity_names, count,
|
||||
create_time, update_time)
|
||||
|
|
@ -4686,8 +4808,8 @@ SQL_TEMPLATES = {
|
|||
# SQL for VectorStorage
|
||||
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
|
||||
chunk_order_index, full_doc_id, content, content_vector, file_path,
|
||||
create_time, update_time)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
create_time, update_time, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
SET tokens=EXCLUDED.tokens,
|
||||
chunk_order_index=EXCLUDED.chunk_order_index,
|
||||
|
|
@ -4695,7 +4817,8 @@ SQL_TEMPLATES = {
|
|||
content = EXCLUDED.content,
|
||||
content_vector=EXCLUDED.content_vector,
|
||||
file_path=EXCLUDED.file_path,
|
||||
update_time = EXCLUDED.update_time
|
||||
update_time = EXCLUDED.update_time,
|
||||
metadata = EXCLUDED.metadata
|
||||
""",
|
||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
||||
content_vector, chunk_ids, file_path, create_time, update_time)
|
||||
|
|
@ -4721,34 +4844,71 @@ SQL_TEMPLATES = {
|
|||
update_time = EXCLUDED.update_time
|
||||
""",
|
||||
"relationships": """
|
||||
SELECT r.source_id AS src_id,
|
||||
r.target_id AS tgt_id,
|
||||
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
|
||||
FROM LIGHTRAG_VDB_RELATION r
|
||||
WHERE r.workspace = $1
|
||||
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
|
||||
LIMIT $3;
|
||||
WITH relevant_chunks AS (
|
||||
SELECT id as chunk_id
|
||||
FROM LIGHTRAG_VDB_CHUNKS
|
||||
WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[]))
|
||||
{metadata_filter_clause}
|
||||
),
|
||||
rc AS (
|
||||
SELECT array_agg(chunk_id) AS chunk_arr
|
||||
FROM relevant_chunks
|
||||
)
|
||||
SELECT r.source_id AS src_id,
|
||||
r.target_id AS tgt_id,
|
||||
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
|
||||
FROM LIGHTRAG_VDB_RELATION r
|
||||
JOIN rc ON TRUE
|
||||
WHERE r.workspace = $1
|
||||
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||
AND r.chunk_ids && (rc.chunk_arr::varchar[])
|
||||
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
|
||||
LIMIT $3;
|
||||
""",
|
||||
"entities": """
|
||||
WITH relevant_chunks AS (
|
||||
SELECT id as chunk_id
|
||||
FROM LIGHTRAG_VDB_CHUNKS
|
||||
WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[]))
|
||||
{metadata_filter_clause}
|
||||
),
|
||||
rc AS (
|
||||
SELECT array_agg(chunk_id) AS chunk_arr
|
||||
FROM relevant_chunks
|
||||
)
|
||||
SELECT e.entity_name,
|
||||
EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
|
||||
FROM LIGHTRAG_VDB_ENTITY e
|
||||
JOIN rc ON TRUE
|
||||
WHERE e.workspace = $1
|
||||
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||
AND e.chunk_ids && (rc.chunk_arr::varchar[])
|
||||
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
|
||||
LIMIT $3;
|
||||
""",
|
||||
"chunks": """
|
||||
SELECT c.id,
|
||||
c.content,
|
||||
c.file_path,
|
||||
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
|
||||
FROM LIGHTRAG_VDB_CHUNKS c
|
||||
WHERE c.workspace = $1
|
||||
AND c.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
|
||||
LIMIT $3;
|
||||
WITH relevant_chunks AS (
|
||||
SELECT id as chunk_id
|
||||
FROM LIGHTRAG_VDB_CHUNKS
|
||||
WHERE ($4::varchar[] IS NULL OR full_doc_id = ANY ($4::varchar[]))
|
||||
{metadata_filter_clause}
|
||||
),
|
||||
rc AS (
|
||||
SELECT array_agg(chunk_id) AS chunk_arr
|
||||
FROM relevant_chunks
|
||||
)
|
||||
SELECT c.id,
|
||||
c.content,
|
||||
c.file_path,
|
||||
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at,
|
||||
c.metadata
|
||||
FROM LIGHTRAG_VDB_CHUNKS c
|
||||
JOIN rc ON TRUE
|
||||
WHERE c.workspace = $1
|
||||
AND c.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||
AND c.id = ANY (rc.chunk_arr)
|
||||
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
|
||||
LIMIT $3;
|
||||
""",
|
||||
# DROP tables
|
||||
"drop_specifiy_table_workspace": """
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ from .base import (
|
|||
QueryParam,
|
||||
QueryResult,
|
||||
QueryContextResult,
|
||||
MetadataFilter
|
||||
)
|
||||
from .prompt import PROMPTS
|
||||
from .constants import (
|
||||
|
|
@ -2409,6 +2408,7 @@ async def kg_query(
|
|||
query_param.ll_keywords or [],
|
||||
query_param.user_prompt or "",
|
||||
query_param.enable_rerank,
|
||||
query_param.metadata_filter,
|
||||
)
|
||||
cached_result = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
|
|
@ -2720,7 +2720,7 @@ async def _get_vector_context(
|
|||
cosine_threshold = chunks_vdb.cosine_better_than_threshold
|
||||
|
||||
results = await chunks_vdb.query(
|
||||
query, top_k=search_top_k, query_embedding=query_embedding
|
||||
query, top_k=search_top_k, query_embedding=query_embedding, metadata_filter=query_param.metadata_filter
|
||||
)
|
||||
if not results:
|
||||
logger.info(
|
||||
|
|
@ -2737,6 +2737,7 @@ async def _get_vector_context(
|
|||
"file_path": result.get("file_path", "unknown_source"),
|
||||
"source_type": "vector", # Mark the source type
|
||||
"chunk_id": result.get("id"), # Add chunk_id for deduplication
|
||||
"metadata": result.get("metadata")
|
||||
}
|
||||
valid_chunks.append(chunk_with_metadata)
|
||||
|
||||
|
|
@ -3561,7 +3562,8 @@ async def _get_node_data(
|
|||
f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})"
|
||||
)
|
||||
|
||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
||||
|
||||
results = await entities_vdb.query(query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
|
||||
|
||||
if not len(results):
|
||||
return [], []
|
||||
|
|
@ -3570,21 +3572,13 @@ async def _get_node_data(
|
|||
node_ids = [r["entity_name"] for r in results]
|
||||
|
||||
|
||||
# TODO update method to take in the metadata_filter dataclass
|
||||
node_kg_ids = []
|
||||
if hasattr(knowledge_graph_inst, "get_nodes_by_metadata_filter"):
|
||||
node_kg_ids = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_by_metadata_filter(QueryParam.metadata_filter)
|
||||
)
|
||||
|
||||
filtered_node_ids = (
|
||||
[nid for nid in node_ids if nid in node_kg_ids] if node_kg_ids else node_ids
|
||||
)
|
||||
# Extract all entity IDs from your results list
|
||||
node_ids = [r["entity_name"] for r in results]
|
||||
|
||||
# Call the batch node retrieval and degree functions concurrently.
|
||||
nodes_dict, degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_batch(filtered_node_ids),
|
||||
knowledge_graph_inst.node_degrees_batch(filtered_node_ids),
|
||||
knowledge_graph_inst.get_nodes_batch(node_ids),
|
||||
knowledge_graph_inst.node_degrees_batch(node_ids),
|
||||
)
|
||||
|
||||
# Now, if you need the node data and degree in order:
|
||||
|
|
@ -3849,7 +3843,7 @@ async def _get_edge_data(
|
|||
f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})"
|
||||
)
|
||||
|
||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
|
||||
|
||||
if not len(results):
|
||||
return [], []
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Optional
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import Any, Optional, List, Union, Dict
|
||||
|
||||
|
||||
class GPTKeywordExtractionFormat(BaseModel):
|
||||
|
|
@ -27,3 +27,46 @@ class KnowledgeGraph(BaseModel):
|
|||
nodes: list[KnowledgeGraphNode] = []
|
||||
edges: list[KnowledgeGraphEdge] = []
|
||||
is_truncated: bool = False
|
||||
|
||||
|
||||
class MetadataFilter(BaseModel):
|
||||
"""
|
||||
Represents a logical expression for metadata filtering.
|
||||
|
||||
Args:
|
||||
operator: "AND", "OR", or "NOT"
|
||||
operands: List of either simple key-value pairs or nested MetadataFilter objects
|
||||
"""
|
||||
operator: str = Field(..., description="Logical operator: AND, OR, or NOT")
|
||||
operands: List[Union[Dict[str, Any], 'MetadataFilter']] = Field(default_factory=list, description="List of operands for filtering")
|
||||
|
||||
@validator('operator')
|
||||
def validate_operator(cls, v):
|
||||
if v not in ["AND", "OR", "NOT"]:
|
||||
raise ValueError('operator must be one of: "AND", "OR", "NOT"')
|
||||
return v
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary representation."""
|
||||
return {
|
||||
"operator": self.operator,
|
||||
"operands": [
|
||||
operand.dict() if isinstance(operand, MetadataFilter) else operand
|
||||
for operand in self.operands
|
||||
]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter':
|
||||
"""Create from dictionary representation."""
|
||||
operands = []
|
||||
for operand in data.get("operands", []):
|
||||
if isinstance(operand, dict) and "operator" in operand:
|
||||
operands.append(cls.from_dict(operand))
|
||||
else:
|
||||
operands.append(operand)
|
||||
return cls(operator=data.get("operator", "AND"), operands=operands)
|
||||
|
||||
class Config:
|
||||
"""Pydantic configuration."""
|
||||
validate_assignment = True
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue