feat: add vector database interface

This commit is contained in:
Boris Arzentar 2024-02-22 14:09:16 +01:00
parent 2fe437c92a
commit a6b9c8a5bf
10 changed files with 92 additions and 21 deletions

View file

@ -17,7 +17,6 @@ COPY pyproject.toml poetry.lock /app/
# Install the dependencies # Install the dependencies
RUN poetry config virtualenvs.create false && \ RUN poetry config virtualenvs.create false && \
poetry lock --no-update && \
poetry install --no-root --no-dev poetry install --no-root --no-dev
RUN apt-get update -q && \ RUN apt-get update -q && \

View file

@ -1,26 +1,24 @@
""" This module contains the classifiers for the documents. """ """ This module contains the classifiers for the documents. """
import json
import logging import logging
from langchain.prompts import ChatPromptTemplate from langchain.prompts import ChatPromptTemplate
import json
from langchain.document_loaders import TextLoader from langchain.document_loaders import TextLoader
from langchain.document_loaders import DirectoryLoader from langchain.document_loaders import DirectoryLoader
from langchain.chains import create_extraction_chain from langchain.chains import create_extraction_chain
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from ..config import Config from ..config import Config
from ..database.vectordb.loaders.loaders import _document_loader
config = Config() config = Config()
config.load() config.load()
OPENAI_API_KEY = config.openai_key OPENAI_API_KEY = config.openai_key
async def classify_documents(query: str, document_id: str, content: str): async def classify_documents(query: str, document_id: str, content: str):
"""Classify the documents based on the query and content.""" """Classify the documents based on the query and content."""
document_context = content document_context = content
logging.info("This is the document context", document_context) logging.info("This is the document context %s", document_context)
llm = ChatOpenAI(temperature=0, model=config.model) llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template( prompt_classify = ChatPromptTemplate.from_template(

View file

@ -1,8 +1,9 @@
""" This module contains the function to classify a summary of a document. """ """ This module contains the function to classify a summary of a document. """
import json
import logging import logging
from langchain.prompts import ChatPromptTemplate from langchain.prompts import ChatPromptTemplate
import json
from langchain.chains import create_extraction_chain from langchain.chains import create_extraction_chain
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
@ -12,11 +13,6 @@ config = Config()
config.load() config.load()
OPENAI_API_KEY = config.openai_key OPENAI_API_KEY = config.openai_key
async def classify_summary(query, document_summaries): async def classify_summary(query, document_summaries):
"""Classify the documents based on the query and content.""" """Classify the documents based on the query and content."""
llm = ChatOpenAI(temperature=0, model=config.model) llm = ChatOpenAI(temperature=0, model=config.model)

View file

@ -1,16 +1,13 @@
""" This module contains the classifiers for the documents. """ """ This module contains the classifiers for the documents. """
import json
import logging import logging
from langchain.prompts import ChatPromptTemplate from langchain.prompts import ChatPromptTemplate
import json
from langchain.chains import create_extraction_chain from langchain.chains import create_extraction_chain
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from ..config import Config from ..config import Config
from ..database.vectordb.loaders.loaders import _document_loader
config = Config() config = Config()
config.load() config.load()
@ -20,7 +17,7 @@ async def classify_user_input(query, input_type):
""" Classify the user input based on the query and input type.""" """ Classify the user input based on the query and input type."""
llm = ChatOpenAI(temperature=0, model=config.model) llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template( prompt_classify = ChatPromptTemplate.from_template(
"""You are a classifier. """You are a classifier.
Determine with a True or False if the following input: {query}, Determine with a True or False if the following input: {query},
is relevant for the following memory category: {input_type}""" is relevant for the following memory category: {input_type}"""
) )
@ -52,4 +49,4 @@ async def classify_user_input(query, input_type):
logging.info("Relevant summary is %s", arguments_dict.get("DocumentSummary", None)) logging.info("Relevant summary is %s", arguments_dict.get("DocumentSummary", None))
InputClassification = arguments_dict.get("InputClassification", None) InputClassification = arguments_dict.get("InputClassification", None)
logging.info("This is the classification %s", InputClassification) logging.info("This is the classification %s", InputClassification)
return InputClassification return InputClassification

View file

@ -1,19 +1,19 @@
""" This module contains the function to classify the user query. """ """ This module contains the function to classify the user query. """
from langchain.prompts import ChatPromptTemplate
import json import json
from langchain.prompts import ChatPromptTemplate
from langchain.chains import create_extraction_chain from langchain.chains import create_extraction_chain
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader from langchain.document_loaders import TextLoader
from langchain.document_loaders import DirectoryLoader from langchain.document_loaders import DirectoryLoader
from ..config import Config from ..config import Config
from ..database.vectordb.loaders.loaders import _document_loader
config = Config() config = Config()
config.load() config.load()
OPENAI_API_KEY = config.openai_key OPENAI_API_KEY = config.openai_key
async def classify_user_query(query, context, document_types): async def classify_user_query(query, context, document_types):
"""Classify the user query based on the context and document types.""" """Classify the user query based on the context and document types."""
llm = ChatOpenAI(temperature=0, model=config.model) llm = ChatOpenAI(temperature=0, model=config.model)

View file

@ -0,0 +1,13 @@
from vector.vector_db_interface import VectorDBInterface
from qdrant_client import AsyncQdrantClient
class QDrantAdapter(VectorDBInterface):
def __init__(self, qdrant_url, qdrant_api_key):
self.qdrant_client = AsyncQdrantClient(qdrant_url, qdrant_api_key)
async def create_collection(
self,
collection_name: str,
collection_config: object
):
return await self.qdrant_client.create_collection(collection_name, collection_config)

View file

@ -0,0 +1,68 @@
from abc import abstractmethod
from typing import Protocol
class VectorDBInterface(Protocol):
""" Collections """
@abstractmethod
async def create_collection(
self,
collection_name: str,
collection_config: object
): raise NotImplementedError
@abstractmethod
async def update_collection(
self,
collection_name: str,
collection_config: object
): raise NotImplementedError
@abstractmethod
async def delete_collection(
self,
collection_name: str
): raise NotImplementedError
@abstractmethod
async def create_vector_index(
self,
collection_name: str,
vector_index_config: object
): raise NotImplementedError
@abstractmethod
async def create_data_index(
self,
collection_name: str,
vector_index_config: object
): raise NotImplementedError
""" Data points """
@abstractmethod
async def create_data_point(
self,
collection_name: str,
payload: object
): raise NotImplementedError
@abstractmethod
async def get_data_point(
self,
collection_name: str,
data_point_id: str
): raise NotImplementedError
@abstractmethod
async def update_data_point(
self,
collection_name: str,
data_point_id: str,
payload: object
): raise NotImplementedError
@abstractmethod
async def delete_data_point(
self,
collection_name: str,
data_point_id: str
): raise NotImplementedError