Merge pull request #2027 from Matt23-star/main

Refactor: PostgreSQL
This commit is contained in:
Daniel.y 2025-09-09 15:12:35 +08:00 committed by GitHub
commit f064b950fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 93 additions and 117 deletions

View file

@ -164,9 +164,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"Error during ChromaDB upsert: {str(e)}") logger.error(f"Error during ChromaDB upsert: {str(e)}")
raise raise
async def query( async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
try: try:
embedding = await self.embedding_func( embedding = await self.embedding_func(
[query], _priority=5 [query], _priority=5

View file

@ -787,13 +787,13 @@ class PostgreSQLDB:
FROM information_schema.columns FROM information_schema.columns
WHERE table_name = $1 AND column_name = $2 WHERE table_name = $1 AND column_name = $2
""" """
params = {
"table_name": migration["table"].lower(),
"column_name": migration["column"],
}
column_info = await self.query( column_info = await self.query(
check_column_sql, check_column_sql,
{ list(params.values()),
"table_name": migration["table"].lower(),
"column_name": migration["column"],
},
) )
if not column_info: if not column_info:
@ -1035,10 +1035,8 @@ class PostgreSQLDB:
WHERE table_name = $1 WHERE table_name = $1
AND table_schema = 'public' AND table_schema = 'public'
""" """
params = {"table_name": table_name.lower()}
table_exists = await self.query( table_exists = await self.query(check_table_sql, list(params.values()))
check_table_sql, {"table_name": table_name.lower()}
)
if not table_exists: if not table_exists:
logger.info(f"Creating table {table_name}") logger.info(f"Creating table {table_name}")
@ -1121,7 +1119,8 @@ class PostgreSQLDB:
AND indexname = $1 AND indexname = $1
""" """
existing = await self.query(check_sql, {"indexname": index["name"]}) params = {"indexname": index["name"]}
existing = await self.query(check_sql, list(params.values()))
if not existing: if not existing:
logger.info(f"Creating pagination index: {index['description']}") logger.info(f"Creating pagination index: {index['description']}")
@ -1217,7 +1216,7 @@ class PostgreSQLDB:
async def query( async def query(
self, self,
sql: str, sql: str,
params: dict[str, Any] | None = None, params: list[Any] | None = None,
multirows: bool = False, multirows: bool = False,
with_age: bool = False, with_age: bool = False,
graph_name: str | None = None, graph_name: str | None = None,
@ -1230,7 +1229,7 @@ class PostgreSQLDB:
try: try:
if params: if params:
rows = await connection.fetch(sql, *params.values()) rows = await connection.fetch(sql, *params)
else: else:
rows = await connection.fetch(sql) rows = await connection.fetch(sql)
@ -1446,7 +1445,7 @@ class PGKVStorage(BaseKVStorage):
params = {"workspace": self.workspace} params = {"workspace": self.workspace}
try: try:
results = await self.db.query(sql, params, multirows=True) results = await self.db.query(sql, list(params.values()), multirows=True)
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
@ -1540,7 +1539,7 @@ class PGKVStorage(BaseKVStorage):
"""Get data by id.""" """Get data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace] sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.workspace, "id": id} params = {"workspace": self.workspace, "id": id}
response = await self.db.query(sql, params) response = await self.db.query(sql, list(params.values()))
if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list # Parse llm_cache_list JSON string back to list
@ -1620,7 +1619,7 @@ class PGKVStorage(BaseKVStorage):
ids=",".join([f"'{id}'" for id in ids]) ids=",".join([f"'{id}'" for id in ids])
) )
params = {"workspace": self.workspace} params = {"workspace": self.workspace}
results = await self.db.query(sql, params, multirows=True) results = await self.db.query(sql, list(params.values()), multirows=True)
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list for each result # Parse llm_cache_list JSON string back to list for each result
@ -1708,7 +1707,7 @@ class PGKVStorage(BaseKVStorage):
) )
params = {"workspace": self.workspace} params = {"workspace": self.workspace}
try: try:
res = await self.db.query(sql, params, multirows=True) res = await self.db.query(sql, list(params.values()), multirows=True)
if res: if res:
exist_keys = [key["id"] for key in res] exist_keys = [key["id"] for key in res]
else: else:
@ -2023,7 +2022,7 @@ class PGVectorStorage(BaseVectorStorage):
"closer_than_threshold": 1 - self.cosine_better_than_threshold, "closer_than_threshold": 1 - self.cosine_better_than_threshold,
"top_k": top_k, "top_k": top_k,
} }
results = await self.db.query(sql, params=params, multirows=True) results = await self.db.query(sql, params=list(params.values()), multirows=True)
return results return results
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
@ -2120,7 +2119,7 @@ class PGVectorStorage(BaseVectorStorage):
params = {"workspace": self.workspace, "id": id} params = {"workspace": self.workspace, "id": id}
try: try:
result = await self.db.query(query, params) result = await self.db.query(query, list(params.values()))
if result: if result:
return dict(result) return dict(result)
return None return None
@ -2154,7 +2153,7 @@ class PGVectorStorage(BaseVectorStorage):
params = {"workspace": self.workspace} params = {"workspace": self.workspace}
try: try:
results = await self.db.query(query, params, multirows=True) results = await self.db.query(query, list(params.values()), multirows=True)
return [dict(record) for record in results] return [dict(record) for record in results]
except Exception as e: except Exception as e:
logger.error( logger.error(
@ -2187,7 +2186,7 @@ class PGVectorStorage(BaseVectorStorage):
params = {"workspace": self.workspace} params = {"workspace": self.workspace}
try: try:
results = await self.db.query(query, params, multirows=True) results = await self.db.query(query, list(params.values()), multirows=True)
vectors_dict = {} vectors_dict = {}
for result in results: for result in results:
@ -2274,7 +2273,7 @@ class PGDocStatusStorage(DocStatusStorage):
) )
params = {"workspace": self.workspace} params = {"workspace": self.workspace}
try: try:
res = await self.db.query(sql, params, multirows=True) res = await self.db.query(sql, list(params.values()), multirows=True)
if res: if res:
exist_keys = [key["id"] for key in res] exist_keys = [key["id"] for key in res]
else: else:
@ -2292,7 +2291,7 @@ class PGDocStatusStorage(DocStatusStorage):
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
params = {"workspace": self.workspace, "id": id} params = {"workspace": self.workspace, "id": id}
result = await self.db.query(sql, params, True) result = await self.db.query(sql, list(params.values()), True)
if result is None or result == []: if result is None or result == []:
return None return None
else: else:
@ -2338,7 +2337,7 @@ class PGDocStatusStorage(DocStatusStorage):
sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)" sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": ids} params = {"workspace": self.workspace, "ids": ids}
results = await self.db.query(sql, params, True) results = await self.db.query(sql, list(params.values()), True)
if not results: if not results:
return [] return []
@ -2389,7 +2388,8 @@ class PGDocStatusStorage(DocStatusStorage):
FROM LIGHTRAG_DOC_STATUS FROM LIGHTRAG_DOC_STATUS
where workspace=$1 GROUP BY STATUS where workspace=$1 GROUP BY STATUS
""" """
result = await self.db.query(sql, {"workspace": self.workspace}, True) params = {"workspace": self.workspace}
result = await self.db.query(sql, list(params.values()), True)
counts = {} counts = {}
for doc in result: for doc in result:
counts[doc["status"]] = doc["count"] counts[doc["status"]] = doc["count"]
@ -2401,7 +2401,7 @@ class PGDocStatusStorage(DocStatusStorage):
"""all documents with a specific status""" """all documents with a specific status"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.workspace, "status": status.value} params = {"workspace": self.workspace, "status": status.value}
result = await self.db.query(sql, params, True) result = await self.db.query(sql, list(params.values()), True)
docs_by_status = {} docs_by_status = {}
for element in result: for element in result:
@ -2455,7 +2455,7 @@ class PGDocStatusStorage(DocStatusStorage):
"""Get all documents with a specific track_id""" """Get all documents with a specific track_id"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and track_id=$2" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and track_id=$2"
params = {"workspace": self.workspace, "track_id": track_id} params = {"workspace": self.workspace, "track_id": track_id}
result = await self.db.query(sql, params, True) result = await self.db.query(sql, list(params.values()), True)
docs_by_track_id = {} docs_by_track_id = {}
for element in result: for element in result:
@ -2555,7 +2555,7 @@ class PGDocStatusStorage(DocStatusStorage):
# Query for total count # Query for total count
count_sql = f"SELECT COUNT(*) as total FROM LIGHTRAG_DOC_STATUS {where_clause}" count_sql = f"SELECT COUNT(*) as total FROM LIGHTRAG_DOC_STATUS {where_clause}"
count_result = await self.db.query(count_sql, params) count_result = await self.db.query(count_sql, list(params.values()))
total_count = count_result["total"] if count_result else 0 total_count = count_result["total"] if count_result else 0
# Query for paginated data # Query for paginated data
@ -2568,7 +2568,7 @@ class PGDocStatusStorage(DocStatusStorage):
params["limit"] = page_size params["limit"] = page_size
params["offset"] = offset params["offset"] = offset
result = await self.db.query(data_sql, params, True) result = await self.db.query(data_sql, list(params.values()), True)
# Convert to (doc_id, DocProcessingStatus) tuples # Convert to (doc_id, DocProcessingStatus) tuples
documents = [] documents = []
@ -2625,7 +2625,7 @@ class PGDocStatusStorage(DocStatusStorage):
GROUP BY status GROUP BY status
""" """
params = {"workspace": self.workspace} params = {"workspace": self.workspace}
result = await self.db.query(sql, params, True) result = await self.db.query(sql, list(params.values()), True)
counts = {} counts = {}
total_count = 0 total_count = 0
@ -3071,7 +3071,7 @@ class PGGraphStorage(BaseGraphStorage):
if readonly: if readonly:
data = await self.db.query( data = await self.db.query(
query, query,
params, list(params.values()) if params else None,
multirows=True, multirows=True,
with_age=True, with_age=True,
graph_name=self.graph_name, graph_name=self.graph_name,
@ -3102,114 +3102,92 @@ class PGGraphStorage(BaseGraphStorage):
return result return result
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = self._normalize_node_id(node_id) query = f"""
SELECT EXISTS (
SELECT 1
FROM {self.graph_name}.base
WHERE ag_catalog.agtype_access_operator(
VARIADIC ARRAY[properties, '"entity_id"'::agtype]
) = (to_json($1::text)::text)::agtype
LIMIT 1
) AS node_exists;
"""
query = """SELECT * FROM cypher('%s', $$ params = {"node_id": node_id}
MATCH (n:base {entity_id: "%s"}) row = (await self._query(query, params=params))[0]
RETURN count(n) > 0 AS node_exists return bool(row["node_exists"])
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
single_result = (await self._query(query))[0]
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
src_label = self._normalize_node_id(source_node_id) query = f"""
tgt_label = self._normalize_node_id(target_node_id) WITH a AS (
SELECT id AS vid
query = """SELECT * FROM cypher('%s', $$ FROM {self.graph_name}.base
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) WHERE ag_catalog.agtype_access_operator(
RETURN COUNT(r) > 0 AS edge_exists VARIADIC ARRAY[properties, '"entity_id"'::agtype]
$$) AS (edge_exists bool)""" % ( ) = (to_json($1::text)::text)::agtype
self.graph_name, ),
src_label, b AS (
tgt_label, SELECT id AS vid
) FROM {self.graph_name}.base
WHERE ag_catalog.agtype_access_operator(
single_result = (await self._query(query))[0] VARIADIC ARRAY[properties, '"entity_id"'::agtype]
) = (to_json($2::text)::text)::agtype
return single_result["edge_exists"] )
SELECT EXISTS (
SELECT 1
FROM {self.graph_name}."DIRECTED" d
JOIN a ON d.start_id = a.vid
JOIN b ON d.end_id = b.vid
LIMIT 1
)
OR EXISTS (
SELECT 1
FROM {self.graph_name}."DIRECTED" d
JOIN a ON d.end_id = a.vid
JOIN b ON d.start_id = b.vid
LIMIT 1
) AS edge_exists;
"""
params = {
"source_node_id": source_node_id,
"target_node_id": target_node_id,
}
row = (await self._query(query, params=params))[0]
return bool(row["edge_exists"])
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties""" """Get node by its label identifier, return only node properties"""
label = self._normalize_node_id(node_id) label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})
RETURN n
$$) AS (n agtype)""" % (self.graph_name, label)
record = await self._query(query)
if record:
node = record[0]
node_dict = node["n"]["properties"]
# Process string result, parse it to JSON dictionary result = await self.get_nodes_batch(node_ids=[label])
if isinstance(node_dict, str): if result and node_id in result:
try: return result[node_id]
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string: {node_dict}"
)
return node_dict
return None return None
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
label = self._normalize_node_id(node_id) label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$ result = await self.node_degrees_batch(node_ids=[label])
MATCH (n:base {entity_id: "%s"})-[r]-() if result and node_id in result:
RETURN count(r) AS total_edge_count return result[node_id]
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
record = (await self._query(query))[0]
if record:
edge_count = int(record["total_edge_count"])
return edge_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
src_degree = await self.node_degree(src_id) result = await self.edge_degrees_batch(edges=[(src_id, tgt_id)])
trg_degree = await self.node_degree(tgt_id) if result and (src_id, tgt_id) in result:
return result[(src_id, tgt_id)]
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
return degrees
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
"""Get edge properties between two nodes""" """Get edge properties between two nodes"""
src_label = self._normalize_node_id(source_node_id) src_label = self._normalize_node_id(source_node_id)
tgt_label = self._normalize_node_id(target_node_id) tgt_label = self._normalize_node_id(target_node_id)
query = """SELECT * FROM cypher('%s', $$ result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}])
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) if result and (src_label, tgt_label) in result:
RETURN properties(r) as edge_properties return result[(src_label, tgt_label)]
LIMIT 1 return None
$$) AS (edge_properties agtype)""" % (
self.graph_name,
src_label,
tgt_label,
)
record = await self._query(query)
if record and record[0] and record[0]["edge_properties"]:
result = record[0]["edge_properties"]
# Process string result, parse it to JSON dictionary
if isinstance(result, str):
try:
result = json.loads(result)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse edge string: {result}"
)
return result
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
""" """