cognee/level_3/vectordb/vectordb.py
2023-10-12 13:18:26 +02:00

288 lines
11 KiB
Python

# Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client
import logging
from marshmallow import Schema, fields
from loaders.loaders import _document_loader
# Add the parent directory to sys.path
logging.basicConfig(level=logging.INFO)
from langchain.retrievers import WeaviateHybridSearchRetriever
from weaviate.gql.get import HybridFusion
import tracemalloc
tracemalloc.start()
import os
from langchain.embeddings.openai import OpenAIEmbeddings
from dotenv import load_dotenv
from langchain.schema import Document
import weaviate
load_dotenv()
LTM_MEMORY_ID_DEFAULT = "00000"
ST_MEMORY_ID_DEFAULT = "0000"
BUFFER_ID_DEFAULT = "0000"
class VectorDB:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
def __init__(
self,
user_id: str,
index_name: str,
memory_id: str,
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
buffer_id: str = BUFFER_ID_DEFAULT,
namespace: str = None,
embeddings = None,
):
self.user_id = user_id
self.index_name = index_name
self.namespace = namespace
self.memory_id = memory_id
self.ltm_memory_id = ltm_memory_id
self.st_memory_id = st_memory_id
self.buffer_id = buffer_id
self.embeddings = embeddings
class PineconeVectorDB(VectorDB):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_pinecone(self.index_name)
def init_pinecone(self, index_name):
# Pinecone initialization logic
pass
import langchain.embeddings
class WeaviateVectorDB(VectorDB):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_weaviate(embeddings= self.embeddings, namespace = self.namespace)
def init_weaviate(self, embeddings =OpenAIEmbeddings() , namespace: str=None):
# Weaviate initialization logic
# embeddings = OpenAIEmbeddings()
auth_config = weaviate.auth.AuthApiKey(
api_key=os.environ.get("WEAVIATE_API_KEY")
)
client = weaviate.Client(
url=os.environ.get("WEAVIATE_URL"),
auth_client_secret=auth_config,
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
)
retriever = WeaviateHybridSearchRetriever(
client=client,
index_name=namespace,
text_key="text",
attributes=[],
embedding=embeddings,
create_schema_if_missing=True,
)
return retriever # If this is part of the initialization, call it here.
def init_weaviate_client(self, namespace: str):
# Weaviate client initialization logic
auth_config = weaviate.auth.AuthApiKey(
api_key=os.environ.get("WEAVIATE_API_KEY")
)
client = weaviate.Client(
url=os.environ.get("WEAVIATE_URL"),
auth_client_secret=auth_config,
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
)
return client
def _stuct(self, observation, params, metadata_schema_class =None):
"""Utility function to create the document structure with optional custom fields."""
# Construct document data
document_data = {
"metadata": params,
"page_content": observation
}
def get_document_schema():
class DynamicDocumentSchema(Schema):
metadata = fields.Nested(metadata_schema_class, required=True)
page_content = fields.Str(required=True)
return DynamicDocumentSchema
# Validate and deserialize # Default to "1.0" if not provided
CurrentDocumentSchema = get_document_schema()
loaded_document = CurrentDocumentSchema().load(document_data)
return [loaded_document]
async def add_memories(self, observation, loader_settings=None, params=None, namespace=None, metadata_schema_class=None, embeddings = 'hybrid'):
# Update Weaviate memories here
if namespace is None:
namespace = self.namespace
retriever = self.init_weaviate(embeddings=embeddings,namespace = namespace)
if loader_settings:
# Assuming _document_loader returns a list of documents
documents = _document_loader(observation, loader_settings)
logging.info("here are the docs %s", str(documents))
for doc in documents:
document_to_load = self._stuct(doc.page_content, params, metadata_schema_class)
print("here is the doc to load1", document_to_load)
retriever.add_documents([
Document(metadata=document_to_load[0]['metadata'], page_content=document_to_load[0]['page_content'])])
else:
document_to_load = self._stuct(observation, params, metadata_schema_class)
print("here is the doc to load2", document_to_load)
retriever.add_documents([
Document(metadata=document_to_load[0]['metadata'], page_content=document_to_load[0]['page_content'])])
async def fetch_memories(self, observation: str, namespace: str = None, search_type: str = 'hybrid', **kwargs):
"""
Fetch documents from weaviate.
Parameters:
- observation (str): User query.
- namespace (str, optional): Type of memory accessed.
- search_type (str, optional): Type of search ('text', 'hybrid', 'bm25', 'generate', 'generate_grouped'). Defaults to 'hybrid'.
- **kwargs: Additional parameters for flexibility.
Returns:
List of documents matching the query or an empty list in case of error.
Example:
fetch_memories(query="some query", search_type='text', additional_param='value')
"""
client = self.init_weaviate_client(self.namespace)
if search_type is None:
search_type = 'hybrid'
if not namespace:
namespace = self.namespace
params_user_id = {
"path": ["user_id"],
"operator": "Like",
"valueText": self.user_id,
}
def list_objects_of_class(class_name, schema):
return [
prop["name"]
for class_obj in schema["classes"]
if class_obj["class"] == class_name
for prop in class_obj["properties"]
]
base_query = client.query.get(
namespace, list(list_objects_of_class(namespace, client.schema.get()))
).with_additional(
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
).with_where(params_user_id).with_limit(10)
n_of_observations = kwargs.get('n_of_observations', 2)
try:
if search_type == 'text':
query_output = (
base_query
.with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations)
.do()
)
elif search_type == 'hybrid':
query_output = (
base_query
.with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
.with_autocut(n_of_observations)
.do()
)
elif search_type == 'bm25':
query_output = (
base_query
.with_bm25(query=observation)
.with_autocut(n_of_observations)
.do()
)
elif search_type == 'generate':
generate_prompt = kwargs.get('generate_prompt', "")
query_output = (
base_query
.with_generate(single_prompt=generate_prompt)
.with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations)
.do()
)
elif search_type == 'generate_grouped':
generate_prompt = kwargs.get('generate_prompt', "")
query_output = (
base_query
.with_generate(grouped_task=generate_prompt)
.with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations)
.do()
)
else:
logging.error(f"Invalid search_type: {search_type}")
return []
except Exception as e:
logging.error(f"Error executing query: {str(e)}")
return []
return query_output
async def delete_memories(self, namespace:str, params: dict = None):
if namespace is None:
namespace = self.namespace
client = self.init_weaviate_client(self.namespace)
if params:
where_filter = {
"path": ["id"],
"operator": "Equal",
"valueText": params.get("id", None),
}
return client.batch.delete_objects(
class_name=self.namespace,
# Same `where` filter as in the GraphQL API
where=where_filter,
)
else:
# Delete all objects
print("HERE IS THE USER ID", self.user_id)
return client.batch.delete_objects(
class_name=namespace,
where={
"path": ["version"],
"operator": "Equal",
"valueText": "1.0",
},
)
def update_memories(self, observation, namespace: str, params: dict = None):
client = self.init_weaviate_client(self.namespace)
client.data_object.update(
data_object={
# "text": observation,
"user_id": str(self.user_id),
"memory_id": str(self.memory_id),
"ltm_memory_id": str(self.ltm_memory_id),
"st_memory_id": str(self.st_memory_id),
"buffer_id": str(self.buffer_id),
"version": params.get("version", None) or "",
"agreement_id": params.get("agreement_id", None) or "",
"privacy_policy": params.get("privacy_policy", None) or "",
"terms_of_service": params.get("terms_of_service", None) or "",
"format": params.get("format", None) or "",
"schema_version": params.get("schema_version", None) or "",
"checksum": params.get("checksum", None) or "",
"owner": params.get("owner", None) or "",
"license": params.get("license", None) or "",
"validity_start": params.get("validity_start", None) or "",
"validity_end": params.get("validity_end", None) or ""
# **source_metadata,
},
class_name="Test",
uuid=params.get("id", None),
consistency_level=weaviate.data.replication.ConsistencyLevel.ALL, # default QUORUM
)
return