Compare commits
1 commit
main
...
feature/co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
968ad45ca2 |
8 changed files with 141 additions and 98 deletions
|
|
@ -24,13 +24,13 @@ class IndexSchema(DataPoint):
|
|||
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||
def __init__(
|
||||
self,
|
||||
database_url: str,
|
||||
database_port: int,
|
||||
url: str,
|
||||
port: int,
|
||||
embedding_engine=EmbeddingEngine,
|
||||
):
|
||||
self.driver = FalkorDB(
|
||||
host=database_url,
|
||||
port=database_port,
|
||||
host=url,
|
||||
port=port,
|
||||
)
|
||||
self.embedding_engine = embedding_engine
|
||||
self.graph_name = "cognee_graph"
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ from .models.CollectionConfig import CollectionConfig
|
|||
from .vector_db_interface import VectorDBInterface
|
||||
from .config import get_vectordb_config
|
||||
from .get_vector_engine import get_vector_engine
|
||||
from .use_vector_adapter import use_vector_adapter
|
||||
|
|
|
|||
|
|
@ -0,0 +1,108 @@
|
|||
from .embeddings import get_embedding_engine
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache
|
||||
def create_vector_engine(
|
||||
vector_db_url: str,
|
||||
vector_db_port: str,
|
||||
vector_db_key: str,
|
||||
vector_db_provider: str,
|
||||
):
|
||||
embedding_engine = get_embedding_engine()
|
||||
|
||||
if vector_db_provider == "weaviate":
|
||||
from .weaviate_db import WeaviateAdapter
|
||||
|
||||
if not (vector_db_url and vector_db_key):
|
||||
raise EnvironmentError("Missing requred Weaviate credentials!")
|
||||
|
||||
return WeaviateAdapter(vector_db_url, vector_db_key, embedding_engine=embedding_engine)
|
||||
|
||||
elif vector_db_provider == "qdrant":
|
||||
if not (vector_db_url and vector_db_key):
|
||||
raise EnvironmentError("Missing requred Qdrant credentials!")
|
||||
|
||||
from .qdrant.QDrantAdapter import QDrantAdapter
|
||||
|
||||
return QDrantAdapter(
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "milvus":
|
||||
from .milvus.MilvusAdapter import MilvusAdapter
|
||||
|
||||
if not vector_db_url:
|
||||
raise EnvironmentError("Missing required Milvus credentials!")
|
||||
|
||||
return MilvusAdapter(
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "pgvector":
|
||||
from cognee.infrastructure.databases.relational import get_relational_config
|
||||
|
||||
# Get configuration for postgres database
|
||||
relational_config = get_relational_config()
|
||||
db_username = relational_config.db_username
|
||||
db_password = relational_config.db_password
|
||||
db_host = relational_config.db_host
|
||||
db_port = relational_config.db_port
|
||||
db_name = relational_config.db_name
|
||||
|
||||
if not (db_host and db_port and db_name and db_username and db_password):
|
||||
raise EnvironmentError("Missing requred pgvector credentials!")
|
||||
|
||||
connection_string: str = (
|
||||
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||
)
|
||||
|
||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||
|
||||
return PGVectorAdapter(
|
||||
connection_string,
|
||||
vector_db_key,
|
||||
embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "falkordb":
|
||||
if not (vector_db_url and vector_db_port):
|
||||
raise EnvironmentError("Missing requred FalkorDB credentials!")
|
||||
|
||||
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
||||
return FalkorDBAdapter(
|
||||
url=vector_db_url,
|
||||
port=vector_db_port,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "chromadb":
|
||||
try:
|
||||
import chromadb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"ChromaDB is not installed. Please install it with 'pip install chromadb'"
|
||||
)
|
||||
|
||||
from .chromadb.ChromaDBAdapter import ChromaDBAdapter
|
||||
|
||||
return ChromaDBAdapter(
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
else:
|
||||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||
|
||||
return LanceDBAdapter(
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
from .embeddings import get_embedding_engine
|
||||
from .supported_adapters import supported_adapters
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
|
|
@ -12,97 +13,6 @@ def create_vector_engine(
|
|||
):
|
||||
embedding_engine = get_embedding_engine()
|
||||
|
||||
if vector_db_provider == "weaviate":
|
||||
from .weaviate_db import WeaviateAdapter
|
||||
vector_db_adapter = supported_adapters.get(vector_db_provider, None)
|
||||
|
||||
if not (vector_db_url and vector_db_key):
|
||||
raise EnvironmentError("Missing requred Weaviate credentials!")
|
||||
|
||||
return WeaviateAdapter(vector_db_url, vector_db_key, embedding_engine=embedding_engine)
|
||||
|
||||
elif vector_db_provider == "qdrant":
|
||||
if not (vector_db_url and vector_db_key):
|
||||
raise EnvironmentError("Missing requred Qdrant credentials!")
|
||||
|
||||
from .qdrant.QDrantAdapter import QDrantAdapter
|
||||
|
||||
return QDrantAdapter(
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "milvus":
|
||||
from .milvus.MilvusAdapter import MilvusAdapter
|
||||
|
||||
if not vector_db_url:
|
||||
raise EnvironmentError("Missing required Milvus credentials!")
|
||||
|
||||
return MilvusAdapter(
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "pgvector":
|
||||
from cognee.infrastructure.databases.relational import get_relational_config
|
||||
|
||||
# Get configuration for postgres database
|
||||
relational_config = get_relational_config()
|
||||
db_username = relational_config.db_username
|
||||
db_password = relational_config.db_password
|
||||
db_host = relational_config.db_host
|
||||
db_port = relational_config.db_port
|
||||
db_name = relational_config.db_name
|
||||
|
||||
if not (db_host and db_port and db_name and db_username and db_password):
|
||||
raise EnvironmentError("Missing requred pgvector credentials!")
|
||||
|
||||
connection_string: str = (
|
||||
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||
)
|
||||
|
||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||
|
||||
return PGVectorAdapter(
|
||||
connection_string,
|
||||
vector_db_key,
|
||||
embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "falkordb":
|
||||
if not (vector_db_url and vector_db_port):
|
||||
raise EnvironmentError("Missing requred FalkorDB credentials!")
|
||||
|
||||
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
||||
return FalkorDBAdapter(
|
||||
database_url=vector_db_url,
|
||||
database_port=vector_db_port,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif vector_db_provider == "chromadb":
|
||||
try:
|
||||
import chromadb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"ChromaDB is not installed. Please install it with 'pip install chromadb'"
|
||||
)
|
||||
|
||||
from .chromadb.ChromaDBAdapter import ChromaDBAdapter
|
||||
|
||||
return ChromaDBAdapter(
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
else:
|
||||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||
|
||||
return LanceDBAdapter(
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
return vector_db_adapter(url=vector_db_url, api_key=vector_db_key, embedding_engine=embedding_engine)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .exceptions import CollectionNotFoundError
|
||||
from .exceptions import *
|
||||
|
|
|
|||
|
|
@ -12,3 +12,15 @@ class CollectionNotFoundError(CriticalError):
|
|||
log_level="DEBUG",
|
||||
):
|
||||
super().__init__(message, name, status_code, log, log_level)
|
||||
|
||||
|
||||
class VectorEngineInitializationError(CriticalError):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
name: str = "VectorEngineInitializationError",
|
||||
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
log=True,
|
||||
log_level="ERROR",
|
||||
):
|
||||
super().__init__(message, name, status_code, log, log_level)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||
|
||||
|
||||
supported_adapters = {
|
||||
"lancedb": LanceDBAdapter,
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from .vector_db_interface import VectorDBInterface
|
||||
from .supported_adapters import supported_adapters
|
||||
|
||||
|
||||
def use_vector_adapter(vector_adapter_name: str, vector_adapter: VectorDBInterface):
|
||||
supported_adapters[vector_adapter_name] = vector_adapter
|
||||
Loading…
Add table
Reference in a new issue