feat: add vector database interface
This commit is contained in:
parent
2fe437c92a
commit
a6b9c8a5bf
10 changed files with 92 additions and 21 deletions
|
|
@ -17,7 +17,6 @@ COPY pyproject.toml poetry.lock /app/
|
|||
|
||||
# Install the dependencies
|
||||
RUN poetry config virtualenvs.create false && \
|
||||
poetry lock --no-update && \
|
||||
poetry install --no-root --no-dev
|
||||
|
||||
RUN apt-get update -q && \
|
||||
|
|
|
|||
|
|
@ -1,26 +1,24 @@
|
|||
""" This module contains the classifiers for the documents. """
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
import json
|
||||
from langchain.document_loaders import TextLoader
|
||||
from langchain.document_loaders import DirectoryLoader
|
||||
from langchain.chains import create_extraction_chain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
from ..config import Config
|
||||
from ..database.vectordb.loaders.loaders import _document_loader
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
OPENAI_API_KEY = config.openai_key
|
||||
|
||||
|
||||
|
||||
async def classify_documents(query: str, document_id: str, content: str):
|
||||
"""Classify the documents based on the query and content."""
|
||||
document_context = content
|
||||
logging.info("This is the document context", document_context)
|
||||
logging.info("This is the document context %s", document_context)
|
||||
|
||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||
prompt_classify = ChatPromptTemplate.from_template(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
""" This module contains the function to classify a summary of a document. """
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
import json
|
||||
from langchain.chains import create_extraction_chain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
|
|
@ -12,11 +13,6 @@ config = Config()
|
|||
config.load()
|
||||
OPENAI_API_KEY = config.openai_key
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
async def classify_summary(query, document_summaries):
|
||||
"""Classify the documents based on the query and content."""
|
||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,13 @@
|
|||
""" This module contains the classifiers for the documents. """
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
import json
|
||||
|
||||
|
||||
from langchain.chains import create_extraction_chain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
from ..config import Config
|
||||
from ..database.vectordb.loaders.loaders import _document_loader
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
|
|
@ -20,7 +17,7 @@ async def classify_user_input(query, input_type):
|
|||
""" Classify the user input based on the query and input type."""
|
||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||
prompt_classify = ChatPromptTemplate.from_template(
|
||||
"""You are a classifier.
|
||||
"""You are a classifier.
|
||||
Determine with a True or False if the following input: {query},
|
||||
is relevant for the following memory category: {input_type}"""
|
||||
)
|
||||
|
|
@ -52,4 +49,4 @@ async def classify_user_input(query, input_type):
|
|||
logging.info("Relevant summary is %s", arguments_dict.get("DocumentSummary", None))
|
||||
InputClassification = arguments_dict.get("InputClassification", None)
|
||||
logging.info("This is the classification %s", InputClassification)
|
||||
return InputClassification
|
||||
return InputClassification
|
||||
|
|
|
|||
|
|
@ -1,19 +1,19 @@
|
|||
""" This module contains the function to classify the user query. """
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
import json
|
||||
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.chains import create_extraction_chain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.document_loaders import TextLoader
|
||||
from langchain.document_loaders import DirectoryLoader
|
||||
|
||||
from ..config import Config
|
||||
from ..database.vectordb.loaders.loaders import _document_loader
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
OPENAI_API_KEY = config.openai_key
|
||||
|
||||
|
||||
async def classify_user_query(query, context, document_types):
|
||||
"""Classify the user query based on the context and document types."""
|
||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
from vector.vector_db_interface import VectorDBInterface
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
|
||||
class QDrantAdapter(VectorDBInterface):
|
||||
def __init__(self, qdrant_url, qdrant_api_key):
|
||||
self.qdrant_client = AsyncQdrantClient(qdrant_url, qdrant_api_key)
|
||||
|
||||
async def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
collection_config: object
|
||||
):
|
||||
return await self.qdrant_client.create_collection(collection_name, collection_config)
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Protocol
|
||||
|
||||
class VectorDBInterface(Protocol):
|
||||
""" Collections """
|
||||
@abstractmethod
|
||||
async def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
collection_config: object
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def update_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
collection_config: object
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete_collection(
|
||||
self,
|
||||
collection_name: str
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def create_vector_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
vector_index_config: object
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def create_data_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
vector_index_config: object
|
||||
): raise NotImplementedError
|
||||
|
||||
""" Data points """
|
||||
@abstractmethod
|
||||
async def create_data_point(
|
||||
self,
|
||||
collection_name: str,
|
||||
payload: object
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_data_point(
|
||||
self,
|
||||
collection_name: str,
|
||||
data_point_id: str
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def update_data_point(
|
||||
self,
|
||||
collection_name: str,
|
||||
data_point_id: str,
|
||||
payload: object
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete_data_point(
|
||||
self,
|
||||
collection_name: str,
|
||||
data_point_id: str
|
||||
): raise NotImplementedError
|
||||
Loading…
Add table
Reference in a new issue