Add possibility to create a new Vector memory and store text data points using openai embeddings.
417 lines
15 KiB
Python
417 lines
15 KiB
Python
from weaviate.gql.get import HybridFusion
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
from langchain.retrievers import WeaviateHybridSearchRetriever, ParentDocumentRetriever
|
|
from databases.vector.vector_db_interface import VectorDBInterface
|
|
# from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from cognitive_architecture.database.vectordb.loaders.loaders import _document_loader
|
|
|
|
class WeaviateVectorDB(VectorDBInterface):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.init_weaviate(embeddings=self.embeddings, namespace=self.namespace)
|
|
|
|
def init_weaviate(
|
|
self,
|
|
embeddings=OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY", "")),
|
|
namespace=None,
|
|
retriever_type="",
|
|
):
|
|
# Weaviate initialization logic
|
|
auth_config = weaviate.auth.AuthApiKey(
|
|
api_key=os.environ.get("WEAVIATE_API_KEY")
|
|
)
|
|
client = weaviate.Client(
|
|
url=os.environ.get("WEAVIATE_URL"),
|
|
auth_client_secret=auth_config,
|
|
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
|
|
)
|
|
|
|
if retriever_type == "single_document_context":
|
|
retriever = WeaviateHybridSearchRetriever(
|
|
client=client,
|
|
index_name=namespace,
|
|
text_key="text",
|
|
attributes=[],
|
|
embedding=embeddings,
|
|
create_schema_if_missing=True,
|
|
)
|
|
return retriever
|
|
elif retriever_type == "multi_document_context":
|
|
retriever = WeaviateHybridSearchRetriever(
|
|
client=client,
|
|
index_name=namespace,
|
|
text_key="text",
|
|
attributes=[],
|
|
embedding=embeddings,
|
|
create_schema_if_missing=True,
|
|
)
|
|
return retriever
|
|
else:
|
|
return client
|
|
# child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
|
# store = InMemoryStore()
|
|
# retriever = ParentDocumentRetriever(
|
|
# vectorstore=vectorstore,
|
|
# docstore=store,
|
|
# child_splitter=child_splitter,
|
|
# )
|
|
|
|
from marshmallow import Schema, fields
|
|
|
|
def create_document_structure(observation, params, metadata_schema_class=None):
|
|
"""
|
|
Create and validate a document structure with optional custom fields.
|
|
|
|
:param observation: Content of the document.
|
|
:param params: Metadata information.
|
|
:param metadata_schema_class: Custom metadata schema class (optional).
|
|
:return: A list containing the validated document data.
|
|
"""
|
|
document_data = {"metadata": params, "page_content": observation}
|
|
|
|
def get_document_schema():
|
|
class DynamicDocumentSchema(Schema):
|
|
metadata = fields.Nested(metadata_schema_class, required=True)
|
|
page_content = fields.Str(required=True)
|
|
|
|
return DynamicDocumentSchema
|
|
|
|
# Validate and deserialize, defaulting to "1.0" if not provided
|
|
CurrentDocumentSchema = get_document_schema()
|
|
loaded_document = CurrentDocumentSchema().load(document_data)
|
|
return [loaded_document]
|
|
|
|
def _stuct(self, observation, params, metadata_schema_class=None):
|
|
"""Utility function to create the document structure with optional custom fields."""
|
|
# Construct document data
|
|
document_data = {"metadata": params, "page_content": observation}
|
|
|
|
def get_document_schema():
|
|
class DynamicDocumentSchema(Schema):
|
|
metadata = fields.Nested(metadata_schema_class, required=True)
|
|
page_content = fields.Str(required=True)
|
|
|
|
return DynamicDocumentSchema
|
|
|
|
# Validate and deserialize # Default to "1.0" if not provided
|
|
CurrentDocumentSchema = get_document_schema()
|
|
loaded_document = CurrentDocumentSchema().load(document_data)
|
|
return [loaded_document]
|
|
|
|
async def add_memories(
|
|
self,
|
|
observation,
|
|
loader_settings=None,
|
|
params=None,
|
|
namespace=None,
|
|
metadata_schema_class=None,
|
|
embeddings="hybrid",
|
|
):
|
|
# Update Weaviate memories here
|
|
if namespace is None:
|
|
namespace = self.namespace
|
|
params["user_id"] = self.user_id
|
|
logging.info("User id is %s", self.user_id)
|
|
retriever = self.init_weaviate(
|
|
embeddings=OpenAIEmbeddings(),
|
|
namespace=namespace,
|
|
retriever_type="single_document_context",
|
|
)
|
|
if loader_settings:
|
|
# Assuming _document_loader returns a list of documents
|
|
documents = await _document_loader(observation, loader_settings)
|
|
logging.info("here are the docs %s", str(documents))
|
|
chunk_count = 0
|
|
for doc_list in documents:
|
|
for doc in doc_list:
|
|
chunk_count += 1
|
|
params["chunk_count"] = doc.metadata.get("chunk_count", "None")
|
|
logging.info(
|
|
"Loading document with provided loader settings %s", str(doc)
|
|
)
|
|
params["source"] = doc.metadata.get("source", "None")
|
|
logging.info("Params are %s", str(params))
|
|
retriever.add_documents(
|
|
[Document(metadata=params, page_content=doc.page_content)]
|
|
)
|
|
else:
|
|
chunk_count = 0
|
|
from cognitive_architecture.database.vectordb.chunkers.chunkers import (
|
|
chunk_data,
|
|
)
|
|
|
|
documents = [
|
|
chunk_data(
|
|
chunk_strategy="VANILLA",
|
|
source_data=observation,
|
|
chunk_size=300,
|
|
chunk_overlap=20,
|
|
)
|
|
]
|
|
for doc in documents[0]:
|
|
chunk_count += 1
|
|
params["chunk_order"] = chunk_count
|
|
params["source"] = "User loaded"
|
|
logging.info(
|
|
"Loading document with default loader settings %s", str(doc)
|
|
)
|
|
logging.info("Params are %s", str(params))
|
|
retriever.add_documents(
|
|
[Document(metadata=params, page_content=doc.page_content)]
|
|
)
|
|
|
|
async def fetch_memories(
|
|
self,
|
|
observation: str,
|
|
namespace: str = None,
|
|
search_type: str = "hybrid",
|
|
params=None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Fetch documents from weaviate.
|
|
|
|
Parameters:
|
|
- observation (str): User query.
|
|
- namespace (str, optional): Type of memory accessed.
|
|
- search_type (str, optional): Type of search ('text', 'hybrid', 'bm25', 'generate', 'generate_grouped'). Defaults to 'hybrid'.
|
|
- **kwargs: Additional parameters for flexibility.
|
|
|
|
Returns:
|
|
List of documents matching the query or an empty list in case of error.
|
|
|
|
Example:
|
|
fetch_memories(query="some query", search_type='text', additional_param='value')
|
|
"""
|
|
client = self.init_weaviate(namespace=self.namespace)
|
|
if search_type is None:
|
|
search_type = "hybrid"
|
|
|
|
if not namespace:
|
|
namespace = self.namespace
|
|
|
|
logging.info("Query on namespace %s", namespace)
|
|
|
|
params_user_id = {
|
|
"path": ["user_id"],
|
|
"operator": "Like",
|
|
"valueText": self.user_id,
|
|
}
|
|
|
|
def list_objects_of_class(class_name, schema):
|
|
return [
|
|
prop["name"]
|
|
for class_obj in schema["classes"]
|
|
if class_obj["class"] == class_name
|
|
for prop in class_obj["properties"]
|
|
]
|
|
|
|
base_query = (
|
|
client.query.get(
|
|
namespace, list(list_objects_of_class(namespace, client.schema.get()))
|
|
)
|
|
.with_additional(
|
|
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", "distance"]
|
|
)
|
|
.with_where(params_user_id)
|
|
.with_limit(10)
|
|
)
|
|
|
|
n_of_observations = kwargs.get("n_of_observations", 2)
|
|
|
|
# try:
|
|
if search_type == "text":
|
|
query_output = (
|
|
base_query.with_near_text({"concepts": [observation]})
|
|
.with_autocut(n_of_observations)
|
|
.do()
|
|
)
|
|
elif search_type == "hybrid":
|
|
query_output = (
|
|
base_query.with_hybrid(
|
|
query=observation, fusion_type=HybridFusion.RELATIVE_SCORE
|
|
)
|
|
.with_autocut(n_of_observations)
|
|
.do()
|
|
)
|
|
elif search_type == "bm25":
|
|
query_output = (
|
|
base_query.with_bm25(query=observation)
|
|
.with_autocut(n_of_observations)
|
|
.do()
|
|
)
|
|
elif search_type == "summary":
|
|
filter_object = {
|
|
"operator": "And",
|
|
"operands": [
|
|
{
|
|
"path": ["user_id"],
|
|
"operator": "Equal",
|
|
"valueText": self.user_id,
|
|
},
|
|
{
|
|
"path": ["chunk_order"],
|
|
"operator": "LessThan",
|
|
"valueNumber": 30,
|
|
},
|
|
],
|
|
}
|
|
base_query = (
|
|
client.query.get(
|
|
namespace,
|
|
list(list_objects_of_class(namespace, client.schema.get())),
|
|
)
|
|
.with_additional(
|
|
[
|
|
"id",
|
|
"creationTimeUnix",
|
|
"lastUpdateTimeUnix",
|
|
"score",
|
|
"distance",
|
|
]
|
|
)
|
|
.with_where(filter_object)
|
|
.with_limit(30)
|
|
)
|
|
query_output = (
|
|
base_query
|
|
# .with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
|
|
.do()
|
|
)
|
|
|
|
elif search_type == "summary_filter_by_object_name":
|
|
filter_object = {
|
|
"operator": "And",
|
|
"operands": [
|
|
{
|
|
"path": ["user_id"],
|
|
"operator": "Equal",
|
|
"valueText": self.user_id,
|
|
},
|
|
{
|
|
"path": ["doc_id"],
|
|
"operator": "Equal",
|
|
"valueText": params,
|
|
},
|
|
],
|
|
}
|
|
base_query = (
|
|
client.query.get(
|
|
namespace,
|
|
list(list_objects_of_class(namespace, client.schema.get())),
|
|
)
|
|
.with_additional(
|
|
[
|
|
"id",
|
|
"creationTimeUnix",
|
|
"lastUpdateTimeUnix",
|
|
"score",
|
|
"distance",
|
|
]
|
|
)
|
|
.with_where(filter_object)
|
|
.with_limit(30)
|
|
.with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
|
|
)
|
|
query_output = base_query.do()
|
|
|
|
return query_output
|
|
elif search_type == "generate":
|
|
generate_prompt = kwargs.get("generate_prompt", "")
|
|
query_output = (
|
|
base_query.with_generate(single_prompt=observation)
|
|
.with_near_text({"concepts": [observation]})
|
|
.with_autocut(n_of_observations)
|
|
.do()
|
|
)
|
|
elif search_type == "generate_grouped":
|
|
generate_prompt = kwargs.get("generate_prompt", "")
|
|
query_output = (
|
|
base_query.with_generate(grouped_task=observation)
|
|
.with_near_text({"concepts": [observation]})
|
|
.with_autocut(n_of_observations)
|
|
.do()
|
|
)
|
|
else:
|
|
logging.error(f"Invalid search_type: {search_type}")
|
|
return []
|
|
# except Exception as e:
|
|
# logging.error(f"Error executing query: {str(e)}")
|
|
# return []
|
|
|
|
return query_output
|
|
|
|
async def delete_memories(self, namespace: str, params: dict = None):
|
|
if namespace is None:
|
|
namespace = self.namespace
|
|
client = self.init_weaviate(namespace=self.namespace)
|
|
if params:
|
|
where_filter = {
|
|
"path": ["id"],
|
|
"operator": "Equal",
|
|
"valueText": params.get("id", None),
|
|
}
|
|
return client.batch.delete_objects(
|
|
class_name=self.namespace,
|
|
# Same `where` filter as in the GraphQL API
|
|
where=where_filter,
|
|
)
|
|
else:
|
|
# Delete all objects
|
|
return client.batch.delete_objects(
|
|
class_name=namespace,
|
|
where={
|
|
"path": ["version"],
|
|
"operator": "Equal",
|
|
"valueText": "1.0",
|
|
},
|
|
)
|
|
|
|
async def count_memories(self, namespace: str = None, params: dict = None) -> int:
|
|
"""
|
|
Count memories in a Weaviate database.
|
|
|
|
Args:
|
|
namespace (str, optional): The Weaviate namespace to count memories in. If not provided, uses the default namespace.
|
|
|
|
Returns:
|
|
int: The number of memories in the specified namespace.
|
|
"""
|
|
if namespace is None:
|
|
namespace = self.namespace
|
|
|
|
client = self.init_weaviate(namespace=namespace)
|
|
|
|
try:
|
|
object_count = client.query.aggregate(namespace).with_meta_count().do()
|
|
return object_count
|
|
except Exception as e:
|
|
logging.info(f"Error counting memories: {str(e)}")
|
|
# Handle the error or log it
|
|
return 0
|
|
|
|
def update_memories(self, observation, namespace: str, params: dict = None):
|
|
client = self.init_weaviate(namespace=self.namespace)
|
|
|
|
client.data_object.update(
|
|
data_object={
|
|
# "text": observation,
|
|
"user_id": str(self.user_id),
|
|
"version": params.get("version", None) or "",
|
|
"agreement_id": params.get("agreement_id", None) or "",
|
|
"privacy_policy": params.get("privacy_policy", None) or "",
|
|
"terms_of_service": params.get("terms_of_service", None) or "",
|
|
"format": params.get("format", None) or "",
|
|
"schema_version": params.get("schema_version", None) or "",
|
|
"checksum": params.get("checksum", None) or "",
|
|
"owner": params.get("owner", None) or "",
|
|
"license": params.get("license", None) or "",
|
|
"validity_start": params.get("validity_start", None) or "",
|
|
"validity_end": params.get("validity_end", None) or ""
|
|
# **source_metadata,
|
|
},
|
|
class_name="Test",
|
|
uuid=params.get("id", None),
|
|
consistency_level=weaviate.data.replication.ConsistencyLevel.ALL, # default QUORUM
|
|
)
|
|
return
|