feat: add vector database interface

This commit is contained in:
Boris Arzentar 2024-02-22 14:09:16 +01:00
parent 2fe437c92a
commit a6b9c8a5bf
10 changed files with 92 additions and 21 deletions

View file

@ -17,7 +17,6 @@ COPY pyproject.toml poetry.lock /app/
# Install the dependencies
RUN poetry config virtualenvs.create false && \
poetry lock --no-update && \
poetry install --no-root --no-dev
RUN apt-get update -q && \

View file

@ -1,26 +1,24 @@
""" This module contains the classifiers for the documents. """
import json
import logging
from langchain.prompts import ChatPromptTemplate
import json
from langchain.document_loaders import TextLoader
from langchain.document_loaders import DirectoryLoader
from langchain.chains import create_extraction_chain
from langchain.chat_models import ChatOpenAI
from ..config import Config
from ..database.vectordb.loaders.loaders import _document_loader
config = Config()
config.load()
OPENAI_API_KEY = config.openai_key
async def classify_documents(query: str, document_id: str, content: str):
"""Classify the documents based on the query and content."""
document_context = content
logging.info("This is the document context", document_context)
logging.info("This is the document context %s", document_context)
llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template(

View file

@ -1,8 +1,9 @@
""" This module contains the function to classify a summary of a document. """
import json
import logging
from langchain.prompts import ChatPromptTemplate
import json
from langchain.chains import create_extraction_chain
from langchain.chat_models import ChatOpenAI
@ -12,11 +13,6 @@ config = Config()
config.load()
OPENAI_API_KEY = config.openai_key
async def classify_summary(query, document_summaries):
"""Classify the documents based on the query and content."""
llm = ChatOpenAI(temperature=0, model=config.model)

View file

@ -1,16 +1,13 @@
""" This module contains the classifiers for the documents. """
import json
import logging
from langchain.prompts import ChatPromptTemplate
import json
from langchain.chains import create_extraction_chain
from langchain.chat_models import ChatOpenAI
from ..config import Config
from ..database.vectordb.loaders.loaders import _document_loader
config = Config()
config.load()
@ -20,7 +17,7 @@ async def classify_user_input(query, input_type):
""" Classify the user input based on the query and input type."""
llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template(
"""You are a classifier.
"""You are a classifier.
Determine with a True or False if the following input: {query},
is relevant for the following memory category: {input_type}"""
)
@ -52,4 +49,4 @@ async def classify_user_input(query, input_type):
logging.info("Relevant summary is %s", arguments_dict.get("DocumentSummary", None))
InputClassification = arguments_dict.get("InputClassification", None)
logging.info("This is the classification %s", InputClassification)
return InputClassification
return InputClassification

View file

@ -1,19 +1,19 @@
""" This module contains the function to classify the user query. """
from langchain.prompts import ChatPromptTemplate
import json
from langchain.prompts import ChatPromptTemplate
from langchain.chains import create_extraction_chain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.document_loaders import DirectoryLoader
from ..config import Config
from ..database.vectordb.loaders.loaders import _document_loader
config = Config()
config.load()
OPENAI_API_KEY = config.openai_key
async def classify_user_query(query, context, document_types):
"""Classify the user query based on the context and document types."""
llm = ChatOpenAI(temperature=0, model=config.model)

View file

@ -0,0 +1,13 @@
from vector.vector_db_interface import VectorDBInterface
from qdrant_client import AsyncQdrantClient
class QDrantAdapter(VectorDBInterface):
def __init__(self, qdrant_url, qdrant_api_key):
self.qdrant_client = AsyncQdrantClient(qdrant_url, qdrant_api_key)
async def create_collection(
self,
collection_name: str,
collection_config: object
):
return await self.qdrant_client.create_collection(collection_name, collection_config)

View file

@ -0,0 +1,68 @@
from abc import abstractmethod
from typing import Protocol
class VectorDBInterface(Protocol):
""" Collections """
@abstractmethod
async def create_collection(
self,
collection_name: str,
collection_config: object
): raise NotImplementedError
@abstractmethod
async def update_collection(
self,
collection_name: str,
collection_config: object
): raise NotImplementedError
@abstractmethod
async def delete_collection(
self,
collection_name: str
): raise NotImplementedError
@abstractmethod
async def create_vector_index(
self,
collection_name: str,
vector_index_config: object
): raise NotImplementedError
@abstractmethod
async def create_data_index(
self,
collection_name: str,
vector_index_config: object
): raise NotImplementedError
""" Data points """
@abstractmethod
async def create_data_point(
self,
collection_name: str,
payload: object
): raise NotImplementedError
@abstractmethod
async def get_data_point(
self,
collection_name: str,
data_point_id: str
): raise NotImplementedError
@abstractmethod
async def update_data_point(
self,
collection_name: str,
data_point_id: str,
payload: object
): raise NotImplementedError
@abstractmethod
async def delete_data_point(
self,
collection_name: str,
data_point_id: str
): raise NotImplementedError