Feat: Add WORKSPACE support to all storage types
This commit is contained in:
parent
1b2d295a4f
commit
033098c1bc
17 changed files with 566 additions and 137 deletions
52
env.example
52
env.example
|
|
@ -111,25 +111,37 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
|
|||
###########################
|
||||
### Data storage selection
|
||||
###########################
|
||||
### In-memory database with data persistence to local files
|
||||
# LIGHTRAG_KV_STORAGE=JsonKVStorage
|
||||
# LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage
|
||||
# LIGHTRAG_GRAPH_STORAGE=NetworkXStorage
|
||||
# LIGHTRAG_VECTOR_STORAGE=NanoVectorDBStorage
|
||||
# LIGHTRAG_VECTOR_STORAGE=FaissVectorDBStorage
|
||||
### PostgreSQL
|
||||
# LIGHTRAG_KV_STORAGE=PGKVStorage
|
||||
# LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
|
||||
# LIGHTRAG_VECTOR_STORAGE=PGVectorStorage
|
||||
# LIGHTRAG_GRAPH_STORAGE=PGGraphStorage
|
||||
### MongoDB
|
||||
# LIGHTRAG_VECTOR_STORAGE=PGVectorStorage
|
||||
### MongoDB (recommended for production deploy)
|
||||
# LIGHTRAG_KV_STORAGE=MongoKVStorage
|
||||
# LIGHTRAG_DOC_STATUS_STORAGE=MongoDocStatusStorage
|
||||
# LIGHTRAG_VECTOR_STORAGE=MongoVectorDBStorage
|
||||
# LIGHTRAG_GRAPH_STORAGE=MongoGraphStorage
|
||||
### KV Storage
|
||||
# LIGHTRAG_VECTOR_STORAGE=MongoVectorDBStorage
|
||||
### Redis Storage (recommended for production deploy)
|
||||
# LIGHTRAG_KV_STORAGE=RedisKVStorage
|
||||
# LIGHTRAG_DOC_STATUS_STORAGE=RedisDocStatusStorage
|
||||
### Vector Storage
|
||||
# LIGHTRAG_VECTOR_STORAGE=FaissVectorDBStorage
|
||||
### Vector Storage (recommended for production deploy)
|
||||
# LIGHTRAG_VECTOR_STORAGE=MilvusVectorDBStorage
|
||||
### Graph Storage
|
||||
# LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage
|
||||
### Graph Storage (recommended for production deploy)
|
||||
# LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
|
||||
# LIGHTRAG_GRAPH_STORAGE=MemgraphStorage
|
||||
|
||||
####################################################################
|
||||
### Default workspace for all storage types
|
||||
### For the purpose of isolation of data for each LightRAG instance
|
||||
### Valid characters: a-z, A-Z, 0-9, and _
|
||||
####################################################################
|
||||
# WORKSPACE=doc—
|
||||
|
||||
### PostgreSQL Configuration
|
||||
POSTGRES_HOST=localhost
|
||||
|
|
@ -138,31 +150,18 @@ POSTGRES_USER=your_username
|
|||
POSTGRES_PASSWORD='your_password'
|
||||
POSTGRES_DATABASE=your_database
|
||||
POSTGRES_MAX_CONNECTIONS=12
|
||||
### separating all data from difference Lightrag instances
|
||||
# POSTGRES_WORKSPACE=default
|
||||
# POSTGRES_WORKSPACE=forced_workspace_name
|
||||
|
||||
### Neo4j Configuration
|
||||
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
||||
NEO4J_USERNAME=neo4j
|
||||
NEO4J_PASSWORD='your_password'
|
||||
|
||||
### Independent AGM Configuration(not for AMG embedded in PostreSQL)
|
||||
# AGE_POSTGRES_DB=
|
||||
# AGE_POSTGRES_USER=
|
||||
# AGE_POSTGRES_PASSWORD=
|
||||
# AGE_POSTGRES_HOST=
|
||||
# AGE_POSTGRES_PORT=8529
|
||||
|
||||
# AGE Graph Name(apply to PostgreSQL and independent AGM)
|
||||
### AGE_GRAPH_NAME is deprecated
|
||||
# AGE_GRAPH_NAME=lightrag
|
||||
|
||||
### MongoDB Configuration
|
||||
MONGO_URI=mongodb://root:root@localhost:27017/
|
||||
#MONGO_URI=mongodb+srv://root:rooot@cluster0.xxxx.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0
|
||||
MONGO_DATABASE=LightRAG
|
||||
### separating all data from difference Lightrag instances(deprecating)
|
||||
### separating all data from difference Lightrag instances
|
||||
# MONGODB_WORKSPACE=default
|
||||
# MONGODB_WORKSPACE=forced_workspace_name
|
||||
|
||||
### Milvus Configuration
|
||||
MILVUS_URI=http://localhost:19530
|
||||
|
|
@ -170,10 +169,13 @@ MILVUS_DB_NAME=lightrag
|
|||
# MILVUS_USER=root
|
||||
# MILVUS_PASSWORD=your_password
|
||||
# MILVUS_TOKEN=your_token
|
||||
# MILVUS_WORKSPACE=forced_workspace_name
|
||||
|
||||
### Qdrant
|
||||
QDRANT_URL=http://localhost:16333
|
||||
QDRANT_URL=http://localhost:6333
|
||||
# QDRANT_API_KEY=your-api-key
|
||||
# QDRANT_WORKSPACE=forced_workspace_name
|
||||
|
||||
### Redis
|
||||
REDIS_URI=redis://localhost:6379
|
||||
# REDIS_WORKSPACE=forced_workspace_name
|
||||
|
|
|
|||
|
|
@ -184,10 +184,10 @@ def parse_args() -> argparse.Namespace:
|
|||
|
||||
# Namespace
|
||||
parser.add_argument(
|
||||
"--namespace-prefix",
|
||||
"--workspace",
|
||||
type=str,
|
||||
default=get_env_value("NAMESPACE_PREFIX", ""),
|
||||
help="Prefix of the namespace",
|
||||
default=get_env_value("WORKSPACE", ""),
|
||||
help="Default workspace for all storage",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
|
|||
|
|
@ -112,8 +112,8 @@ def create_app(args):
|
|||
# Check if API key is provided either through env var or args
|
||||
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
|
||||
|
||||
# Initialize document manager
|
||||
doc_manager = DocumentManager(args.input_dir)
|
||||
# Initialize document manager with workspace support for data isolation
|
||||
doc_manager = DocumentManager(args.input_dir, workspace=args.workspace)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
|
@ -295,6 +295,7 @@ def create_app(args):
|
|||
if args.llm_binding in ["lollms", "ollama", "openai"]:
|
||||
rag = LightRAG(
|
||||
working_dir=args.working_dir,
|
||||
workspace=args.workspace,
|
||||
llm_model_func=lollms_model_complete
|
||||
if args.llm_binding == "lollms"
|
||||
else ollama_model_complete
|
||||
|
|
@ -330,6 +331,7 @@ def create_app(args):
|
|||
else: # azure_openai
|
||||
rag = LightRAG(
|
||||
working_dir=args.working_dir,
|
||||
workspace=args.workspace,
|
||||
llm_model_func=azure_openai_model_complete,
|
||||
chunk_token_size=int(args.chunk_size),
|
||||
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
||||
|
|
|
|||
|
|
@ -475,6 +475,7 @@ class DocumentManager:
|
|||
def __init__(
|
||||
self,
|
||||
input_dir: str,
|
||||
workspace: str = "", # New parameter for workspace isolation
|
||||
supported_extensions: tuple = (
|
||||
".txt",
|
||||
".md",
|
||||
|
|
@ -515,10 +516,19 @@ class DocumentManager:
|
|||
".less", # LESS CSS
|
||||
),
|
||||
):
|
||||
self.input_dir = Path(input_dir)
|
||||
# Store the base input directory and workspace
|
||||
self.base_input_dir = Path(input_dir)
|
||||
self.workspace = workspace
|
||||
self.supported_extensions = supported_extensions
|
||||
self.indexed_files = set()
|
||||
|
||||
# Create workspace-specific input directory
|
||||
# If workspace is provided, create a subdirectory for data isolation
|
||||
if workspace:
|
||||
self.input_dir = self.base_input_dir / workspace
|
||||
else:
|
||||
self.input_dir = self.base_input_dir
|
||||
|
||||
# Create input directory if it doesn't exist
|
||||
self.input_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -716,7 +726,9 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
|||
if content:
|
||||
# Check if content contains only whitespace characters
|
||||
if not content.strip():
|
||||
logger.warning(f"File contains only whitespace characters. file_paths={file_path.name}")
|
||||
logger.warning(
|
||||
f"File contains only whitespace characters. file_paths={file_path.name}"
|
||||
)
|
||||
|
||||
await rag.apipeline_enqueue_documents(content, file_paths=file_path.name)
|
||||
logger.info(f"Successfully fetched and enqueued file: {file_path.name}")
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ class QueryParam:
|
|||
@dataclass
|
||||
class StorageNameSpace(ABC):
|
||||
namespace: str
|
||||
workspace: str
|
||||
global_config: dict[str, Any]
|
||||
|
||||
async def initialize(self):
|
||||
|
|
|
|||
|
|
@ -38,9 +38,19 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
# Where to save index file if you want persistent storage
|
||||
self._faiss_index_file = os.path.join(
|
||||
self.global_config["working_dir"], f"faiss_index_{self.namespace}.index"
|
||||
)
|
||||
working_dir = self.global_config["working_dir"]
|
||||
if self.workspace:
|
||||
# Include workspace in the file path for data isolation
|
||||
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
self._faiss_index_file = os.path.join(
|
||||
workspace_dir, f"faiss_index_{self.namespace}.index"
|
||||
)
|
||||
else:
|
||||
# Default behavior when workspace is empty
|
||||
self._faiss_index_file = os.path.join(
|
||||
working_dir, f"faiss_index_{self.namespace}.index"
|
||||
)
|
||||
self._meta_file = self._faiss_index_file + ".meta.json"
|
||||
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
|
|
|||
|
|
@ -30,7 +30,18 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
if self.workspace:
|
||||
# Include workspace in the file path for data isolation
|
||||
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
self._file_name = os.path.join(
|
||||
workspace_dir, f"kv_store_{self.namespace}.json"
|
||||
)
|
||||
else:
|
||||
# Default behavior when workspace is empty
|
||||
self._file_name = os.path.join(
|
||||
working_dir, f"kv_store_{self.namespace}.json"
|
||||
)
|
||||
self._data = None
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
|
|
|
|||
|
|
@ -26,7 +26,18 @@ from .shared_storage import (
|
|||
class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
if self.workspace:
|
||||
# Include workspace in the file path for data isolation
|
||||
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
self._file_name = os.path.join(
|
||||
workspace_dir, f"kv_store_{self.namespace}.json"
|
||||
)
|
||||
else:
|
||||
# Default behavior when workspace is empty
|
||||
self._file_name = os.path.join(
|
||||
working_dir, f"kv_store_{self.namespace}.json"
|
||||
)
|
||||
self._data = None
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
|
|
|
|||
|
|
@ -7,10 +7,6 @@ from lightrag.utils import logger, compute_mdhash_id
|
|||
from ..base import BaseVectorStorage
|
||||
import pipmaster as pm
|
||||
|
||||
|
||||
if not pm.is_installed("configparser"):
|
||||
pm.install("configparser")
|
||||
|
||||
if not pm.is_installed("pymilvus"):
|
||||
pm.install("pymilvus")
|
||||
|
||||
|
|
@ -660,6 +656,29 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
|||
raise
|
||||
|
||||
def __post_init__(self):
|
||||
# Check for MILVUS_WORKSPACE environment variable first (higher priority)
|
||||
# This allows administrators to force a specific workspace for all Milvus storage instances
|
||||
milvus_workspace = os.environ.get("MILVUS_WORKSPACE")
|
||||
if milvus_workspace and milvus_workspace.strip():
|
||||
# Use environment variable value, overriding the passed workspace parameter
|
||||
effective_workspace = milvus_workspace.strip()
|
||||
logger.info(
|
||||
f"Using MILVUS_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
|
||||
)
|
||||
else:
|
||||
# Use the workspace parameter passed during initialization
|
||||
effective_workspace = self.workspace
|
||||
if effective_workspace:
|
||||
logger.debug(
|
||||
f"Using passed workspace parameter: '{effective_workspace}'"
|
||||
)
|
||||
|
||||
# Build namespace with workspace prefix for data isolation
|
||||
if effective_workspace:
|
||||
self.namespace = f"{effective_workspace}_{self.namespace}"
|
||||
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
|
||||
# When workspace is empty, keep the original namespace unchanged
|
||||
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ if not pm.is_installed("pymongo"):
|
|||
pm.install("pymongo")
|
||||
|
||||
from pymongo import AsyncMongoClient # type: ignore
|
||||
from pymongo import UpdateOne # type: ignore
|
||||
from pymongo.asynchronous.database import AsyncDatabase # type: ignore
|
||||
from pymongo.asynchronous.collection import AsyncCollection # type: ignore
|
||||
from pymongo.operations import SearchIndexModel # type: ignore
|
||||
|
|
@ -81,7 +82,39 @@ class MongoKVStorage(BaseKVStorage):
|
|||
db: AsyncDatabase = field(default=None)
|
||||
_data: AsyncCollection = field(default=None)
|
||||
|
||||
def __init__(self, namespace, global_config, embedding_func, workspace=None):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
workspace=workspace or "",
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(self):
|
||||
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
|
||||
# This allows administrators to force a specific workspace for all MongoDB storage instances
|
||||
mongodb_workspace = os.environ.get("MONGODB_WORKSPACE")
|
||||
if mongodb_workspace and mongodb_workspace.strip():
|
||||
# Use environment variable value, overriding the passed workspace parameter
|
||||
effective_workspace = mongodb_workspace.strip()
|
||||
logger.info(
|
||||
f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
|
||||
)
|
||||
else:
|
||||
# Use the workspace parameter passed during initialization
|
||||
effective_workspace = self.workspace
|
||||
if effective_workspace:
|
||||
logger.debug(
|
||||
f"Using passed workspace parameter: '{effective_workspace}'"
|
||||
)
|
||||
|
||||
# Build namespace with workspace prefix for data isolation
|
||||
if effective_workspace:
|
||||
self.namespace = f"{effective_workspace}_{self.namespace}"
|
||||
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
|
||||
# When workspace is empty, keep the original namespace unchanged
|
||||
|
||||
self._collection_name = self.namespace
|
||||
|
||||
async def initialize(self):
|
||||
|
|
@ -142,7 +175,6 @@ class MongoKVStorage(BaseKVStorage):
|
|||
|
||||
# Unified handling for all namespaces with flattened keys
|
||||
# Use bulk_write for better performance
|
||||
from pymongo import UpdateOne
|
||||
|
||||
operations = []
|
||||
current_time = int(time.time()) # Get current Unix timestamp
|
||||
|
|
@ -252,7 +284,39 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|||
db: AsyncDatabase = field(default=None)
|
||||
_data: AsyncCollection = field(default=None)
|
||||
|
||||
def __init__(self, namespace, global_config, embedding_func, workspace=None):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
workspace=workspace or "",
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(self):
|
||||
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
|
||||
# This allows administrators to force a specific workspace for all MongoDB storage instances
|
||||
mongodb_workspace = os.environ.get("MONGODB_WORKSPACE")
|
||||
if mongodb_workspace and mongodb_workspace.strip():
|
||||
# Use environment variable value, overriding the passed workspace parameter
|
||||
effective_workspace = mongodb_workspace.strip()
|
||||
logger.info(
|
||||
f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
|
||||
)
|
||||
else:
|
||||
# Use the workspace parameter passed during initialization
|
||||
effective_workspace = self.workspace
|
||||
if effective_workspace:
|
||||
logger.debug(
|
||||
f"Using passed workspace parameter: '{effective_workspace}'"
|
||||
)
|
||||
|
||||
# Build namespace with workspace prefix for data isolation
|
||||
if effective_workspace:
|
||||
self.namespace = f"{effective_workspace}_{self.namespace}"
|
||||
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
|
||||
# When workspace is empty, keep the original namespace unchanged
|
||||
|
||||
self._collection_name = self.namespace
|
||||
|
||||
async def initialize(self):
|
||||
|
|
@ -367,12 +431,36 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
# edge collection storing source_node_id, target_node_id, and edge_properties
|
||||
edgeCollection: AsyncCollection = field(default=None)
|
||||
|
||||
def __init__(self, namespace, global_config, embedding_func):
|
||||
def __init__(self, namespace, global_config, embedding_func, workspace=None):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
workspace=workspace or "",
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
|
||||
# This allows administrators to force a specific workspace for all MongoDB storage instances
|
||||
mongodb_workspace = os.environ.get("MONGODB_WORKSPACE")
|
||||
if mongodb_workspace and mongodb_workspace.strip():
|
||||
# Use environment variable value, overriding the passed workspace parameter
|
||||
effective_workspace = mongodb_workspace.strip()
|
||||
logger.info(
|
||||
f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
|
||||
)
|
||||
else:
|
||||
# Use the workspace parameter passed during initialization
|
||||
effective_workspace = self.workspace
|
||||
if effective_workspace:
|
||||
logger.debug(
|
||||
f"Using passed workspace parameter: '{effective_workspace}'"
|
||||
)
|
||||
|
||||
# Build namespace with workspace prefix for data isolation
|
||||
if effective_workspace:
|
||||
self.namespace = f"{effective_workspace}_{self.namespace}"
|
||||
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
|
||||
# When workspace is empty, keep the original namespace unchanged
|
||||
|
||||
self._collection_name = self.namespace
|
||||
self._edge_collection_name = f"{self._collection_name}_edges"
|
||||
|
||||
|
|
@ -1231,7 +1319,42 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|||
db: AsyncDatabase | None = field(default=None)
|
||||
_data: AsyncCollection | None = field(default=None)
|
||||
|
||||
def __init__(
|
||||
self, namespace, global_config, embedding_func, workspace=None, meta_fields=None
|
||||
):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
workspace=workspace or "",
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
meta_fields=meta_fields or set(),
|
||||
)
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(self):
|
||||
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
|
||||
# This allows administrators to force a specific workspace for all MongoDB storage instances
|
||||
mongodb_workspace = os.environ.get("MONGODB_WORKSPACE")
|
||||
if mongodb_workspace and mongodb_workspace.strip():
|
||||
# Use environment variable value, overriding the passed workspace parameter
|
||||
effective_workspace = mongodb_workspace.strip()
|
||||
logger.info(
|
||||
f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
|
||||
)
|
||||
else:
|
||||
# Use the workspace parameter passed during initialization
|
||||
effective_workspace = self.workspace
|
||||
if effective_workspace:
|
||||
logger.debug(
|
||||
f"Using passed workspace parameter: '{effective_workspace}'"
|
||||
)
|
||||
|
||||
# Build namespace with workspace prefix for data isolation
|
||||
if effective_workspace:
|
||||
self.namespace = f"{effective_workspace}_{self.namespace}"
|
||||
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
|
||||
# When workspace is empty, keep the original namespace unchanged
|
||||
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
|
|
|
|||
|
|
@ -41,9 +41,19 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
working_dir = self.global_config["working_dir"]
|
||||
if self.workspace:
|
||||
# Include workspace in the file path for data isolation
|
||||
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
self._client_file_name = os.path.join(
|
||||
workspace_dir, f"vdb_{self.namespace}.json"
|
||||
)
|
||||
else:
|
||||
# Default behavior when workspace is empty
|
||||
self._client_file_name = os.path.join(
|
||||
working_dir, f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
self._client = NanoVectorDB(
|
||||
|
|
|
|||
|
|
@ -50,14 +50,20 @@ logging.getLogger("neo4j").setLevel(logging.ERROR)
|
|||
@final
|
||||
@dataclass
|
||||
class Neo4JStorage(BaseGraphStorage):
|
||||
def __init__(self, namespace, global_config, embedding_func):
|
||||
def __init__(self, namespace, global_config, embedding_func, workspace=None):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
workspace=workspace or "",
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self._driver = None
|
||||
|
||||
def _get_workspace_label(self) -> str:
|
||||
"""Get workspace label, return 'base' for compatibility when workspace is empty"""
|
||||
workspace = getattr(self, "workspace", None)
|
||||
return workspace if workspace else "base"
|
||||
|
||||
async def initialize(self):
|
||||
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
||||
USERNAME = os.environ.get(
|
||||
|
|
@ -153,13 +159,14 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
raise e
|
||||
|
||||
if connected:
|
||||
# Create index for base nodes on entity_id if it doesn't exist
|
||||
# Create index for workspace nodes on entity_id if it doesn't exist
|
||||
workspace_label = self._get_workspace_label()
|
||||
try:
|
||||
async with self._driver.session(database=database) as session:
|
||||
# Check if index exists first
|
||||
check_query = """
|
||||
check_query = f"""
|
||||
CALL db.indexes() YIELD name, labelsOrTypes, properties
|
||||
WHERE labelsOrTypes = ['base'] AND properties = ['entity_id']
|
||||
WHERE labelsOrTypes = ['{workspace_label}'] AND properties = ['entity_id']
|
||||
RETURN count(*) > 0 AS exists
|
||||
"""
|
||||
try:
|
||||
|
|
@ -172,16 +179,16 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
if not index_exists:
|
||||
# Create index only if it doesn't exist
|
||||
result = await session.run(
|
||||
"CREATE INDEX FOR (n:base) ON (n.entity_id)"
|
||||
f"CREATE INDEX FOR (n:`{workspace_label}`) ON (n.entity_id)"
|
||||
)
|
||||
await result.consume()
|
||||
logger.info(
|
||||
f"Created index for base nodes on entity_id in {database}"
|
||||
f"Created index for {workspace_label} nodes on entity_id in {database}"
|
||||
)
|
||||
except Exception:
|
||||
# Fallback if db.indexes() is not supported in this Neo4j version
|
||||
result = await session.run(
|
||||
"CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)"
|
||||
f"CREATE INDEX IF NOT EXISTS FOR (n:`{workspace_label}`) ON (n.entity_id)"
|
||||
)
|
||||
await result.consume()
|
||||
except Exception as e:
|
||||
|
|
@ -216,11 +223,12 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
ValueError: If node_id is invalid
|
||||
Exception: If there is an error executing the query
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
try:
|
||||
query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
|
||||
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
|
||||
result = await session.run(query, entity_id=node_id)
|
||||
single_result = await result.single()
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
|
@ -245,12 +253,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
ValueError: If either node_id is invalid
|
||||
Exception: If there is an error executing the query
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
try:
|
||||
query = (
|
||||
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
|
||||
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
|
||||
"RETURN COUNT(r) > 0 AS edgeExists"
|
||||
)
|
||||
result = await session.run(
|
||||
|
|
@ -282,11 +291,14 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
ValueError: If node_id is invalid
|
||||
Exception: If there is an error executing the query
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
try:
|
||||
query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
|
||||
query = (
|
||||
f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
|
||||
)
|
||||
result = await session.run(query, entity_id=node_id)
|
||||
try:
|
||||
records = await result.fetch(
|
||||
|
|
@ -300,12 +312,12 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
if records:
|
||||
node = records[0]["n"]
|
||||
node_dict = dict(node)
|
||||
# Remove base label from labels list if it exists
|
||||
# Remove workspace label from labels list if it exists
|
||||
if "labels" in node_dict:
|
||||
node_dict["labels"] = [
|
||||
label
|
||||
for label in node_dict["labels"]
|
||||
if label != "base"
|
||||
if label != workspace_label
|
||||
]
|
||||
# logger.debug(f"Neo4j query node {query} return: {node_dict}")
|
||||
return node_dict
|
||||
|
|
@ -326,12 +338,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
Returns:
|
||||
A dictionary mapping each node_id to its node data (or None if not found).
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
query = f"""
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
MATCH (n:`{workspace_label}` {{entity_id: id}})
|
||||
RETURN n.entity_id AS entity_id, n
|
||||
"""
|
||||
result = await session.run(query, node_ids=node_ids)
|
||||
|
|
@ -340,10 +353,12 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
entity_id = record["entity_id"]
|
||||
node = record["n"]
|
||||
node_dict = dict(node)
|
||||
# Remove the 'base' label if present in a 'labels' property
|
||||
# Remove the workspace label if present in a 'labels' property
|
||||
if "labels" in node_dict:
|
||||
node_dict["labels"] = [
|
||||
label for label in node_dict["labels"] if label != "base"
|
||||
label
|
||||
for label in node_dict["labels"]
|
||||
if label != workspace_label
|
||||
]
|
||||
nodes[entity_id] = node_dict
|
||||
await result.consume() # Make sure to consume the result fully
|
||||
|
|
@ -364,12 +379,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
ValueError: If node_id is invalid
|
||||
Exception: If there is an error executing the query
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
try:
|
||||
query = """
|
||||
MATCH (n:base {entity_id: $entity_id})
|
||||
query = f"""
|
||||
MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||
OPTIONAL MATCH (n)-[r]-()
|
||||
RETURN COUNT(r) AS degree
|
||||
"""
|
||||
|
|
@ -403,13 +419,14 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
A dictionary mapping each node_id to its degree (number of relationships).
|
||||
If a node is not found, its degree will be set to 0.
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
query = f"""
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
RETURN n.entity_id AS entity_id, count { (n)--() } AS degree;
|
||||
MATCH (n:`{workspace_label}` {{entity_id: id}})
|
||||
RETURN n.entity_id AS entity_id, count {{ (n)--() }} AS degree;
|
||||
"""
|
||||
result = await session.run(query, node_ids=node_ids)
|
||||
degrees = {}
|
||||
|
|
@ -489,12 +506,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
ValueError: If either node_id is invalid
|
||||
Exception: If there is an error executing the query
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
try:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
|
||||
query = f"""
|
||||
MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
|
||||
RETURN properties(r) as edge_properties
|
||||
"""
|
||||
result = await session.run(
|
||||
|
|
@ -571,12 +589,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
Returns:
|
||||
A dictionary mapping (src, tgt) tuples to their edge properties.
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
query = f"""
|
||||
UNWIND $pairs AS pair
|
||||
MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt})
|
||||
MATCH (start:`{workspace_label}` {{entity_id: pair.src}})-[r:DIRECTED]-(end:`{workspace_label}` {{entity_id: pair.tgt}})
|
||||
RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges
|
||||
"""
|
||||
result = await session.run(query, pairs=pairs)
|
||||
|
|
@ -627,8 +646,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
try:
|
||||
query = """MATCH (n:base {entity_id: $entity_id})
|
||||
OPTIONAL MATCH (n)-[r]-(connected:base)
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||
OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
|
||||
WHERE connected.entity_id IS NOT NULL
|
||||
RETURN n, r, connected"""
|
||||
results = await session.run(query, entity_id=source_node_id)
|
||||
|
|
@ -689,10 +709,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
# Query to get both outgoing and incoming edges
|
||||
query = """
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
OPTIONAL MATCH (n)-[r]-(connected:base)
|
||||
MATCH (n:`{workspace_label}` {{entity_id: id}})
|
||||
OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
|
||||
RETURN id AS queried_id, n.entity_id AS node_entity_id,
|
||||
connected.entity_id AS connected_entity_id,
|
||||
startNode(r).entity_id AS start_entity_id
|
||||
|
|
@ -727,12 +748,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
return edges_dict
|
||||
|
||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
query = f"""
|
||||
UNWIND $chunk_ids AS chunk_id
|
||||
MATCH (n:base)
|
||||
MATCH (n:`{workspace_label}`)
|
||||
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
|
||||
RETURN DISTINCT n
|
||||
"""
|
||||
|
|
@ -748,12 +770,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
return nodes
|
||||
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
query = f"""
|
||||
UNWIND $chunk_ids AS chunk_id
|
||||
MATCH (a:base)-[r]-(b:base)
|
||||
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
|
||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
||||
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
||||
"""
|
||||
|
|
@ -787,6 +810,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
node_id: The unique identifier for the node (used as label)
|
||||
node_data: Dictionary of node properties
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
properties = node_data
|
||||
entity_type = properties["entity_type"]
|
||||
if "entity_id" not in properties:
|
||||
|
|
@ -796,14 +820,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
query = (
|
||||
"""
|
||||
MERGE (n:base {entity_id: $entity_id})
|
||||
query = f"""
|
||||
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||
SET n += $properties
|
||||
SET n:`%s`
|
||||
SET n:`{entity_type}`
|
||||
"""
|
||||
% entity_type
|
||||
)
|
||||
result = await tx.run(
|
||||
query, entity_id=node_id, properties=properties
|
||||
)
|
||||
|
|
@ -847,10 +868,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
query = """
|
||||
MATCH (source:base {entity_id: $source_entity_id})
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""
|
||||
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
|
||||
WITH source
|
||||
MATCH (target:base {entity_id: $target_entity_id})
|
||||
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
|
||||
MERGE (source)-[r:DIRECTED]-(target)
|
||||
SET r += $properties
|
||||
RETURN r, source, target
|
||||
|
|
@ -889,6 +911,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
|
@ -899,7 +922,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
try:
|
||||
if node_label == "*":
|
||||
# First check total node count to determine if graph is truncated
|
||||
count_query = "MATCH (n) RETURN count(n) as total"
|
||||
count_query = (
|
||||
f"MATCH (n:`{workspace_label}`) RETURN count(n) as total"
|
||||
)
|
||||
count_result = None
|
||||
try:
|
||||
count_result = await session.run(count_query)
|
||||
|
|
@ -915,13 +940,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
await count_result.consume()
|
||||
|
||||
# Run main query to get nodes with highest degree
|
||||
main_query = """
|
||||
MATCH (n)
|
||||
main_query = f"""
|
||||
MATCH (n:`{workspace_label}`)
|
||||
OPTIONAL MATCH (n)-[r]-()
|
||||
WITH n, COALESCE(count(r), 0) AS degree
|
||||
ORDER BY degree DESC
|
||||
LIMIT $max_nodes
|
||||
WITH collect({node: n}) AS filtered_nodes
|
||||
WITH collect({{node: n}}) AS filtered_nodes
|
||||
UNWIND filtered_nodes AS node_info
|
||||
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
|
||||
OPTIONAL MATCH (a)-[r]-(b)
|
||||
|
|
@ -943,20 +968,21 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
else:
|
||||
# return await self._robust_fallback(node_label, max_depth, max_nodes)
|
||||
# First try without limit to check if we need to truncate
|
||||
full_query = """
|
||||
MATCH (start)
|
||||
full_query = f"""
|
||||
MATCH (start:`{workspace_label}`)
|
||||
WHERE start.entity_id = $entity_id
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {
|
||||
CALL apoc.path.subgraphAll(start, {{
|
||||
relationshipFilter: '',
|
||||
labelFilter: '{workspace_label}',
|
||||
minLevel: 0,
|
||||
maxLevel: $max_depth,
|
||||
bfs: true
|
||||
})
|
||||
}})
|
||||
YIELD nodes, relationships
|
||||
WITH nodes, relationships, size(nodes) AS total_nodes
|
||||
UNWIND nodes AS node
|
||||
WITH collect({node: node}) AS node_info, relationships, total_nodes
|
||||
WITH collect({{node: node}}) AS node_info, relationships, total_nodes
|
||||
RETURN node_info, relationships, total_nodes
|
||||
"""
|
||||
|
||||
|
|
@ -994,20 +1020,21 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
)
|
||||
|
||||
# Run limited query
|
||||
limited_query = """
|
||||
MATCH (start)
|
||||
limited_query = f"""
|
||||
MATCH (start:`{workspace_label}`)
|
||||
WHERE start.entity_id = $entity_id
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {
|
||||
CALL apoc.path.subgraphAll(start, {{
|
||||
relationshipFilter: '',
|
||||
labelFilter: '{workspace_label}',
|
||||
minLevel: 0,
|
||||
maxLevel: $max_depth,
|
||||
limit: $max_nodes,
|
||||
bfs: true
|
||||
})
|
||||
}})
|
||||
YIELD nodes, relationships
|
||||
UNWIND nodes AS node
|
||||
WITH collect({node: node}) AS node_info, relationships
|
||||
WITH collect({{node: node}}) AS node_info, relationships
|
||||
RETURN node_info, relationships
|
||||
"""
|
||||
result_set = None
|
||||
|
|
@ -1094,11 +1121,12 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
visited_edge_pairs = set()
|
||||
|
||||
# Get the starting node's data
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
MATCH (n:base {entity_id: $entity_id})
|
||||
query = f"""
|
||||
MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||
RETURN id(n) as node_id, n
|
||||
"""
|
||||
node_result = await session.run(query, entity_id=node_label)
|
||||
|
|
@ -1156,8 +1184,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""
|
||||
MATCH (a:`{workspace_label}` {{entity_id: $entity_id}})-[r]-(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
"""
|
||||
|
|
@ -1241,6 +1270,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
Returns:
|
||||
["Person", "Company", ...] # Alphabetically sorted label list
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
|
|
@ -1248,8 +1278,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
# query = "CALL db.labels() YIELD label RETURN label"
|
||||
|
||||
# Method 2: Query compatible with older versions
|
||||
query = """
|
||||
MATCH (n:base)
|
||||
query = f"""
|
||||
MATCH (n:`{workspace_label}`)
|
||||
WHERE n.entity_id IS NOT NULL
|
||||
RETURN DISTINCT n.entity_id AS label
|
||||
ORDER BY label
|
||||
|
|
@ -1285,8 +1315,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
"""
|
||||
|
||||
async def _do_delete(tx: AsyncManagedTransaction):
|
||||
query = """
|
||||
MATCH (n:base {entity_id: $entity_id})
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""
|
||||
MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
result = await tx.run(query, entity_id=node_id)
|
||||
|
|
@ -1342,8 +1373,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
for source, target in edges:
|
||||
|
||||
async def _do_delete_edge(tx: AsyncManagedTransaction):
|
||||
query = """
|
||||
MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""
|
||||
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
|
||||
DELETE r
|
||||
"""
|
||||
result = await tx.run(
|
||||
|
|
@ -1360,26 +1392,32 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
raise
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all data from storage and clean up resources
|
||||
"""Drop all data from current workspace storage and clean up resources
|
||||
|
||||
This method will delete all nodes and relationships in the Neo4j database.
|
||||
This method will delete all nodes and relationships in the current workspace only.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On success: {"status": "success", "message": "workspace data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
workspace_label = self._get_workspace_label()
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
# Delete all nodes and relationships
|
||||
query = "MATCH (n) DETACH DELETE n"
|
||||
# Delete all nodes and relationships in current workspace only
|
||||
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
||||
result = await session.run(query)
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
|
||||
f"Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"workspace '{workspace_label}' data dropped",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
|
||||
logger.error(
|
||||
f"Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
|
||||
)
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
|
|
|||
|
|
@ -46,9 +46,19 @@ class NetworkXStorage(BaseGraphStorage):
|
|||
nx.write_graphml(graph, file_name)
|
||||
|
||||
def __post_init__(self):
|
||||
self._graphml_xml_file = os.path.join(
|
||||
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
working_dir = self.global_config["working_dir"]
|
||||
if self.workspace:
|
||||
# Include workspace in the file path for data isolation
|
||||
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
self._graphml_xml_file = os.path.join(
|
||||
workspace_dir, f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
else:
|
||||
# Default behavior when workspace is empty
|
||||
self._graphml_xml_file = os.path.join(
|
||||
working_dir, f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
self._graph = None
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import datetime
|
||||
from datetime import timezone
|
||||
from dataclasses import dataclass, field
|
||||
|
|
@ -319,7 +320,7 @@ class PostgreSQLDB:
|
|||
# Get all old format data
|
||||
old_data_sql = """
|
||||
SELECT id, mode, original_prompt, return_value, chunk_id,
|
||||
create_time, update_time
|
||||
workspace, create_time, update_time
|
||||
FROM LIGHTRAG_LLM_CACHE
|
||||
WHERE id NOT LIKE '%:%'
|
||||
"""
|
||||
|
|
@ -364,7 +365,9 @@ class PostgreSQLDB:
|
|||
await self.execute(
|
||||
insert_sql,
|
||||
{
|
||||
"workspace": self.workspace,
|
||||
"workspace": record[
|
||||
"workspace"
|
||||
], # Use original record's workspace
|
||||
"id": new_key,
|
||||
"mode": record["mode"],
|
||||
"original_prompt": record["original_prompt"],
|
||||
|
|
@ -384,7 +387,9 @@ class PostgreSQLDB:
|
|||
await self.execute(
|
||||
delete_sql,
|
||||
{
|
||||
"workspace": self.workspace,
|
||||
"workspace": record[
|
||||
"workspace"
|
||||
], # Use original record's workspace
|
||||
"mode": record["mode"],
|
||||
"id": record["id"], # Old id
|
||||
},
|
||||
|
|
@ -670,7 +675,7 @@ class ClientManager:
|
|||
),
|
||||
"workspace": os.environ.get(
|
||||
"POSTGRES_WORKSPACE",
|
||||
config.get("postgres", "workspace", fallback="default"),
|
||||
config.get("postgres", "workspace", fallback=None),
|
||||
),
|
||||
"max_connections": os.environ.get(
|
||||
"POSTGRES_MAX_CONNECTIONS",
|
||||
|
|
@ -716,6 +721,18 @@ class PGKVStorage(BaseKVStorage):
|
|||
async def initialize(self):
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
|
||||
if self.db.workspace:
|
||||
# Use PostgreSQLDB's workspace (highest priority)
|
||||
final_workspace = self.db.workspace
|
||||
elif hasattr(self, "workspace") and self.workspace:
|
||||
# Use storage class's workspace (medium priority)
|
||||
final_workspace = self.workspace
|
||||
self.db.workspace = final_workspace
|
||||
else:
|
||||
# Use "default" for compatibility (lowest priority)
|
||||
final_workspace = "default"
|
||||
self.db.workspace = final_workspace
|
||||
|
||||
async def finalize(self):
|
||||
if self.db is not None:
|
||||
|
|
@ -1047,6 +1064,18 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
async def initialize(self):
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
|
||||
if self.db.workspace:
|
||||
# Use PostgreSQLDB's workspace (highest priority)
|
||||
final_workspace = self.db.workspace
|
||||
elif hasattr(self, "workspace") and self.workspace:
|
||||
# Use storage class's workspace (medium priority)
|
||||
final_workspace = self.workspace
|
||||
self.db.workspace = final_workspace
|
||||
else:
|
||||
# Use "default" for compatibility (lowest priority)
|
||||
final_workspace = "default"
|
||||
self.db.workspace = final_workspace
|
||||
|
||||
async def finalize(self):
|
||||
if self.db is not None:
|
||||
|
|
@ -1328,6 +1357,18 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
async def initialize(self):
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
|
||||
if self.db.workspace:
|
||||
# Use PostgreSQLDB's workspace (highest priority)
|
||||
final_workspace = self.db.workspace
|
||||
elif hasattr(self, "workspace") and self.workspace:
|
||||
# Use storage class's workspace (medium priority)
|
||||
final_workspace = self.workspace
|
||||
self.db.workspace = final_workspace
|
||||
else:
|
||||
# Use "default" for compatibility (lowest priority)
|
||||
final_workspace = "default"
|
||||
self.db.workspace = final_workspace
|
||||
|
||||
async def finalize(self):
|
||||
if self.db is not None:
|
||||
|
|
@ -1606,9 +1647,34 @@ class PGGraphQueryException(Exception):
|
|||
@dataclass
|
||||
class PGGraphStorage(BaseGraphStorage):
|
||||
def __post_init__(self):
|
||||
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
||||
# Graph name will be dynamically generated in initialize() based on workspace
|
||||
self.db: PostgreSQLDB | None = None
|
||||
|
||||
def _get_workspace_graph_name(self) -> str:
|
||||
"""
|
||||
Generate graph name based on workspace and namespace for data isolation.
|
||||
Rules:
|
||||
- If workspace is empty: graph_name = namespace
|
||||
- If workspace has value: graph_name = workspace_namespace
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
str: The graph name for the current workspace
|
||||
"""
|
||||
workspace = getattr(self, "workspace", None)
|
||||
namespace = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
||||
|
||||
if workspace and workspace.strip():
|
||||
# Ensure names comply with PostgreSQL identifier specifications
|
||||
safe_workspace = re.sub(r"[^a-zA-Z0-9_]", "_", workspace.strip())
|
||||
safe_namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace)
|
||||
return f"{safe_workspace}_{safe_namespace}"
|
||||
else:
|
||||
# When workspace is empty, use namespace directly
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", namespace)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_node_id(node_id: str) -> str:
|
||||
"""
|
||||
|
|
@ -1629,6 +1695,27 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
async def initialize(self):
|
||||
if self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > None
|
||||
if self.db.workspace:
|
||||
# Use PostgreSQLDB's workspace (highest priority)
|
||||
final_workspace = self.db.workspace
|
||||
elif hasattr(self, "workspace") and self.workspace:
|
||||
# Use storage class's workspace (medium priority)
|
||||
final_workspace = self.workspace
|
||||
self.db.workspace = final_workspace
|
||||
else:
|
||||
# Use None for compatibility (lowest priority)
|
||||
final_workspace = None
|
||||
self.db.workspace = final_workspace
|
||||
|
||||
# Dynamically generate graph name based on workspace
|
||||
self.workspace = self.db.workspace
|
||||
self.graph_name = self._get_workspace_graph_name()
|
||||
|
||||
# Log the graph initialization for debugging
|
||||
logger.info(
|
||||
f"PostgreSQL Graph initialized: workspace='{self.workspace}', graph_name='{self.graph_name}'"
|
||||
)
|
||||
|
||||
# Execute each statement separately and ignore errors
|
||||
queries = [
|
||||
|
|
@ -2833,7 +2920,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
$$) AS (result agtype)"""
|
||||
|
||||
await self._query(drop_query, readonly=False)
|
||||
return {"status": "success", "message": "graph data dropped"}
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"workspace '{self.workspace}' graph data dropped",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping graph: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
|
|
|||
|
|
@ -50,6 +50,18 @@ def compute_mdhash_id_for_qdrant(
|
|||
@final
|
||||
@dataclass
|
||||
class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
def __init__(
|
||||
self, namespace, global_config, embedding_func, workspace=None, meta_fields=None
|
||||
):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
workspace=workspace or "",
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
meta_fields=meta_fields or set(),
|
||||
)
|
||||
self.__post_init__()
|
||||
|
||||
@staticmethod
|
||||
def create_collection_if_not_exist(
|
||||
client: QdrantClient, collection_name: str, **kwargs
|
||||
|
|
@ -59,6 +71,29 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|||
client.create_collection(collection_name, **kwargs)
|
||||
|
||||
def __post_init__(self):
|
||||
# Check for QDRANT_WORKSPACE environment variable first (higher priority)
|
||||
# This allows administrators to force a specific workspace for all Qdrant storage instances
|
||||
qdrant_workspace = os.environ.get("QDRANT_WORKSPACE")
|
||||
if qdrant_workspace and qdrant_workspace.strip():
|
||||
# Use environment variable value, overriding the passed workspace parameter
|
||||
effective_workspace = qdrant_workspace.strip()
|
||||
logger.info(
|
||||
f"Using QDRANT_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
|
||||
)
|
||||
else:
|
||||
# Use the workspace parameter passed during initialization
|
||||
effective_workspace = self.workspace
|
||||
if effective_workspace:
|
||||
logger.debug(
|
||||
f"Using passed workspace parameter: '{effective_workspace}'"
|
||||
)
|
||||
|
||||
# Build namespace with workspace prefix for data isolation
|
||||
if effective_workspace:
|
||||
self.namespace = f"{effective_workspace}_{self.namespace}"
|
||||
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
|
||||
# When workspace is empty, keep the original namespace unchanged
|
||||
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
|
|
|
|||
|
|
@ -71,6 +71,29 @@ class RedisConnectionManager:
|
|||
@dataclass
|
||||
class RedisKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
# Check for REDIS_WORKSPACE environment variable first (higher priority)
|
||||
# This allows administrators to force a specific workspace for all Redis storage instances
|
||||
redis_workspace = os.environ.get("REDIS_WORKSPACE")
|
||||
if redis_workspace and redis_workspace.strip():
|
||||
# Use environment variable value, overriding the passed workspace parameter
|
||||
effective_workspace = redis_workspace.strip()
|
||||
logger.info(
|
||||
f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
|
||||
)
|
||||
else:
|
||||
# Use the workspace parameter passed during initialization
|
||||
effective_workspace = self.workspace
|
||||
if effective_workspace:
|
||||
logger.debug(
|
||||
f"Using passed workspace parameter: '{effective_workspace}'"
|
||||
)
|
||||
|
||||
# Build namespace with workspace prefix for data isolation
|
||||
if effective_workspace:
|
||||
self.namespace = f"{effective_workspace}_{self.namespace}"
|
||||
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
|
||||
# When workspace is empty, keep the original namespace unchanged
|
||||
|
||||
redis_url = os.environ.get(
|
||||
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
||||
)
|
||||
|
|
@ -461,6 +484,29 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|||
"""Redis implementation of document status storage"""
|
||||
|
||||
def __post_init__(self):
|
||||
# Check for REDIS_WORKSPACE environment variable first (higher priority)
|
||||
# This allows administrators to force a specific workspace for all Redis storage instances
|
||||
redis_workspace = os.environ.get("REDIS_WORKSPACE")
|
||||
if redis_workspace and redis_workspace.strip():
|
||||
# Use environment variable value, overriding the passed workspace parameter
|
||||
effective_workspace = redis_workspace.strip()
|
||||
logger.info(
|
||||
f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
|
||||
)
|
||||
else:
|
||||
# Use the workspace parameter passed during initialization
|
||||
effective_workspace = self.workspace
|
||||
if effective_workspace:
|
||||
logger.debug(
|
||||
f"Using passed workspace parameter: '{effective_workspace}'"
|
||||
)
|
||||
|
||||
# Build namespace with workspace prefix for data isolation
|
||||
if effective_workspace:
|
||||
self.namespace = f"{effective_workspace}_{self.namespace}"
|
||||
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
|
||||
# When workspace is empty, keep the original namespace unchanged
|
||||
|
||||
redis_url = os.environ.get(
|
||||
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -97,9 +97,7 @@ class LightRAG:
|
|||
# Directory
|
||||
# ---
|
||||
|
||||
working_dir: str = field(
|
||||
default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||
)
|
||||
working_dir: str = field(default="./rag_storage")
|
||||
"""Directory where cache and temporary files are stored."""
|
||||
|
||||
# Storage
|
||||
|
|
@ -117,6 +115,12 @@ class LightRAG:
|
|||
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
||||
"""Storage type for tracking document processing statuses."""
|
||||
|
||||
# Workspace
|
||||
# ---
|
||||
|
||||
workspace: str = field(default_factory=lambda: os.getenv("WORKSPACE", ""))
|
||||
"""Workspace for data isolation. Defaults to empty string if WORKSPACE environment variable is not set."""
|
||||
|
||||
# Logging (Deprecated, use setup_logger in utils.py instead)
|
||||
# ---
|
||||
log_level: int | None = field(default=None)
|
||||
|
|
@ -242,7 +246,6 @@ class LightRAG:
|
|||
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional parameters for vector database storage."""
|
||||
|
||||
|
||||
enable_llm_cache: bool = field(default=True)
|
||||
"""Enables caching for LLM responses to avoid redundant computations."""
|
||||
|
||||
|
|
@ -380,39 +383,44 @@ class LightRAG:
|
|||
|
||||
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE,
|
||||
global_config=asdict(
|
||||
self
|
||||
), # Add global_config to ensure cache works properly
|
||||
workspace=self.workspace,
|
||||
global_config=global_config,
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=NameSpace.KV_STORE_FULL_DOCS,
|
||||
workspace=self.workspace,
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=NameSpace.KV_STORE_TEXT_CHUNKS,
|
||||
workspace=self.workspace,
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
||||
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
|
||||
workspace=self.workspace,
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=NameSpace.VECTOR_STORE_ENTITIES,
|
||||
workspace=self.workspace,
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name", "source_id", "content", "file_path"},
|
||||
)
|
||||
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=NameSpace.VECTOR_STORE_RELATIONSHIPS,
|
||||
workspace=self.workspace,
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id", "source_id", "content", "file_path"},
|
||||
)
|
||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
workspace=self.workspace,
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"full_doc_id", "content", "file_path"},
|
||||
)
|
||||
|
|
@ -420,6 +428,7 @@ class LightRAG:
|
|||
# Initialize document status storage
|
||||
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
|
||||
namespace=NameSpace.DOC_STATUS,
|
||||
workspace=self.workspace,
|
||||
global_config=global_config,
|
||||
embedding_func=None,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue