fix: allow alternative vector db engine to be used
This commit is contained in:
parent
472d1a0df6
commit
f79631d5da
24 changed files with 143 additions and 163 deletions
|
|
@ -112,7 +112,7 @@ export default function Home() {
|
||||||
expireIn={notification.expireIn}
|
expireIn={notification.expireIn}
|
||||||
onClose={notification.delete}
|
onClose={notification.delete}
|
||||||
>
|
>
|
||||||
<Text>{notification.message}</Text>
|
<Text nowrap>{notification.message}</Text>
|
||||||
</Notification>
|
</Notification>
|
||||||
))}
|
))}
|
||||||
</NotificationContainer>
|
</NotificationContainer>
|
||||||
|
|
|
||||||
|
|
@ -72,15 +72,21 @@ async def get_dataset_graph(dataset_id: str):
|
||||||
from cognee.infrastructure.databases.graph import get_graph_config
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||||
|
|
||||||
graph_config = get_graph_config()
|
try:
|
||||||
graph_engine = graph_config.graph_engine
|
graph_config = get_graph_config()
|
||||||
graph_client = await get_graph_client(graph_engine)
|
graph_engine = graph_config.graph_engine
|
||||||
graph_url = await render_graph(graph_client.graph)
|
graph_client = await get_graph_client(graph_engine)
|
||||||
|
graph_url = await render_graph(graph_client.graph)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code = 200,
|
||||||
content = str(graph_url),
|
content = str(graph_url),
|
||||||
)
|
)
|
||||||
|
except:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code = 409,
|
||||||
|
content = "Graphistry credentials are not set. Please set them in your .env file.",
|
||||||
|
)
|
||||||
|
|
||||||
@app.get("/datasets/{dataset_id}/data", response_model=list)
|
@app.get("/datasets/{dataset_id}/data", response_model=list)
|
||||||
async def get_dataset_data(dataset_id: str):
|
async def get_dataset_data(dataset_id: str):
|
||||||
|
|
@ -106,7 +112,7 @@ async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 200,
|
status_code = 200,
|
||||||
content = { dataset["data_id"]: dataset["status"] for dataset in datasets_statuses },
|
content = datasets_statuses,
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.get("/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
|
@app.get("/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
|
||||||
|
|
@ -264,8 +270,7 @@ def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
||||||
relational_config.create_engine()
|
relational_config.create_engine()
|
||||||
|
|
||||||
vector_config = get_vectordb_config()
|
vector_config = get_vectordb_config()
|
||||||
vector_config.vector_db_path = databases_directory_path
|
vector_config.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
|
||||||
vector_config.create_engine()
|
|
||||||
|
|
||||||
base_config = get_base_config()
|
base_config = get_base_config()
|
||||||
data_directory_path = os.path.abspath(".data_storage")
|
data_directory_path = os.path.abspath(".data_storage")
|
||||||
|
|
|
||||||
|
|
@ -61,14 +61,19 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
||||||
async with update_status_lock:
|
async with update_status_lock:
|
||||||
task_status = get_task_status([dataset_name])
|
task_status = get_task_status([dataset_name])
|
||||||
|
|
||||||
if task_status == "DATASET_PROCESSING_STARTED":
|
if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED":
|
||||||
logger.info(f"Dataset {dataset_name} is being processed.")
|
logger.info(f"Dataset {dataset_name} is being processed.")
|
||||||
return
|
return
|
||||||
|
|
||||||
update_task_status(dataset_name, "DATASET_PROCESSING_STARTED")
|
update_task_status(dataset_name, "DATASET_PROCESSING_STARTED")
|
||||||
|
|
||||||
await cognify(dataset_name)
|
try:
|
||||||
update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED")
|
await cognify(dataset_name)
|
||||||
|
update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED")
|
||||||
|
except Exception as error:
|
||||||
|
update_task_status(dataset_name, "DATASET_PROCESSING_ERROR")
|
||||||
|
raise error
|
||||||
|
|
||||||
|
|
||||||
# datasets is a list of dataset names
|
# datasets is a list of dataset names
|
||||||
if isinstance(datasets, list):
|
if isinstance(datasets, list):
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,8 @@ class config():
|
||||||
relational_config.create_engine()
|
relational_config.create_engine()
|
||||||
|
|
||||||
vector_config = get_vectordb_config()
|
vector_config = get_vectordb_config()
|
||||||
vector_config.vector_db_path = databases_directory_path
|
if vector_config.vector_engine_provider == "lancedb":
|
||||||
vector_config.create_engine()
|
vector_config.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def data_root_directory(data_root_directory: str):
|
def data_root_directory(data_root_directory: str):
|
||||||
|
|
|
||||||
|
|
@ -2,4 +2,5 @@ from .models.DataPoint import DataPoint
|
||||||
from .models.VectorConfig import VectorConfig
|
from .models.VectorConfig import VectorConfig
|
||||||
from .models.CollectionConfig import CollectionConfig
|
from .models.CollectionConfig import CollectionConfig
|
||||||
from .vector_db_interface import VectorDBInterface
|
from .vector_db_interface import VectorDBInterface
|
||||||
from .config import get_vectordb_config
|
from .config import get_vectordb_config
|
||||||
|
from .get_vector_engine import get_vector_engine
|
||||||
|
|
|
||||||
|
|
@ -1,37 +1,18 @@
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
|
||||||
from cognee.root_dir import get_absolute_path
|
from cognee.root_dir import get_absolute_path
|
||||||
from .create_vector_engine import create_vector_engine
|
|
||||||
|
|
||||||
class VectorConfig(BaseSettings):
|
class VectorConfig(BaseSettings):
|
||||||
vector_db_path: str = os.path.join(get_absolute_path(".cognee_system"), "databases")
|
vector_db_url: str = os.path.join(
|
||||||
vector_db_url: str = os.path.join(vector_db_path, "cognee.lancedb")
|
os.path.join(get_absolute_path(".cognee_system"), "databases"),
|
||||||
|
"cognee.lancedb"
|
||||||
|
)
|
||||||
vector_db_key: str = ""
|
vector_db_key: str = ""
|
||||||
vector_engine_provider: str = "lancedb"
|
vector_engine_provider: str = "lancedb"
|
||||||
vector_engine: object = create_vector_engine(
|
|
||||||
{
|
|
||||||
"vector_db_key": None,
|
|
||||||
"vector_db_url": vector_db_url,
|
|
||||||
"vector_db_provider": "lancedb",
|
|
||||||
},
|
|
||||||
get_embedding_config().embedding_engine,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||||
|
|
||||||
def create_engine(self):
|
|
||||||
if self.vector_engine_provider == "lancedb":
|
|
||||||
self.vector_db_url = os.path.join(self.vector_db_path, "cognee.lancedb")
|
|
||||||
else:
|
|
||||||
self.vector_db_path = None
|
|
||||||
|
|
||||||
self.vector_engine = create_vector_engine(
|
|
||||||
get_vectordb_config().to_dict(),
|
|
||||||
get_embedding_config().embedding_engine,
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
return {
|
return {
|
||||||
"vector_db_url": self.vector_db_url,
|
"vector_db_url": self.vector_db_url,
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,6 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||||
# from cognee.infrastructure.files.storage import LocalStorage
|
|
||||||
|
|
||||||
# LocalStorage.ensure_directory_exists(config["vector_db_url"])
|
|
||||||
|
|
||||||
return LanceDBAdapter(
|
return LanceDBAdapter(
|
||||||
url = config["vector_db_url"],
|
url = config["vector_db_url"],
|
||||||
|
|
|
||||||
|
|
@ -1,69 +0,0 @@
|
||||||
import asyncio
|
|
||||||
from typing import List, Optional
|
|
||||||
from fastembed import TextEmbedding
|
|
||||||
import litellm
|
|
||||||
from litellm import aembedding
|
|
||||||
from cognee.root_dir import get_absolute_path
|
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
|
||||||
|
|
||||||
litellm.set_verbose = True
|
|
||||||
|
|
||||||
class DefaultEmbeddingEngine(EmbeddingEngine):
|
|
||||||
embedding_model: str
|
|
||||||
embedding_dimensions: int
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embedding_model: Optional[str],
|
|
||||||
embedding_dimensions: Optional[int],
|
|
||||||
):
|
|
||||||
self.embedding_model = embedding_model
|
|
||||||
self.embedding_dimensions = embedding_dimensions
|
|
||||||
|
|
||||||
async def embed_text(self, text: List[str]) -> List[float]:
|
|
||||||
embedding_model = TextEmbedding(model_name = self.embedding_model, cache_dir = get_absolute_path("cache/embeddings"))
|
|
||||||
embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text)))
|
|
||||||
|
|
||||||
return embeddings_list
|
|
||||||
|
|
||||||
def get_vector_size(self) -> int:
|
|
||||||
return self.embedding_dimensions
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|
||||||
embedding_model: str
|
|
||||||
embedding_dimensions: int
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embedding_model: Optional[str],
|
|
||||||
embedding_dimensions: Optional[int],
|
|
||||||
):
|
|
||||||
self.embedding_model = embedding_model
|
|
||||||
self.embedding_dimensions = embedding_dimensions
|
|
||||||
|
|
||||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
|
||||||
async def get_embedding(text_):
|
|
||||||
response = await aembedding(self.embedding_model, input=text_)
|
|
||||||
return response.data[0]['embedding']
|
|
||||||
|
|
||||||
tasks = [get_embedding(text_) for text_ in text]
|
|
||||||
result = await asyncio.gather(*tasks)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_vector_size(self) -> int:
|
|
||||||
return self.embedding_dimensions
|
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# async def gg():
|
|
||||||
# openai_embedding_engine = LiteLLMEmbeddingEngine()
|
|
||||||
# # print(openai_embedding_engine.embed_text(["Hello, how are you?"]))
|
|
||||||
# # print(openai_embedding_engine.get_vector_size())
|
|
||||||
# # default_embedding_engine = DefaultEmbeddingEngine()
|
|
||||||
# sds = await openai_embedding_engine.embed_text(["Hello, sadasdas are you?"])
|
|
||||||
# print(sds)
|
|
||||||
# # print(default_embedding_engine.get_vector_size())
|
|
||||||
|
|
||||||
# asyncio.run(gg())
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
from fastembed import TextEmbedding
|
||||||
|
from cognee.root_dir import get_absolute_path
|
||||||
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
|
class FastembedEmbeddingEngine(EmbeddingEngine):
|
||||||
|
embedding_model: str
|
||||||
|
embedding_dimensions: int
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_model: Optional[str] = "BAAI/bge-large-en-v1.5",
|
||||||
|
embedding_dimensions: Optional[int] = 1024,
|
||||||
|
):
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
self.embedding_dimensions = embedding_dimensions
|
||||||
|
|
||||||
|
async def embed_text(self, text: List[str]) -> List[float]:
|
||||||
|
embedding_model = TextEmbedding(model_name = self.embedding_model, cache_dir = get_absolute_path("cache/embeddings"))
|
||||||
|
embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text)))
|
||||||
|
|
||||||
|
return embeddings_list
|
||||||
|
|
||||||
|
def get_vector_size(self) -> int:
|
||||||
|
return self.embedding_dimensions
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Optional
|
||||||
|
import litellm
|
||||||
|
from litellm import aembedding
|
||||||
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
|
api_key: str
|
||||||
|
embedding_model: str
|
||||||
|
embedding_dimensions: int
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_model: Optional[str] = "text-embedding-3-large",
|
||||||
|
embedding_dimensions: Optional[int] = 3072,
|
||||||
|
api_key: str = None,
|
||||||
|
):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
self.embedding_dimensions = embedding_dimensions
|
||||||
|
|
||||||
|
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||||
|
async def get_embedding(text_):
|
||||||
|
response = await aembedding(
|
||||||
|
self.embedding_model,
|
||||||
|
input = text_,
|
||||||
|
api_key = self.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.data[0]["embedding"]
|
||||||
|
|
||||||
|
tasks = [get_embedding(text_) for text_ in text]
|
||||||
|
result = await asyncio.gather(*tasks)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_vector_size(self) -> int:
|
||||||
|
return self.embedding_dimensions
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from .get_embedding_engine import get_embedding_engine
|
||||||
|
|
@ -1,15 +1,12 @@
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector.embeddings.DefaultEmbeddingEngine import DefaultEmbeddingEngine
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfig(BaseSettings):
|
class EmbeddingConfig(BaseSettings):
|
||||||
openai_embedding_model: str = "text-embedding-3-large"
|
openai_embedding_model: str = "text-embedding-3-large"
|
||||||
openai_embedding_dimensions: int = 3072
|
openai_embedding_dimensions: int = 3072
|
||||||
litellm_embedding_model: str = "BAAI/bge-large-en-v1.5"
|
litellm_embedding_model: str = "BAAI/bge-large-en-v1.5"
|
||||||
litellm_embedding_dimensions: int = 1024
|
litellm_embedding_dimensions: int = 1024
|
||||||
embedding_engine:object = DefaultEmbeddingEngine(embedding_model=litellm_embedding_model, embedding_dimensions=litellm_embedding_dimensions)
|
# embedding_engine:object = DefaultEmbeddingEngine(embedding_model=litellm_embedding_model, embedding_dimensions=litellm_embedding_dimensions)
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
from cognee.infrastructure.llm import get_llm_config
|
||||||
|
from .EmbeddingEngine import EmbeddingEngine
|
||||||
|
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
||||||
|
|
||||||
|
def get_embedding_engine() -> EmbeddingEngine:
|
||||||
|
llm_config = get_llm_config()
|
||||||
|
return LiteLLMEmbeddingEngine(api_key = llm_config.llm_api_key)
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
from .config import get_vectordb_config
|
||||||
|
from .embeddings import get_embedding_engine
|
||||||
|
from .create_vector_engine import create_vector_engine
|
||||||
|
|
||||||
|
def get_vector_engine():
|
||||||
|
return create_vector_engine(get_vectordb_config().to_dict(), get_embedding_engine())
|
||||||
|
|
@ -2,11 +2,9 @@ from datetime import datetime
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing import List, Tuple, TypedDict
|
from typing import List, Tuple, TypedDict
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from cognee.infrastructure.databases.vector import DataPoint
|
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
|
||||||
|
|
||||||
# from cognee.shared.utils import extract_pos_tags, extract_named_entities, extract_sentiment_vader
|
|
||||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
# from cognee.shared.utils import extract_pos_tags, extract_named_entities, extract_sentiment_vader
|
||||||
|
|
||||||
|
|
||||||
class GraphLike(TypedDict):
|
class GraphLike(TypedDict):
|
||||||
|
|
@ -20,8 +18,7 @@ async def add_cognitive_layer_graphs(
|
||||||
chunk_id: str,
|
chunk_id: str,
|
||||||
layer_graphs: List[Tuple[str, GraphLike]],
|
layer_graphs: List[Tuple[str, GraphLike]],
|
||||||
):
|
):
|
||||||
vectordb_config = get_vectordb_config()
|
vector_engine = get_vector_engine()
|
||||||
vector_client = vectordb_config.vector_engine
|
|
||||||
|
|
||||||
graph_config = get_graph_config()
|
graph_config = get_graph_config()
|
||||||
graph_model = graph_config.graph_model
|
graph_model = graph_config.graph_model
|
||||||
|
|
@ -127,7 +124,7 @@ async def add_cognitive_layer_graphs(
|
||||||
references: References
|
references: References
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await vector_client.create_collection(layer_id, payload_schema = PayloadSchema)
|
await vector_engine.create_collection(layer_id, payload_schema = PayloadSchema)
|
||||||
except Exception:
|
except Exception:
|
||||||
# It's ok if the collection already exists.
|
# It's ok if the collection already exists.
|
||||||
pass
|
pass
|
||||||
|
|
@ -146,8 +143,8 @@ async def add_cognitive_layer_graphs(
|
||||||
) for (node_id, node_data) in graph_nodes
|
) for (node_id, node_data) in graph_nodes
|
||||||
]
|
]
|
||||||
|
|
||||||
await vector_client.create_data_points(layer_id, data_points)
|
await vector_engine.create_data_points(layer_id, data_points)
|
||||||
|
|
||||||
|
|
||||||
def generate_node_id(node_id: str) -> str:
|
def generate_node_id(node_id: str) -> str:
|
||||||
return node_id.upper().replace(' ', '_').replace("'", "")
|
return node_id.upper().replace(" ", "_").replace("'", "")
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
|
||||||
from cognee.infrastructure.databases.vector import DataPoint
|
|
||||||
|
|
||||||
class TextChunk(TypedDict):
|
class TextChunk(TypedDict):
|
||||||
text: str
|
text: str
|
||||||
|
|
@ -10,8 +9,7 @@ class TextChunk(TypedDict):
|
||||||
file_metadata: dict
|
file_metadata: dict
|
||||||
|
|
||||||
async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
||||||
config = get_vectordb_config()
|
vector_engine = get_vector_engine()
|
||||||
vector_client = config.vector_engine
|
|
||||||
|
|
||||||
identified_chunks = []
|
identified_chunks = []
|
||||||
|
|
||||||
|
|
@ -21,7 +19,7 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
||||||
for (dataset_name, chunks) in dataset_data_chunks.items():
|
for (dataset_name, chunks) in dataset_data_chunks.items():
|
||||||
try:
|
try:
|
||||||
|
|
||||||
await vector_client.create_collection(dataset_name, payload_schema = PayloadSchema)
|
await vector_engine.create_collection(dataset_name, payload_schema = PayloadSchema)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
print(error)
|
print(error)
|
||||||
pass
|
pass
|
||||||
|
|
@ -38,7 +36,7 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
||||||
|
|
||||||
identified_chunks.extend(dataset_chunks)
|
identified_chunks.extend(dataset_chunks)
|
||||||
|
|
||||||
await vector_client.create_data_points(
|
await vector_engine.create_data_points(
|
||||||
dataset_name,
|
dataset_name,
|
||||||
[
|
[
|
||||||
DataPoint[PayloadSchema](
|
DataPoint[PayloadSchema](
|
||||||
|
|
@ -53,8 +51,7 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
||||||
|
|
||||||
|
|
||||||
async def add_data_chunks_basic_rag(dataset_data_chunks: dict[str, list[TextChunk]]):
|
async def add_data_chunks_basic_rag(dataset_data_chunks: dict[str, list[TextChunk]]):
|
||||||
config = get_vectordb_config()
|
vector_engine = get_vector_engine()
|
||||||
vector_client = config.vector_engine
|
|
||||||
|
|
||||||
identified_chunks = []
|
identified_chunks = []
|
||||||
|
|
||||||
|
|
@ -64,7 +61,7 @@ async def add_data_chunks_basic_rag(dataset_data_chunks: dict[str, list[TextChun
|
||||||
for (dataset_name, chunks) in dataset_data_chunks.items():
|
for (dataset_name, chunks) in dataset_data_chunks.items():
|
||||||
try:
|
try:
|
||||||
|
|
||||||
await vector_client.create_collection("basic_rag", payload_schema = PayloadSchema)
|
await vector_engine.create_collection("basic_rag", payload_schema = PayloadSchema)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
print(error)
|
print(error)
|
||||||
|
|
||||||
|
|
@ -80,7 +77,7 @@ async def add_data_chunks_basic_rag(dataset_data_chunks: dict[str, list[TextChun
|
||||||
|
|
||||||
identified_chunks.extend(dataset_chunks)
|
identified_chunks.extend(dataset_chunks)
|
||||||
|
|
||||||
await vector_client.create_data_points(
|
await vector_engine.create_data_points(
|
||||||
"basic_rag",
|
"basic_rag",
|
||||||
[
|
[
|
||||||
DataPoint[PayloadSchema](
|
DataPoint[PayloadSchema](
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,10 @@ from uuid import uuid4
|
||||||
from typing import List
|
from typing import List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
|
||||||
from cognee.infrastructure.databases.vector import DataPoint
|
|
||||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
|
||||||
|
|
||||||
async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str]) -> None:
|
async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str]) -> None:
|
||||||
vectordb_config = get_vectordb_config()
|
vector_engine = get_vector_engine()
|
||||||
vector_client = vectordb_config.vector_engine
|
|
||||||
|
|
||||||
keyword_nodes = []
|
keyword_nodes = []
|
||||||
|
|
||||||
|
|
@ -62,9 +59,9 @@ async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str]
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await vector_client.create_collection(parent_node_id, payload_schema = PayloadSchema)
|
await vector_engine.create_collection(parent_node_id, payload_schema = PayloadSchema)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# It's ok if the collection already exists.
|
# It's ok if the collection already exists.
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
await vector_client.create_data_points(parent_node_id, keyword_data_points)
|
await vector_engine.create_data_points(parent_node_id, keyword_data_points)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
async def resolve_cross_graph_references(nodes_by_layer: Dict):
|
async def resolve_cross_graph_references(nodes_by_layer: Dict):
|
||||||
results = []
|
results = []
|
||||||
|
|
@ -16,8 +16,7 @@ async def resolve_cross_graph_references(nodes_by_layer: Dict):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def get_nodes_by_layer(layer_id: str, layer_nodes: List):
|
async def get_nodes_by_layer(layer_id: str, layer_nodes: List):
|
||||||
vectordb_config = get_vectordb_config()
|
vector_engine = get_vector_engine()
|
||||||
vector_engine = vectordb_config.vector_engine
|
|
||||||
|
|
||||||
score_points = await vector_engine.batch_search(
|
score_points = await vector_engine.batch_search(
|
||||||
layer_id,
|
layer_id,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
|
||||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||||
|
|
||||||
async def prune_system(graph = True, vector = True):
|
async def prune_system(graph = True, vector = True):
|
||||||
|
|
@ -9,6 +9,5 @@ async def prune_system(graph = True, vector = True):
|
||||||
await graph_client.delete_graph()
|
await graph_client.delete_graph()
|
||||||
|
|
||||||
if vector:
|
if vector:
|
||||||
vector_config = get_vectordb_config()
|
vector_engine = get_vector_engine()
|
||||||
vector_client = vector_config.vector_engine
|
await vector_engine.prune()
|
||||||
await vector_client.prune()
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
async def search_similarity(query: str, graph):
|
async def search_similarity(query: str, graph):
|
||||||
graph_config = get_graph_config()
|
graph_config = get_graph_config()
|
||||||
|
|
@ -17,10 +17,8 @@ async def search_similarity(query: str, graph):
|
||||||
|
|
||||||
graph_nodes = []
|
graph_nodes = []
|
||||||
|
|
||||||
vector_config = get_vectordb_config()
|
|
||||||
|
|
||||||
for layer_id in unique_layer_uuids:
|
for layer_id in unique_layer_uuids:
|
||||||
vector_engine = vector_config.vector_engine
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
results = await vector_engine.search(layer_id, query_text = query, limit = 10)
|
results = await vector_engine.search(layer_id, query_text = query, limit = 10)
|
||||||
print("results", results)
|
print("results", results)
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ def get_settings():
|
||||||
"value": llm_config.llm_model,
|
"value": llm_config.llm_model,
|
||||||
"label": llm_config.llm_model,
|
"label": llm_config.llm_model,
|
||||||
} if llm_config.llm_model else None,
|
} if llm_config.llm_model else None,
|
||||||
"apiKey": llm_config.llm_api_key[:-10] + "**********" if llm_config.llm_api_key else None,
|
"apiKey": (llm_config.llm_api_key[:-10] + "**********") if llm_config.llm_api_key else None,
|
||||||
"providers": llm_providers,
|
"providers": llm_providers,
|
||||||
"models": {
|
"models": {
|
||||||
"openai": [{
|
"openai": [{
|
||||||
|
|
|
||||||
|
|
@ -13,4 +13,3 @@ async def save_vector_db_config(vector_db_config: VectorDBConfig):
|
||||||
vector_config.vector_db_url = vector_db_config.url
|
vector_config.vector_db_url = vector_db_config.url
|
||||||
vector_config.vector_db_key = vector_db_config.apiKey
|
vector_config.vector_db_key = vector_db_config.apiKey
|
||||||
vector_config.vector_engine_provider = vector_db_config.provider
|
vector_config.vector_engine_provider = vector_db_config.provider
|
||||||
vector_config.create_engine()
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ def get_task_status(data_ids: [str]):
|
||||||
|
|
||||||
formatted_data_ids = ", ".join([f"'{data_id}'" for data_id in data_ids])
|
formatted_data_ids = ", ".join([f"'{data_id}'" for data_id in data_ids])
|
||||||
|
|
||||||
results = db_engine.execute_query(
|
datasets_statuses = db_engine.execute_query(
|
||||||
f"""SELECT data_id, status
|
f"""SELECT data_id, status
|
||||||
FROM (
|
FROM (
|
||||||
SELECT data_id, status, ROW_NUMBER() OVER (PARTITION BY data_id ORDER BY created_at DESC) as rn
|
SELECT data_id, status, ROW_NUMBER() OVER (PARTITION BY data_id ORDER BY created_at DESC) as rn
|
||||||
|
|
@ -16,4 +16,4 @@ def get_task_status(data_ids: [str]):
|
||||||
WHERE rn = 1;"""
|
WHERE rn = 1;"""
|
||||||
)
|
)
|
||||||
|
|
||||||
return results[0] if len(results) > 0 else None
|
return { dataset["data_id"]: dataset["status"] for dataset in datasets_statuses }
|
||||||
|
|
|
||||||
|
|
@ -82,18 +82,17 @@ async def run_cognify_base_rag():
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
async def cognify_search_base_rag(content:str, context:str):
|
async def cognify_search_base_rag(content:str, context:str):
|
||||||
base_config = get_base_config()
|
base_config = get_base_config()
|
||||||
|
|
||||||
cognee_directory_path = os.path.abspath(".cognee_system")
|
cognee_directory_path = os.path.abspath(".cognee_system")
|
||||||
base_config.system_root_directory = cognee_directory_path
|
base_config.system_root_directory = cognee_directory_path
|
||||||
|
|
||||||
vector_config = get_vectordb_config()
|
vector_engine = get_vector_engine()
|
||||||
vector_client = vector_config.vector_engine
|
|
||||||
|
|
||||||
return_ = await vector_client.search(collection_name="basic_rag", query_text=content, limit=10)
|
return_ = await vector_engine.search(collection_name="basic_rag", query_text=content, limit=10)
|
||||||
|
|
||||||
print("results", return_)
|
print("results", return_)
|
||||||
return return_
|
return return_
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue