cognee/level_2/vectordb/basevectordb.py
Vasilije 59c53f7339 Added following:
1. Dynamic metadata retrieval, refactored function
2. Load with using marshmallow, allows dynamic fields now
3. Added chunkers, different varieties
4. Fixed PDF loading so it is better standardized
2023-09-13 17:15:33 +02:00

133 lines
3.6 KiB
Python

# Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client
import logging
from io import BytesIO
from level_2.vectordb.vectordb import PineconeVectorDB, WeaviateVectorDB
logging.basicConfig(level=logging.INFO)
import marvin
import requests
from dotenv import load_dotenv
from langchain.document_loaders import PyPDFLoader
from langchain.retrievers import WeaviateHybridSearchRetriever
from weaviate.gql.get import HybridFusion
load_dotenv()
from typing import Optional
import tracemalloc
tracemalloc.start()
import os
from datetime import datetime
from langchain.embeddings.openai import OpenAIEmbeddings
from dotenv import load_dotenv
from langchain.schema import Document
import uuid
import weaviate
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY")
LTM_MEMORY_ID_DEFAULT = "00000"
ST_MEMORY_ID_DEFAULT = "0000"
BUFFER_ID_DEFAULT = "0000"
class VectorDBFactory:
def create_vector_db(
self,
user_id: str,
index_name: str,
memory_id: str,
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
buffer_id: str = BUFFER_ID_DEFAULT,
db_type: str = "pinecone",
namespace: str = None,
):
db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB}
if db_type in db_map:
return db_map[db_type](
user_id,
index_name,
memory_id,
ltm_memory_id,
st_memory_id,
buffer_id,
namespace,
)
raise ValueError(f"Unsupported database type: {db_type}")
class BaseMemory:
def __init__(
self,
user_id: str,
memory_id: Optional[str],
index_name: Optional[str],
db_type: str,
namespace: str,
):
self.user_id = user_id
self.memory_id = memory_id
self.index_name = index_name
self.namespace = namespace
self.memory_type_id = str(uuid.uuid4())
self.db_type = db_type
factory = VectorDBFactory()
self.vector_db = factory.create_vector_db(
self.user_id,
self.index_name,
self.memory_id,
db_type=self.db_type,
namespace=self.namespace,
)
def init_client(self, namespace: str):
return self.vector_db.init_weaviate_client(namespace)
async def add_memories(
self,
observation: Optional[str] = None,
loader_settings: dict = None,
params: Optional[dict] = None,
namespace: Optional[str] = None,
custom_fields: Optional[str] = None,
):
return await self.vector_db.add_memories(
observation=observation, loader_settings=loader_settings,
params=params, namespace=namespace, custom_fields=custom_fields
)
# Add other db_type conditions if necessary
async def fetch_memories(
self,
observation: str,
params: Optional[str] = None,
namespace: Optional[str] = None,
n_of_observations: Optional[int] = 2,
):
return await self.vector_db.fetch_memories(
observation=observation, params=params,
namespace=namespace,
n_of_observations=n_of_observations
)
async def delete_memories(self, params: Optional[str] = None):
return await self.vector_db.delete_memories(params)
# Additional methods for specific Memory can be added here