cognee/level_3/vectordb/basevectordb.py

239 lines
7.5 KiB
Python

import logging
from io import BytesIO
import os, sys
# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from vectordb.vectordb import PineconeVectorDB, WeaviateVectorDB
import sqlalchemy as sa
logging.basicConfig(level=logging.INFO)
import marvin
import requests
from dotenv import load_dotenv
from langchain.document_loaders import PyPDFLoader
from langchain.retrievers import WeaviateHybridSearchRetriever
from weaviate.gql.get import HybridFusion
from models.sessions import Session
from models.testset import TestSet
from models.testoutput import TestOutput
from models.metadatas import MetaDatas
from models.operation import Operation
from sqlalchemy.orm import sessionmaker
from database.database import engine
load_dotenv()
from typing import Optional
import time
import tracemalloc
tracemalloc.start()
from datetime import datetime
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.schema import Document
import uuid
import weaviate
from marshmallow import Schema, fields
import json
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY")
LTM_MEMORY_ID_DEFAULT = "00000"
ST_MEMORY_ID_DEFAULT = "0000"
BUFFER_ID_DEFAULT = "0000"
class VectorDBFactory:
def create_vector_db(
self,
user_id: str,
index_name: str,
memory_id: str,
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
buffer_id: str = BUFFER_ID_DEFAULT,
db_type: str = "pinecone",
namespace: str = None,
embeddings = None,
):
db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB}
if db_type in db_map:
return db_map[db_type](
user_id,
index_name,
memory_id,
ltm_memory_id,
st_memory_id,
buffer_id,
namespace,
embeddings
)
raise ValueError(f"Unsupported database type: {db_type}")
class BaseMemory:
def __init__(
self,
user_id: str,
memory_id: Optional[str],
index_name: Optional[str],
db_type: str,
namespace: str,
embeddings: Optional[None],
):
self.user_id = user_id
self.memory_id = memory_id
self.index_name = index_name
self.namespace = namespace
self.embeddings = embeddings
self.db_type = db_type
factory = VectorDBFactory()
self.vector_db = factory.create_vector_db(
self.user_id,
self.index_name,
self.memory_id,
db_type=self.db_type,
namespace=self.namespace,
embeddings=self.embeddings
)
def init_client(self, embeddings, namespace: str):
return self.vector_db.init_weaviate_client(embeddings, namespace)
def create_field(self, field_type, **kwargs):
field_mapping = {
"Str": fields.Str,
"Int": fields.Int,
"Float": fields.Float,
"Bool": fields.Bool,
}
return field_mapping[field_type](**kwargs)
def create_dynamic_schema(self, params):
"""Create a dynamic schema based on provided parameters."""
dynamic_fields = {field_name: fields.Str() for field_name in params.keys()}
# Create a Schema instance with the dynamic fields
dynamic_schema_instance = Schema.from_dict(dynamic_fields)()
return dynamic_schema_instance
async def get_version_from_db(self, user_id, memory_id):
# Logic to retrieve the version from the database.
Session = sessionmaker(bind=engine)
session = Session()
try:
# Querying both fields: contract_metadata and created_at
result = (
session.query(MetaDatas.contract_metadata, MetaDatas.created_at)
.filter_by(user_id=user_id) # using parameter, not self.user_id
.order_by(MetaDatas.created_at.desc())
.first()
)
if result:
version_in_db, created_at = result
logging.info(f"version_in_db: {version_in_db}")
from ast import literal_eval
version_in_db= literal_eval(version_in_db)
version_in_db = version_in_db.get("version")
return [version_in_db, created_at]
else:
return None
finally:
session.close()
async def update_metadata(self, user_id, memory_id, version_in_params, params):
version_from_db = await self.get_version_from_db(user_id, memory_id)
Session = sessionmaker(bind=engine)
session = Session()
# If there is no metadata, insert it.
if version_from_db is None:
session.add(MetaDatas(id = str(uuid.uuid4()), user_id=self.user_id, version = str(int(time.time())) ,memory_id=self.memory_id, contract_metadata=params))
session.commit()
return params
# If params version is higher, update the metadata.
elif version_in_params > version_from_db[0]:
session.add(MetaDatas(id = str(uuid.uuid4()), user_id=self.user_id, memory_id=self.memory_id, contract_metadata=params))
session.commit()
return params
else:
return params
async def add_memories(
self,
observation: Optional[str] = None,
loader_settings: dict = None,
params: Optional[dict] = None,
namespace: Optional[str] = None,
custom_fields: Optional[str] = None,
embeddings: Optional[str] = None,
):
from ast import literal_eval
class DynamicSchema(Schema):
pass
default_version = 'current_timestamp'
version_in_params = params.get("version", default_version)
# Check and update metadata version in DB.
schema_fields = params
def create_field(field_type, **kwargs):
field_mapping = {
"Str": fields.Str,
"Int": fields.Int,
"Float": fields.Float,
"Bool": fields.Bool,
}
return field_mapping[field_type](**kwargs)
# Dynamic Schema Creation
schema_instance = self.create_dynamic_schema(params) # Always creating Str field, adjust as needed
logging.info(f"params : {params}")
# Schema Validation
schema_instance = schema_instance
print("Schema fields: ", [field for field in schema_instance._declared_fields])
loaded_params = schema_instance.load(params)
return await self.vector_db.add_memories(
observation=observation, loader_settings=loader_settings,
params=loaded_params, namespace=namespace, metadata_schema_class = schema_instance, embeddings=embeddings
)
# Add other db_type conditions if necessary
async def fetch_memories(
self,
observation: str,
search_type: Optional[str] = None,
params: Optional[str] = None,
namespace: Optional[str] = None,
n_of_observations: Optional[int] = 2,
):
return await self.vector_db.fetch_memories(
observation=observation, search_type= search_type, params=params,
namespace=namespace,
n_of_observations=n_of_observations
)
async def delete_memories(self, namespace:str, params: Optional[str] = None):
return await self.vector_db.delete_memories(namespace,params)