diff --git a/README-zh.md b/README-zh.md
index 7dd7e975..98edca47 100644
--- a/README-zh.md
+++ b/README-zh.md
@@ -30,7 +30,7 @@
-
+
diff --git a/README.md b/README.md
index 79479da9..dd8238f0 100644
--- a/README.md
+++ b/README.md
@@ -870,6 +870,41 @@ rag = LightRAG(
+
+ Using Memgraph for Storage
+
+* Memgraph is a high-performance, in-memory graph database compatible with the Neo4j Bolt protocol.
+* You can run Memgraph locally using Docker for easy testing:
+* See: https://memgraph.com/download
+
+```python
+export MEMGRAPH_URI="bolt://localhost:7687"
+
+# Setup logger for LightRAG
+setup_logger("lightrag", level="INFO")
+
+# When you launch the project, override the default KG: NetworkX
+# by specifying kg="MemgraphStorage".
+
+# Note: Default settings use NetworkX
+# Initialize LightRAG with Memgraph implementation.
+async def initialize_rag():
+ rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
+ graph_storage="MemgraphStorage", #<-----------override KG default
+ )
+
+ # Initialize database connections
+ await rag.initialize_storages()
+ # Initialize pipeline status for document processing
+ await initialize_pipeline_status()
+
+ return rag
+```
+
+
+
## Edit Entities and Relations
LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph.
diff --git a/config.ini.example b/config.ini.example
index 63d9c2c0..94d300a1 100644
--- a/config.ini.example
+++ b/config.ini.example
@@ -21,3 +21,6 @@ password = your_password
database = your_database
workspace = default # 可选,默认为default
max_connections = 12
+
+[memgraph]
+uri = bolt://localhost:7687
diff --git a/env.example b/env.example
index 874ebf3c..d8a4dfd6 100644
--- a/env.example
+++ b/env.example
@@ -147,13 +147,14 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
# LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage
### Graph Storage (Recommended for production deployment)
# LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
+# LIGHTRAG_GRAPH_STORAGE=MemgraphStorage
####################################################################
### Default workspace for all storage types
### For the purpose of isolation of data for each LightRAG instance
### Valid characters: a-z, A-Z, 0-9, and _
####################################################################
-# WORKSPACE=doc—
+# WORKSPACE=space1
### PostgreSQL Configuration
POSTGRES_HOST=localhost
@@ -192,3 +193,10 @@ QDRANT_URL=http://localhost:6333
### Redis
REDIS_URI=redis://localhost:6379
# REDIS_WORKSPACE=forced_workspace_name
+
+### Memgraph Configuration
+MEMGRAPH_URI=bolt://localhost:7687
+MEMGRAPH_USERNAME=
+MEMGRAPH_PASSWORD=
+MEMGRAPH_DATABASE=memgraph
+# MEMGRAPH_WORKSPACE=forced_workspace_name
diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py
index 1f5fd56f..b2a93e82 100644
--- a/lightrag/kg/__init__.py
+++ b/lightrag/kg/__init__.py
@@ -15,6 +15,7 @@ STORAGE_IMPLEMENTATIONS = {
"Neo4JStorage",
"PGGraphStorage",
"MongoGraphStorage",
+ "MemgraphStorage",
# "AGEStorage",
# "TiDBGraphStorage",
# "GremlinStorage",
@@ -57,6 +58,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [],
+ "MemgraphStorage": ["MEMGRAPH_URI"],
# "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [
"AGE_POSTGRES_DB",
@@ -111,6 +113,7 @@ STORAGES = {
"PGDocStatusStorage": ".kg.postgres_impl",
"FaissVectorDBStorage": ".kg.faiss_impl",
"QdrantVectorDBStorage": ".kg.qdrant_impl",
+ "MemgraphStorage": ".kg.memgraph_impl",
}
diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py
new file mode 100644
index 00000000..8c6d6574
--- /dev/null
+++ b/lightrag/kg/memgraph_impl.py
@@ -0,0 +1,906 @@
+import os
+from dataclasses import dataclass
+from typing import final
+import configparser
+
+from ..utils import logger
+from ..base import BaseGraphStorage
+from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
+from ..constants import GRAPH_FIELD_SEP
+import pipmaster as pm
+
+if not pm.is_installed("neo4j"):
+ pm.install("neo4j")
+
+from neo4j import (
+ AsyncGraphDatabase,
+ AsyncManagedTransaction,
+)
+
+from dotenv import load_dotenv
+
+# use the .env that is inside the current folder
+load_dotenv(dotenv_path=".env", override=False)
+
+MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
+
+config = configparser.ConfigParser()
+config.read("config.ini", "utf-8")
+
+
+@final
+@dataclass
+class MemgraphStorage(BaseGraphStorage):
+ def __init__(self, namespace, global_config, embedding_func, workspace=None):
+ memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
+ if memgraph_workspace and memgraph_workspace.strip():
+ workspace = memgraph_workspace
+ super().__init__(
+ namespace=namespace,
+ workspace=workspace or "",
+ global_config=global_config,
+ embedding_func=embedding_func,
+ )
+ self._driver = None
+
+ def _get_workspace_label(self) -> str:
+ """Get workspace label, return 'base' for compatibility when workspace is empty"""
+ workspace = getattr(self, "workspace", None)
+ return workspace if workspace else "base"
+
+ async def initialize(self):
+ URI = os.environ.get(
+ "MEMGRAPH_URI",
+ config.get("memgraph", "uri", fallback="bolt://localhost:7687"),
+ )
+ USERNAME = os.environ.get(
+ "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")
+ )
+ PASSWORD = os.environ.get(
+ "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")
+ )
+ DATABASE = os.environ.get(
+ "MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")
+ )
+
+ self._driver = AsyncGraphDatabase.driver(
+ URI,
+ auth=(USERNAME, PASSWORD),
+ )
+ self._DATABASE = DATABASE
+ try:
+ async with self._driver.session(database=DATABASE) as session:
+ # Create index for base nodes on entity_id if it doesn't exist
+ try:
+ workspace_label = self._get_workspace_label()
+ await session.run(
+ f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
+ )
+ logger.info(
+ f"Created index on :{workspace_label}(entity_id) in Memgraph."
+ )
+ except Exception as e:
+ # Index may already exist, which is not an error
+ logger.warning(
+ f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
+ )
+ await session.run("RETURN 1")
+ logger.info(f"Connected to Memgraph at {URI}")
+ except Exception as e:
+ logger.error(f"Failed to connect to Memgraph at {URI}: {e}")
+ raise
+
+ async def finalize(self):
+ if self._driver is not None:
+ await self._driver.close()
+ self._driver = None
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.finalize()
+
+ async def index_done_callback(self):
+ # Memgraph handles persistence automatically
+ pass
+
+ async def has_node(self, node_id: str) -> bool:
+ """
+ Check if a node exists in the graph.
+
+ Args:
+ node_id: The ID of the node to check.
+
+ Returns:
+ bool: True if the node exists, False otherwise.
+
+ Raises:
+ Exception: If there is an error checking the node existence.
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ async with self._driver.session(
+ database=self._DATABASE, default_access_mode="READ"
+ ) as session:
+ try:
+ workspace_label = self._get_workspace_label()
+ query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
+ result = await session.run(query, entity_id=node_id)
+ single_result = await result.single()
+ await result.consume() # Ensure result is fully consumed
+ return (
+ single_result["node_exists"] if single_result is not None else False
+ )
+ except Exception as e:
+ logger.error(f"Error checking node existence for {node_id}: {str(e)}")
+ await result.consume() # Ensure the result is consumed even on error
+ raise
+
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
+ """
+ Check if an edge exists between two nodes in the graph.
+
+ Args:
+ source_node_id: The ID of the source node.
+ target_node_id: The ID of the target node.
+
+ Returns:
+ bool: True if the edge exists, False otherwise.
+
+ Raises:
+ Exception: If there is an error checking the edge existence.
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ async with self._driver.session(
+ database=self._DATABASE, default_access_mode="READ"
+ ) as session:
+ try:
+ workspace_label = self._get_workspace_label()
+ query = (
+ f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
+ "RETURN COUNT(r) > 0 AS edgeExists"
+ )
+ result = await session.run(
+ query,
+ source_entity_id=source_node_id,
+ target_entity_id=target_node_id,
+ ) # type: ignore
+ single_result = await result.single()
+ await result.consume() # Ensure result is fully consumed
+ return (
+ single_result["edgeExists"] if single_result is not None else False
+ )
+ except Exception as e:
+ logger.error(
+ f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
+ )
+ await result.consume() # Ensure the result is consumed even on error
+ raise
+
+ async def get_node(self, node_id: str) -> dict[str, str] | None:
+ """Get node by its label identifier, return only node properties
+
+ Args:
+ node_id: The node label to look up
+
+ Returns:
+ dict: Node properties if found
+ None: If node not found
+
+ Raises:
+ Exception: If there is an error executing the query
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ async with self._driver.session(
+ database=self._DATABASE, default_access_mode="READ"
+ ) as session:
+ try:
+ workspace_label = self._get_workspace_label()
+ query = (
+ f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
+ )
+ result = await session.run(query, entity_id=node_id)
+ try:
+ records = await result.fetch(
+ 2
+ ) # Get 2 records for duplication check
+
+ if len(records) > 1:
+ logger.warning(
+ f"Multiple nodes found with label '{node_id}'. Using first node."
+ )
+ if records:
+ node = records[0]["n"]
+ node_dict = dict(node)
+ # Remove workspace label from labels list if it exists
+ if "labels" in node_dict:
+ node_dict["labels"] = [
+ label
+ for label in node_dict["labels"]
+ if label != workspace_label
+ ]
+ return node_dict
+ return None
+ finally:
+ await result.consume() # Ensure result is fully consumed
+ except Exception as e:
+ logger.error(f"Error getting node for {node_id}: {str(e)}")
+ raise
+
+ async def node_degree(self, node_id: str) -> int:
+ """Get the degree (number of relationships) of a node with the given label.
+ If multiple nodes have the same label, returns the degree of the first node.
+ If no node is found, returns 0.
+
+ Args:
+ node_id: The label of the node
+
+ Returns:
+ int: The number of relationships the node has, or 0 if no node found
+
+ Raises:
+ Exception: If there is an error executing the query
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ async with self._driver.session(
+ database=self._DATABASE, default_access_mode="READ"
+ ) as session:
+ try:
+ workspace_label = self._get_workspace_label()
+ query = f"""
+ MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
+ OPTIONAL MATCH (n)-[r]-()
+ RETURN COUNT(r) AS degree
+ """
+ result = await session.run(query, entity_id=node_id)
+ try:
+ record = await result.single()
+
+ if not record:
+ logger.warning(f"No node found with label '{node_id}'")
+ return 0
+
+ degree = record["degree"]
+ return degree
+ finally:
+ await result.consume() # Ensure result is fully consumed
+ except Exception as e:
+ logger.error(f"Error getting node degree for {node_id}: {str(e)}")
+ raise
+
+ async def get_all_labels(self) -> list[str]:
+ """
+ Get all existing node labels in the database
+ Returns:
+ ["Person", "Company", ...] # Alphabetically sorted label list
+
+ Raises:
+ Exception: If there is an error executing the query
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ async with self._driver.session(
+ database=self._DATABASE, default_access_mode="READ"
+ ) as session:
+ try:
+ workspace_label = self._get_workspace_label()
+ query = f"""
+ MATCH (n:`{workspace_label}`)
+ WHERE n.entity_id IS NOT NULL
+ RETURN DISTINCT n.entity_id AS label
+ ORDER BY label
+ """
+ result = await session.run(query)
+ labels = []
+ async for record in result:
+ labels.append(record["label"])
+ await result.consume()
+ return labels
+ except Exception as e:
+ logger.error(f"Error getting all labels: {str(e)}")
+ await result.consume() # Ensure the result is consumed even on error
+ raise
+
+ async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
+ """Retrieves all edges (relationships) for a particular node identified by its label.
+
+ Args:
+ source_node_id: Label of the node to get edges for
+
+ Returns:
+ list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
+ None: If no edges found
+
+ Raises:
+ Exception: If there is an error executing the query
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ try:
+ async with self._driver.session(
+ database=self._DATABASE, default_access_mode="READ"
+ ) as session:
+ try:
+ workspace_label = self._get_workspace_label()
+ query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
+ OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
+ WHERE connected.entity_id IS NOT NULL
+ RETURN n, r, connected"""
+ results = await session.run(query, entity_id=source_node_id)
+
+ edges = []
+ async for record in results:
+ source_node = record["n"]
+ connected_node = record["connected"]
+
+ # Skip if either node is None
+ if not source_node or not connected_node:
+ continue
+
+ source_label = (
+ source_node.get("entity_id")
+ if source_node.get("entity_id")
+ else None
+ )
+ target_label = (
+ connected_node.get("entity_id")
+ if connected_node.get("entity_id")
+ else None
+ )
+
+ if source_label and target_label:
+ edges.append((source_label, target_label))
+
+ await results.consume() # Ensure results are consumed
+ return edges
+ except Exception as e:
+ logger.error(
+ f"Error getting edges for node {source_node_id}: {str(e)}"
+ )
+ await results.consume() # Ensure results are consumed even on error
+ raise
+ except Exception as e:
+ logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
+ raise
+
+ async def get_edge(
+ self, source_node_id: str, target_node_id: str
+ ) -> dict[str, str] | None:
+ """Get edge properties between two nodes.
+
+ Args:
+ source_node_id: Label of the source node
+ target_node_id: Label of the target node
+
+ Returns:
+ dict: Edge properties if found, default properties if not found or on error
+
+ Raises:
+ Exception: If there is an error executing the query
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ async with self._driver.session(
+ database=self._DATABASE, default_access_mode="READ"
+ ) as session:
+ try:
+ workspace_label = self._get_workspace_label()
+ query = f"""
+ MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
+ RETURN properties(r) as edge_properties
+ """
+ result = await session.run(
+ query,
+ source_entity_id=source_node_id,
+ target_entity_id=target_node_id,
+ )
+ records = await result.fetch(2)
+ await result.consume()
+ if records:
+ edge_result = dict(records[0]["edge_properties"])
+ for key, default_value in {
+ "weight": 0.0,
+ "source_id": None,
+ "description": None,
+ "keywords": None,
+ }.items():
+ if key not in edge_result:
+ edge_result[key] = default_value
+ logger.warning(
+ f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}"
+ )
+ return edge_result
+ return None
+ except Exception as e:
+ logger.error(
+ f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
+ )
+ await result.consume() # Ensure the result is consumed even on error
+ raise
+
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
+ """
+ Upsert a node in the Neo4j database.
+
+ Args:
+ node_id: The unique identifier for the node (used as label)
+ node_data: Dictionary of node properties
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ properties = node_data
+ entity_type = properties["entity_type"]
+ if "entity_id" not in properties:
+ raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
+
+ try:
+ async with self._driver.session(database=self._DATABASE) as session:
+ workspace_label = self._get_workspace_label()
+
+ async def execute_upsert(tx: AsyncManagedTransaction):
+ query = f"""
+ MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
+ SET n += $properties
+ SET n:`{entity_type}`
+ """
+ result = await tx.run(
+ query, entity_id=node_id, properties=properties
+ )
+ await result.consume() # Ensure result is fully consumed
+
+ await session.execute_write(execute_upsert)
+ except Exception as e:
+ logger.error(f"Error during upsert: {str(e)}")
+ raise
+
+ async def upsert_edge(
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
+ ) -> None:
+ """
+ Upsert an edge and its properties between two nodes identified by their labels.
+ Ensures both source and target nodes exist and are unique before creating the edge.
+ Uses entity_id property to uniquely identify nodes.
+
+ Args:
+ source_node_id (str): Label of the source node (used as identifier)
+ target_node_id (str): Label of the target node (used as identifier)
+ edge_data (dict): Dictionary of properties to set on the edge
+
+ Raises:
+ Exception: If there is an error executing the query
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ try:
+ edge_properties = edge_data
+ async with self._driver.session(database=self._DATABASE) as session:
+
+ async def execute_upsert(tx: AsyncManagedTransaction):
+ workspace_label = self._get_workspace_label()
+ query = f"""
+ MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
+ WITH source
+ MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
+ MERGE (source)-[r:DIRECTED]-(target)
+ SET r += $properties
+ RETURN r, source, target
+ """
+ result = await tx.run(
+ query,
+ source_entity_id=source_node_id,
+ target_entity_id=target_node_id,
+ properties=edge_properties,
+ )
+ try:
+ await result.fetch(2)
+ finally:
+ await result.consume() # Ensure result is consumed
+
+ await session.execute_write(execute_upsert)
+ except Exception as e:
+ logger.error(f"Error during edge upsert: {str(e)}")
+ raise
+
+ async def delete_node(self, node_id: str) -> None:
+ """Delete a node with the specified label
+
+ Args:
+ node_id: The label of the node to delete
+
+ Raises:
+ Exception: If there is an error executing the query
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+
+ async def _do_delete(tx: AsyncManagedTransaction):
+ workspace_label = self._get_workspace_label()
+ query = f"""
+ MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
+ DETACH DELETE n
+ """
+ result = await tx.run(query, entity_id=node_id)
+ logger.debug(f"Deleted node with label {node_id}")
+ await result.consume()
+
+ try:
+ async with self._driver.session(database=self._DATABASE) as session:
+ await session.execute_write(_do_delete)
+ except Exception as e:
+ logger.error(f"Error during node deletion: {str(e)}")
+ raise
+
+ async def remove_nodes(self, nodes: list[str]):
+ """Delete multiple nodes
+
+ Args:
+ nodes: List of node labels to be deleted
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ for node in nodes:
+ await self.delete_node(node)
+
+ async def remove_edges(self, edges: list[tuple[str, str]]):
+ """Delete multiple edges
+
+ Args:
+ edges: List of edges to be deleted, each edge is a (source, target) tuple
+
+ Raises:
+ Exception: If there is an error executing the query
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ for source, target in edges:
+
+ async def _do_delete_edge(tx: AsyncManagedTransaction):
+ workspace_label = self._get_workspace_label()
+ query = f"""
+ MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
+ DELETE r
+ """
+ result = await tx.run(
+ query, source_entity_id=source, target_entity_id=target
+ )
+ logger.debug(f"Deleted edge from '{source}' to '{target}'")
+ await result.consume() # Ensure result is fully consumed
+
+ try:
+ async with self._driver.session(database=self._DATABASE) as session:
+ await session.execute_write(_do_delete_edge)
+ except Exception as e:
+ logger.error(f"Error during edge deletion: {str(e)}")
+ raise
+
+ async def drop(self) -> dict[str, str]:
+ """Drop all data from the current workspace and clean up resources
+
+ This method will delete all nodes and relationships in the Memgraph database.
+
+ Returns:
+ dict[str, str]: Operation status and message
+ - On success: {"status": "success", "message": "data dropped"}
+ - On failure: {"status": "error", "message": ""}
+
+ Raises:
+ Exception: If there is an error executing the query
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ try:
+ async with self._driver.session(database=self._DATABASE) as session:
+ workspace_label = self._get_workspace_label()
+ query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
+ result = await session.run(query)
+ await result.consume()
+ logger.info(
+ f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
+ )
+ return {"status": "success", "message": "workspace data dropped"}
+ except Exception as e:
+ logger.error(
+ f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
+ )
+ return {"status": "error", "message": str(e)}
+
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
+ """Get the total degree (sum of relationships) of two nodes.
+
+ Args:
+ src_id: Label of the source node
+ tgt_id: Label of the target node
+
+ Returns:
+ int: Sum of the degrees of both nodes
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ src_degree = await self.node_degree(src_id)
+ trg_degree = await self.node_degree(tgt_id)
+
+ # Convert None to 0 for addition
+ src_degree = 0 if src_degree is None else src_degree
+ trg_degree = 0 if trg_degree is None else trg_degree
+
+ degrees = int(src_degree) + int(trg_degree)
+ return degrees
+
+ async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
+ """Get all nodes that are associated with the given chunk_ids.
+
+ Args:
+ chunk_ids: List of chunk IDs to find associated nodes for
+
+ Returns:
+ list[dict]: A list of nodes, where each node is a dictionary of its properties.
+ An empty list if no matching nodes are found.
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ workspace_label = self._get_workspace_label()
+ async with self._driver.session(
+ database=self._DATABASE, default_access_mode="READ"
+ ) as session:
+ query = f"""
+ UNWIND $chunk_ids AS chunk_id
+ MATCH (n:`{workspace_label}`)
+ WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
+ RETURN DISTINCT n
+ """
+ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
+ nodes = []
+ async for record in result:
+ node = record["n"]
+ node_dict = dict(node)
+ node_dict["id"] = node_dict.get("entity_id")
+ nodes.append(node_dict)
+ await result.consume()
+ return nodes
+
+ async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
+ """Get all edges that are associated with the given chunk_ids.
+
+ Args:
+ chunk_ids: List of chunk IDs to find associated edges for
+
+ Returns:
+ list[dict]: A list of edges, where each edge is a dictionary of its properties.
+ An empty list if no matching edges are found.
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+ workspace_label = self._get_workspace_label()
+ async with self._driver.session(
+ database=self._DATABASE, default_access_mode="READ"
+ ) as session:
+ query = f"""
+ UNWIND $chunk_ids AS chunk_id
+ MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
+ WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
+ WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
+ // Ensure we only return each unique edge once by ordering the source and target
+ WITH a, b, r,
+ CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
+ CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
+ RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
+ """
+ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
+ edges = []
+ async for record in result:
+ edge_properties = record["properties"]
+ edge_properties["source"] = record["source"]
+ edge_properties["target"] = record["target"]
+ edges.append(edge_properties)
+ await result.consume()
+ return edges
+
+ async def get_knowledge_graph(
+ self,
+ node_label: str,
+ max_depth: int = 3,
+ max_nodes: int = MAX_GRAPH_NODES,
+ ) -> KnowledgeGraph:
+ """
+ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
+
+ Args:
+ node_label: Label of the starting node, * means all nodes
+ max_depth: Maximum depth of the subgraph, Defaults to 3
+ max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
+
+ Returns:
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
+ indicating whether the graph was truncated due to max_nodes limit
+
+ Raises:
+ Exception: If there is an error executing the query
+ """
+ if self._driver is None:
+ raise RuntimeError(
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
+ )
+
+ result = KnowledgeGraph()
+ seen_nodes = set()
+ seen_edges = set()
+ workspace_label = self._get_workspace_label()
+ async with self._driver.session(
+ database=self._DATABASE, default_access_mode="READ"
+ ) as session:
+ try:
+ if node_label == "*":
+ # First check if database has any nodes
+ count_query = "MATCH (n) RETURN count(n) as total"
+ count_result = None
+ total_count = 0
+ try:
+ count_result = await session.run(count_query)
+ count_record = await count_result.single()
+ if count_record:
+ total_count = count_record["total"]
+ if total_count == 0:
+ logger.debug("No nodes found in database")
+ return result
+ if total_count > max_nodes:
+ result.is_truncated = True
+ logger.info(
+ f"Graph truncated: {total_count} nodes found, limited to {max_nodes}"
+ )
+ finally:
+ if count_result:
+ await count_result.consume()
+
+ # Run the main query to get nodes with highest degree
+ main_query = f"""
+ MATCH (n:`{workspace_label}`)
+ OPTIONAL MATCH (n)-[r]-()
+ WITH n, COALESCE(count(r), 0) AS degree
+ ORDER BY degree DESC
+ LIMIT $max_nodes
+ WITH collect(n) AS kept_nodes
+ MATCH (a)-[r]-(b)
+ WHERE a IN kept_nodes AND b IN kept_nodes
+ RETURN [node IN kept_nodes | {{node: node}}] AS node_info,
+ collect(DISTINCT r) AS relationships
+ """
+ result_set = None
+ try:
+ result_set = await session.run(
+ main_query, {"max_nodes": max_nodes}
+ )
+ record = await result_set.single()
+ if not record:
+ logger.debug("No record returned from main query")
+ return result
+ finally:
+ if result_set:
+ await result_set.consume()
+
+ else:
+ bfs_query = f"""
+ MATCH (start:`{workspace_label}`)
+ WHERE start.entity_id = $entity_id
+ WITH start
+ CALL {{
+ WITH start
+ MATCH path = (start)-[*0..{max_depth}]-(node)
+ WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
+ UNWIND path_nodes AS n
+ WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
+ WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
+ RETURN all_nodes, all_rels
+ }}
+ WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
+ WITH
+ CASE
+ WHEN total_nodes <= {max_nodes} THEN nodes
+ ELSE nodes[0..{max_nodes}]
+ END AS limited_nodes,
+ relationships,
+ total_nodes,
+ total_nodes > {max_nodes} AS is_truncated
+ RETURN
+ [node IN limited_nodes | {{node: node}}] AS node_info,
+ relationships,
+ total_nodes,
+ is_truncated
+ """
+ result_set = None
+ try:
+ result_set = await session.run(
+ bfs_query,
+ {
+ "entity_id": node_label,
+ },
+ )
+ record = await result_set.single()
+ if not record:
+ logger.debug(f"No nodes found for entity_id: {node_label}")
+ return result
+
+ # Check if the query indicates truncation
+ if "is_truncated" in record and record["is_truncated"]:
+ result.is_truncated = True
+ logger.info(
+ f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
+ )
+
+ finally:
+ if result_set:
+ await result_set.consume()
+
+ # Process the record if it exists
+ if record and record["node_info"]:
+ for node_info in record["node_info"]:
+ node = node_info["node"]
+ node_id = node.id
+ if node_id not in seen_nodes:
+ seen_nodes.add(node_id)
+ result.nodes.append(
+ KnowledgeGraphNode(
+ id=f"{node_id}",
+ labels=[node.get("entity_id")],
+ properties=dict(node),
+ )
+ )
+
+ for rel in record["relationships"]:
+ edge_id = rel.id
+ if edge_id not in seen_edges:
+ seen_edges.add(edge_id)
+ start = rel.start_node
+ end = rel.end_node
+ result.edges.append(
+ KnowledgeGraphEdge(
+ id=f"{edge_id}",
+ type=rel.type,
+ source=f"{start.id}",
+ target=f"{end.id}",
+ properties=dict(rel),
+ )
+ )
+
+ logger.info(
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
+ )
+
+ except Exception as e:
+ logger.error(f"Error getting knowledge graph: {str(e)}")
+ # Return empty but properly initialized KnowledgeGraph on error
+ return KnowledgeGraph()
+
+ return result
diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py
index 57f016cf..eb74c2f1 100644
--- a/lightrag/llm/openai.py
+++ b/lightrag/llm/openai.py
@@ -210,9 +210,18 @@ async def openai_complete_if_cache(
async def inner():
# Track if we've started iterating
iteration_started = False
+ final_chunk_usage = None
+
try:
iteration_started = True
async for chunk in response:
+ # Check if this chunk has usage information (final chunk)
+ if hasattr(chunk, "usage") and chunk.usage:
+ final_chunk_usage = chunk.usage
+ logger.debug(
+ f"Received usage info in streaming chunk: {chunk.usage}"
+ )
+
# Check if choices exists and is not empty
if not hasattr(chunk, "choices") or not chunk.choices:
logger.warning(f"Received chunk without choices: {chunk}")
@@ -222,16 +231,31 @@ async def openai_complete_if_cache(
if not hasattr(chunk.choices[0], "delta") or not hasattr(
chunk.choices[0].delta, "content"
):
- logger.warning(
- f"Received chunk without delta content: {chunk.choices[0]}"
- )
+ # This might be the final chunk, continue to check for usage
continue
+
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
+
yield content
+
+ # After streaming is complete, track token usage
+ if token_tracker and final_chunk_usage:
+ # Use actual usage from the API
+ token_counts = {
+ "prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0),
+ "completion_tokens": getattr(
+ final_chunk_usage, "completion_tokens", 0
+ ),
+ "total_tokens": getattr(final_chunk_usage, "total_tokens", 0),
+ }
+ token_tracker.add_usage(token_counts)
+ logger.debug(f"Streaming token usage (from API): {token_counts}")
+ elif token_tracker:
+ logger.debug("No usage information available in streaming response")
except Exception as e:
logger.error(f"Error in stream response: {str(e)}")
# Try to clean up resources if possible
diff --git a/lightrag/operate.py b/lightrag/operate.py
index 05fef78e..e2251fc7 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -26,6 +26,7 @@ from .utils import (
get_conversation_turns,
use_llm_func_with_cache,
update_chunk_cache_list,
+ remove_think_tags,
)
from .base import (
BaseGraphStorage,
@@ -1704,7 +1705,8 @@ async def extract_keywords_only(
result = await use_model_func(kw_prompt, keyword_extraction=True)
# 6. Parse out JSON from the LLM response
- match = re.search(r"\{.*\}", result, re.DOTALL)
+ result = remove_think_tags(result)
+ match = re.search(r"\{.*?\}", result, re.DOTALL)
if not match:
logger.error("No JSON-like structure found in the LLM respond.")
return [], []
diff --git a/lightrag/utils.py b/lightrag/utils.py
index c6e2def9..386de3ab 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -1465,6 +1465,11 @@ async def update_chunk_cache_list(
)
+def remove_think_tags(text: str) -> str:
+ """Remove tags from the text"""
+ return re.sub(r"^(.*?|)", "", text, flags=re.DOTALL).strip()
+
+
async def use_llm_func_with_cache(
input_text: str,
use_llm_func: callable,
@@ -1531,6 +1536,7 @@ async def use_llm_func_with_cache(
kwargs["max_tokens"] = max_tokens
res: str = await use_llm_func(input_text, **kwargs)
+ res = remove_think_tags(res)
if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"):
await save_to_cache(
@@ -1557,8 +1563,9 @@ async def use_llm_func_with_cache(
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
- logger.info(f"Call LLM function with query text lenght: {len(input_text)}")
- return await use_llm_func(input_text, **kwargs)
+ logger.info(f"Call LLM function with query text length: {len(input_text)}")
+ res = await use_llm_func(input_text, **kwargs)
+ return remove_think_tags(res)
def get_content_summary(content: str, max_length: int = 250) -> str:
diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py
index 258c8795..62f658ff 100644
--- a/tests/test_graph_storage.py
+++ b/tests/test_graph_storage.py
@@ -10,6 +10,7 @@
- Neo4JStorage
- MongoDBStorage
- PGGraphStorage
+- MemgraphStorage
"""
import asyncio