diff --git a/README-zh.md b/README-zh.md index 7dd7e975..98edca47 100644 --- a/README-zh.md +++ b/README-zh.md @@ -30,7 +30,7 @@

- +

diff --git a/README.md b/README.md index 79479da9..dd8238f0 100644 --- a/README.md +++ b/README.md @@ -870,6 +870,41 @@ rag = LightRAG( +
+ Using Memgraph for Storage + +* Memgraph is a high-performance, in-memory graph database compatible with the Neo4j Bolt protocol. +* You can run Memgraph locally using Docker for easy testing: +* See: https://memgraph.com/download + +```python +export MEMGRAPH_URI="bolt://localhost:7687" + +# Setup logger for LightRAG +setup_logger("lightrag", level="INFO") + +# When you launch the project, override the default KG: NetworkX +# by specifying kg="MemgraphStorage". + +# Note: Default settings use NetworkX +# Initialize LightRAG with Memgraph implementation. +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model + graph_storage="MemgraphStorage", #<-----------override KG default + ) + + # Initialize database connections + await rag.initialize_storages() + # Initialize pipeline status for document processing + await initialize_pipeline_status() + + return rag +``` + +
+ ## Edit Entities and Relations LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph. diff --git a/config.ini.example b/config.ini.example index 63d9c2c0..94d300a1 100644 --- a/config.ini.example +++ b/config.ini.example @@ -21,3 +21,6 @@ password = your_password database = your_database workspace = default # 可选,默认为default max_connections = 12 + +[memgraph] +uri = bolt://localhost:7687 diff --git a/env.example b/env.example index 874ebf3c..d8a4dfd6 100644 --- a/env.example +++ b/env.example @@ -147,13 +147,14 @@ EMBEDDING_BINDING_HOST=http://localhost:11434 # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage ### Graph Storage (Recommended for production deployment) # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage +# LIGHTRAG_GRAPH_STORAGE=MemgraphStorage #################################################################### ### Default workspace for all storage types ### For the purpose of isolation of data for each LightRAG instance ### Valid characters: a-z, A-Z, 0-9, and _ #################################################################### -# WORKSPACE=doc— +# WORKSPACE=space1 ### PostgreSQL Configuration POSTGRES_HOST=localhost @@ -192,3 +193,10 @@ QDRANT_URL=http://localhost:6333 ### Redis REDIS_URI=redis://localhost:6379 # REDIS_WORKSPACE=forced_workspace_name + +### Memgraph Configuration +MEMGRAPH_URI=bolt://localhost:7687 +MEMGRAPH_USERNAME= +MEMGRAPH_PASSWORD= +MEMGRAPH_DATABASE=memgraph +# MEMGRAPH_WORKSPACE=forced_workspace_name diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index 1f5fd56f..b2a93e82 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -15,6 +15,7 @@ STORAGE_IMPLEMENTATIONS = { "Neo4JStorage", "PGGraphStorage", "MongoGraphStorage", + "MemgraphStorage", # "AGEStorage", # "TiDBGraphStorage", # "GremlinStorage", @@ -57,6 +58,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "NetworkXStorage": [], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "MongoGraphStorage": [], + "MemgraphStorage": ["MEMGRAPH_URI"], # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "AGEStorage": [ "AGE_POSTGRES_DB", @@ -111,6 +113,7 @@ STORAGES = { "PGDocStatusStorage": ".kg.postgres_impl", "FaissVectorDBStorage": ".kg.faiss_impl", "QdrantVectorDBStorage": ".kg.qdrant_impl", + "MemgraphStorage": ".kg.memgraph_impl", } diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py new file mode 100644 index 00000000..8c6d6574 --- /dev/null +++ b/lightrag/kg/memgraph_impl.py @@ -0,0 +1,906 @@ +import os +from dataclasses import dataclass +from typing import final +import configparser + +from ..utils import logger +from ..base import BaseGraphStorage +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from ..constants import GRAPH_FIELD_SEP +import pipmaster as pm + +if not pm.is_installed("neo4j"): + pm.install("neo4j") + +from neo4j import ( + AsyncGraphDatabase, + AsyncManagedTransaction, +) + +from dotenv import load_dotenv + +# use the .env that is inside the current folder +load_dotenv(dotenv_path=".env", override=False) + +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + + +@final +@dataclass +class MemgraphStorage(BaseGraphStorage): + def __init__(self, namespace, global_config, embedding_func, workspace=None): + memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE") + if memgraph_workspace and memgraph_workspace.strip(): + workspace = memgraph_workspace + super().__init__( + namespace=namespace, + workspace=workspace or "", + global_config=global_config, + embedding_func=embedding_func, + ) + self._driver = None + + def _get_workspace_label(self) -> str: + """Get workspace label, return 'base' for compatibility when workspace is empty""" + workspace = getattr(self, "workspace", None) + return workspace if workspace else "base" + + async def initialize(self): + URI = os.environ.get( + "MEMGRAPH_URI", + config.get("memgraph", "uri", fallback="bolt://localhost:7687"), + ) + USERNAME = os.environ.get( + "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="") + ) + PASSWORD = os.environ.get( + "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="") + ) + DATABASE = os.environ.get( + "MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph") + ) + + self._driver = AsyncGraphDatabase.driver( + URI, + auth=(USERNAME, PASSWORD), + ) + self._DATABASE = DATABASE + try: + async with self._driver.session(database=DATABASE) as session: + # Create index for base nodes on entity_id if it doesn't exist + try: + workspace_label = self._get_workspace_label() + await session.run( + f"""CREATE INDEX ON :{workspace_label}(entity_id)""" + ) + logger.info( + f"Created index on :{workspace_label}(entity_id) in Memgraph." + ) + except Exception as e: + # Index may already exist, which is not an error + logger.warning( + f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}" + ) + await session.run("RETURN 1") + logger.info(f"Connected to Memgraph at {URI}") + except Exception as e: + logger.error(f"Failed to connect to Memgraph at {URI}: {e}") + raise + + async def finalize(self): + if self._driver is not None: + await self._driver.close() + self._driver = None + + async def __aexit__(self, exc_type, exc, tb): + await self.finalize() + + async def index_done_callback(self): + # Memgraph handles persistence automatically + pass + + async def has_node(self, node_id: str) -> bool: + """ + Check if a node exists in the graph. + + Args: + node_id: The ID of the node to check. + + Returns: + bool: True if the node exists, False otherwise. + + Raises: + Exception: If there is an error checking the node existence. + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" + result = await session.run(query, entity_id=node_id) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return ( + single_result["node_exists"] if single_result is not None else False + ) + except Exception as e: + logger.error(f"Error checking node existence for {node_id}: {str(e)}") + await result.consume() # Ensure the result is consumed even on error + raise + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """ + Check if an edge exists between two nodes in the graph. + + Args: + source_node_id: The ID of the source node. + target_node_id: The ID of the target node. + + Returns: + bool: True if the edge exists, False otherwise. + + Raises: + Exception: If there is an error checking the edge existence. + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = ( + f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) # type: ignore + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return ( + single_result["edgeExists"] if single_result is not None else False + ) + except Exception as e: + logger.error( + f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" + ) + await result.consume() # Ensure the result is consumed even on error + raise + + async def get_node(self, node_id: str) -> dict[str, str] | None: + """Get node by its label identifier, return only node properties + + Args: + node_id: The node label to look up + + Returns: + dict: Node properties if found + None: If node not found + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = ( + f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" + ) + result = await session.run(query, entity_id=node_id) + try: + records = await result.fetch( + 2 + ) # Get 2 records for duplication check + + if len(records) > 1: + logger.warning( + f"Multiple nodes found with label '{node_id}'. Using first node." + ) + if records: + node = records[0]["n"] + node_dict = dict(node) + # Remove workspace label from labels list if it exists + if "labels" in node_dict: + node_dict["labels"] = [ + label + for label in node_dict["labels"] + if label != workspace_label + ] + return node_dict + return None + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node for {node_id}: {str(e)}") + raise + + async def node_degree(self, node_id: str) -> int: + """Get the degree (number of relationships) of a node with the given label. + If multiple nodes have the same label, returns the degree of the first node. + If no node is found, returns 0. + + Args: + node_id: The label of the node + + Returns: + int: The number of relationships the node has, or 0 if no node found + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + OPTIONAL MATCH (n)-[r]-() + RETURN COUNT(r) AS degree + """ + result = await session.run(query, entity_id=node_id) + try: + record = await result.single() + + if not record: + logger.warning(f"No node found with label '{node_id}'") + return 0 + + degree = record["degree"] + return degree + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node degree for {node_id}: {str(e)}") + raise + + async def get_all_labels(self) -> list[str]: + """ + Get all existing node labels in the database + Returns: + ["Person", "Company", ...] # Alphabetically sorted label list + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}`) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY label + """ + result = await session.run(query) + labels = [] + async for record in result: + labels.append(record["label"]) + await result.consume() + return labels + except Exception as e: + logger.error(f"Error getting all labels: {str(e)}") + await result.consume() # Ensure the result is consumed even on error + raise + + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: + """Retrieves all edges (relationships) for a particular node identified by its label. + + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + try: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`) + WHERE connected.entity_id IS NOT NULL + RETURN n, r, connected""" + results = await session.run(query, entity_id=source_node_id) + + edges = [] + async for record in results: + source_node = record["n"] + connected_node = record["connected"] + + # Skip if either node is None + if not source_node or not connected_node: + continue + + source_label = ( + source_node.get("entity_id") + if source_node.get("entity_id") + else None + ) + target_label = ( + connected_node.get("entity_id") + if connected_node.get("entity_id") + else None + ) + + if source_label and target_label: + edges.append((source_label, target_label)) + + await results.consume() # Ensure results are consumed + return edges + except Exception as e: + logger.error( + f"Error getting edges for node {source_node_id}: {str(e)}" + ) + await results.consume() # Ensure results are consumed even on error + raise + except Exception as e: + logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") + raise + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> dict[str, str] | None: + """Get edge properties between two nodes. + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + dict: Edge properties if found, default properties if not found or on error + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = f""" + MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}}) + RETURN properties(r) as edge_properties + """ + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + records = await result.fetch(2) + await result.consume() + if records: + edge_result = dict(records[0]["edge_properties"]) + for key, default_value in { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + }.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}" + ) + return edge_result + return None + except Exception as e: + logger.error( + f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" + ) + await result.consume() # Ensure the result is consumed even on error + raise + + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + """ + Upsert a node in the Neo4j database. + + Args: + node_id: The unique identifier for the node (used as label) + node_data: Dictionary of node properties + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + properties = node_data + entity_type = properties["entity_type"] + if "entity_id" not in properties: + raise ValueError("Neo4j: node properties must contain an 'entity_id' field") + + try: + async with self._driver.session(database=self._DATABASE) as session: + workspace_label = self._get_workspace_label() + + async def execute_upsert(tx: AsyncManagedTransaction): + query = f""" + MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) + SET n += $properties + SET n:`{entity_type}` + """ + result = await tx.run( + query, entity_id=node_id, properties=properties + ) + await result.consume() # Ensure result is fully consumed + + await session.execute_write(execute_upsert) + except Exception as e: + logger.error(f"Error during upsert: {str(e)}") + raise + + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: + """ + Upsert an edge and its properties between two nodes identified by their labels. + Ensures both source and target nodes exist and are unique before creating the edge. + Uses entity_id property to uniquely identify nodes. + + Args: + source_node_id (str): Label of the source node (used as identifier) + target_node_id (str): Label of the target node (used as identifier) + edge_data (dict): Dictionary of properties to set on the edge + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + try: + edge_properties = edge_data + async with self._driver.session(database=self._DATABASE) as session: + + async def execute_upsert(tx: AsyncManagedTransaction): + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) + WITH source + MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + result = await tx.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + properties=edge_properties, + ) + try: + await result.fetch(2) + finally: + await result.consume() # Ensure result is consumed + + await session.execute_write(execute_upsert) + except Exception as e: + logger.error(f"Error during edge upsert: {str(e)}") + raise + + async def delete_node(self, node_id: str) -> None: + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + + async def _do_delete(tx: AsyncManagedTransaction): + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + DETACH DELETE n + """ + result = await tx.run(query, entity_id=node_id) + logger.debug(f"Deleted node with label {node_id}") + await result.consume() + + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete) + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise + + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + for node in nodes: + await self.delete_node(node) + + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + for source, target in edges: + + async def _do_delete_edge(tx: AsyncManagedTransaction): + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}}) + DELETE r + """ + result = await tx.run( + query, source_entity_id=source, target_entity_id=target + ) + logger.debug(f"Deleted edge from '{source}' to '{target}'") + await result.consume() # Ensure result is fully consumed + + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete_edge) + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise + + async def drop(self) -> dict[str, str]: + """Drop all data from the current workspace and clean up resources + + This method will delete all nodes and relationships in the Memgraph database. + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + try: + async with self._driver.session(database=self._DATABASE) as session: + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" + result = await session.run(query) + await result.consume() + logger.info( + f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" + ) + return {"status": "success", "message": "workspace data dropped"} + except Exception as e: + logger.error( + f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" + ) + return {"status": "error", "message": str(e)} + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """Get the total degree (sum of relationships) of two nodes. + + Args: + src_id: Label of the source node + tgt_id: Label of the target node + + Returns: + int: Sum of the degrees of both nodes + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + + # Convert None to 0 for addition + src_degree = 0 if src_degree is None else src_degree + trg_degree = 0 if trg_degree is None else trg_degree + + degrees = int(src_degree) + int(trg_degree) + return degrees + + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all nodes that are associated with the given chunk_ids. + + Args: + chunk_ids: List of chunk IDs to find associated nodes for + + Returns: + list[dict]: A list of nodes, where each node is a dictionary of its properties. + An empty list if no matching nodes are found. + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + UNWIND $chunk_ids AS chunk_id + MATCH (n:`{workspace_label}`) + WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) + RETURN DISTINCT n + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + nodes = [] + async for record in result: + node = record["n"] + node_dict = dict(node) + node_dict["id"] = node_dict.get("entity_id") + nodes.append(node_dict) + await result.consume() + return nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all edges that are associated with the given chunk_ids. + + Args: + chunk_ids: List of chunk IDs to find associated edges for + + Returns: + list[dict]: A list of edges, where each edge is a dictionary of its properties. + An empty list if no matching edges are found. + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + UNWIND $chunk_ids AS chunk_id + MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) + WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) + WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id + // Ensure we only return each unique edge once by ordering the source and target + WITH a, b, r, + CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source, + CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target + RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + edges = [] + async for record in result: + edge_properties = record["properties"] + edge_properties["source"] = record["source"] + edge_properties["target"] = record["target"] + edges.append(edge_properties) + await result.consume() + return edges + + async def get_knowledge_graph( + self, + node_label: str, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, + ) -> KnowledgeGraph: + """ + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + + Args: + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 + + Returns: + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + if node_label == "*": + # First check if database has any nodes + count_query = "MATCH (n) RETURN count(n) as total" + count_result = None + total_count = 0 + try: + count_result = await session.run(count_query) + count_record = await count_result.single() + if count_record: + total_count = count_record["total"] + if total_count == 0: + logger.debug("No nodes found in database") + return result + if total_count > max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: {total_count} nodes found, limited to {max_nodes}" + ) + finally: + if count_result: + await count_result.consume() + + # Run the main query to get nodes with highest degree + main_query = f""" + MATCH (n:`{workspace_label}`) + OPTIONAL MATCH (n)-[r]-() + WITH n, COALESCE(count(r), 0) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect(n) AS kept_nodes + MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN [node IN kept_nodes | {{node: node}}] AS node_info, + collect(DISTINCT r) AS relationships + """ + result_set = None + try: + result_set = await session.run( + main_query, {"max_nodes": max_nodes} + ) + record = await result_set.single() + if not record: + logger.debug("No record returned from main query") + return result + finally: + if result_set: + await result_set.consume() + + else: + bfs_query = f""" + MATCH (start:`{workspace_label}`) + WHERE start.entity_id = $entity_id + WITH start + CALL {{ + WITH start + MATCH path = (start)-[*0..{max_depth}]-(node) + WITH nodes(path) AS path_nodes, relationships(path) AS path_rels + UNWIND path_nodes AS n + WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists + WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels + RETURN all_nodes, all_rels + }} + WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes + WITH + CASE + WHEN total_nodes <= {max_nodes} THEN nodes + ELSE nodes[0..{max_nodes}] + END AS limited_nodes, + relationships, + total_nodes, + total_nodes > {max_nodes} AS is_truncated + RETURN + [node IN limited_nodes | {{node: node}}] AS node_info, + relationships, + total_nodes, + is_truncated + """ + result_set = None + try: + result_set = await session.run( + bfs_query, + { + "entity_id": node_label, + }, + ) + record = await result_set.single() + if not record: + logger.debug(f"No nodes found for entity_id: {node_label}") + return result + + # Check if the query indicates truncation + if "is_truncated" in record and record["is_truncated"]: + result.is_truncated = True + logger.info( + f"Graph truncated: breadth-first search limited to {max_nodes} nodes" + ) + + finally: + if result_set: + await result_set.consume() + + # Process the record if it exists + if record and record["node_info"]: + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + seen_nodes.add(node_id) + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=[node.get("entity_id")], + properties=dict(node), + ) + ) + + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) + ) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except Exception as e: + logger.error(f"Error getting knowledge graph: {str(e)}") + # Return empty but properly initialized KnowledgeGraph on error + return KnowledgeGraph() + + return result diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 57f016cf..eb74c2f1 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -210,9 +210,18 @@ async def openai_complete_if_cache( async def inner(): # Track if we've started iterating iteration_started = False + final_chunk_usage = None + try: iteration_started = True async for chunk in response: + # Check if this chunk has usage information (final chunk) + if hasattr(chunk, "usage") and chunk.usage: + final_chunk_usage = chunk.usage + logger.debug( + f"Received usage info in streaming chunk: {chunk.usage}" + ) + # Check if choices exists and is not empty if not hasattr(chunk, "choices") or not chunk.choices: logger.warning(f"Received chunk without choices: {chunk}") @@ -222,16 +231,31 @@ async def openai_complete_if_cache( if not hasattr(chunk.choices[0], "delta") or not hasattr( chunk.choices[0].delta, "content" ): - logger.warning( - f"Received chunk without delta content: {chunk.choices[0]}" - ) + # This might be the final chunk, continue to check for usage continue + content = chunk.choices[0].delta.content if content is None: continue if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) + yield content + + # After streaming is complete, track token usage + if token_tracker and final_chunk_usage: + # Use actual usage from the API + token_counts = { + "prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0), + "completion_tokens": getattr( + final_chunk_usage, "completion_tokens", 0 + ), + "total_tokens": getattr(final_chunk_usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + logger.debug(f"Streaming token usage (from API): {token_counts}") + elif token_tracker: + logger.debug("No usage information available in streaming response") except Exception as e: logger.error(f"Error in stream response: {str(e)}") # Try to clean up resources if possible diff --git a/lightrag/operate.py b/lightrag/operate.py index 05fef78e..e2251fc7 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -26,6 +26,7 @@ from .utils import ( get_conversation_turns, use_llm_func_with_cache, update_chunk_cache_list, + remove_think_tags, ) from .base import ( BaseGraphStorage, @@ -1704,7 +1705,8 @@ async def extract_keywords_only( result = await use_model_func(kw_prompt, keyword_extraction=True) # 6. Parse out JSON from the LLM response - match = re.search(r"\{.*\}", result, re.DOTALL) + result = remove_think_tags(result) + match = re.search(r"\{.*?\}", result, re.DOTALL) if not match: logger.error("No JSON-like structure found in the LLM respond.") return [], [] diff --git a/lightrag/utils.py b/lightrag/utils.py index c6e2def9..386de3ab 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1465,6 +1465,11 @@ async def update_chunk_cache_list( ) +def remove_think_tags(text: str) -> str: + """Remove tags from the text""" + return re.sub(r"^(.*?|)", "", text, flags=re.DOTALL).strip() + + async def use_llm_func_with_cache( input_text: str, use_llm_func: callable, @@ -1531,6 +1536,7 @@ async def use_llm_func_with_cache( kwargs["max_tokens"] = max_tokens res: str = await use_llm_func(input_text, **kwargs) + res = remove_think_tags(res) if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"): await save_to_cache( @@ -1557,8 +1563,9 @@ async def use_llm_func_with_cache( if max_tokens is not None: kwargs["max_tokens"] = max_tokens - logger.info(f"Call LLM function with query text lenght: {len(input_text)}") - return await use_llm_func(input_text, **kwargs) + logger.info(f"Call LLM function with query text length: {len(input_text)}") + res = await use_llm_func(input_text, **kwargs) + return remove_think_tags(res) def get_content_summary(content: str, max_length: int = 250) -> str: diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index 258c8795..62f658ff 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -10,6 +10,7 @@ - Neo4JStorage - MongoDBStorage - PGGraphStorage +- MemgraphStorage """ import asyncio