feat: add vector database interface
This commit is contained in:
parent
2fe437c92a
commit
a6b9c8a5bf
10 changed files with 92 additions and 21 deletions
|
|
@ -17,7 +17,6 @@ COPY pyproject.toml poetry.lock /app/
|
||||||
|
|
||||||
# Install the dependencies
|
# Install the dependencies
|
||||||
RUN poetry config virtualenvs.create false && \
|
RUN poetry config virtualenvs.create false && \
|
||||||
poetry lock --no-update && \
|
|
||||||
poetry install --no-root --no-dev
|
poetry install --no-root --no-dev
|
||||||
|
|
||||||
RUN apt-get update -q && \
|
RUN apt-get update -q && \
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,24 @@
|
||||||
""" This module contains the classifiers for the documents. """
|
""" This module contains the classifiers for the documents. """
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
import json
|
|
||||||
from langchain.document_loaders import TextLoader
|
from langchain.document_loaders import TextLoader
|
||||||
from langchain.document_loaders import DirectoryLoader
|
from langchain.document_loaders import DirectoryLoader
|
||||||
from langchain.chains import create_extraction_chain
|
from langchain.chains import create_extraction_chain
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
from ..database.vectordb.loaders.loaders import _document_loader
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
OPENAI_API_KEY = config.openai_key
|
OPENAI_API_KEY = config.openai_key
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def classify_documents(query: str, document_id: str, content: str):
|
async def classify_documents(query: str, document_id: str, content: str):
|
||||||
"""Classify the documents based on the query and content."""
|
"""Classify the documents based on the query and content."""
|
||||||
document_context = content
|
document_context = content
|
||||||
logging.info("This is the document context", document_context)
|
logging.info("This is the document context %s", document_context)
|
||||||
|
|
||||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||||
prompt_classify = ChatPromptTemplate.from_template(
|
prompt_classify = ChatPromptTemplate.from_template(
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
""" This module contains the function to classify a summary of a document. """
|
""" This module contains the function to classify a summary of a document. """
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
import json
|
|
||||||
from langchain.chains import create_extraction_chain
|
from langchain.chains import create_extraction_chain
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
|
||||||
|
|
@ -12,11 +13,6 @@ config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
OPENAI_API_KEY = config.openai_key
|
OPENAI_API_KEY = config.openai_key
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def classify_summary(query, document_summaries):
|
async def classify_summary(query, document_summaries):
|
||||||
"""Classify the documents based on the query and content."""
|
"""Classify the documents based on the query and content."""
|
||||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,13 @@
|
||||||
""" This module contains the classifiers for the documents. """
|
""" This module contains the classifiers for the documents. """
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
from langchain.chains import create_extraction_chain
|
from langchain.chains import create_extraction_chain
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
from ..database.vectordb.loaders.loaders import _document_loader
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
|
|
@ -20,7 +17,7 @@ async def classify_user_input(query, input_type):
|
||||||
""" Classify the user input based on the query and input type."""
|
""" Classify the user input based on the query and input type."""
|
||||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||||
prompt_classify = ChatPromptTemplate.from_template(
|
prompt_classify = ChatPromptTemplate.from_template(
|
||||||
"""You are a classifier.
|
"""You are a classifier.
|
||||||
Determine with a True or False if the following input: {query},
|
Determine with a True or False if the following input: {query},
|
||||||
is relevant for the following memory category: {input_type}"""
|
is relevant for the following memory category: {input_type}"""
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
""" This module contains the function to classify the user query. """
|
""" This module contains the function to classify the user query. """
|
||||||
from langchain.prompts import ChatPromptTemplate
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from langchain.prompts import ChatPromptTemplate
|
||||||
from langchain.chains import create_extraction_chain
|
from langchain.chains import create_extraction_chain
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.document_loaders import TextLoader
|
from langchain.document_loaders import TextLoader
|
||||||
from langchain.document_loaders import DirectoryLoader
|
from langchain.document_loaders import DirectoryLoader
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
from ..database.vectordb.loaders.loaders import _document_loader
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
OPENAI_API_KEY = config.openai_key
|
OPENAI_API_KEY = config.openai_key
|
||||||
|
|
||||||
|
|
||||||
async def classify_user_query(query, context, document_types):
|
async def classify_user_query(query, context, document_types):
|
||||||
"""Classify the user query based on the context and document types."""
|
"""Classify the user query based on the context and document types."""
|
||||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
from vector.vector_db_interface import VectorDBInterface
|
||||||
|
from qdrant_client import AsyncQdrantClient
|
||||||
|
|
||||||
|
class QDrantAdapter(VectorDBInterface):
|
||||||
|
def __init__(self, qdrant_url, qdrant_api_key):
|
||||||
|
self.qdrant_client = AsyncQdrantClient(qdrant_url, qdrant_api_key)
|
||||||
|
|
||||||
|
async def create_collection(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
collection_config: object
|
||||||
|
):
|
||||||
|
return await self.qdrant_client.create_collection(collection_name, collection_config)
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
class VectorDBInterface(Protocol):
|
||||||
|
""" Collections """
|
||||||
|
@abstractmethod
|
||||||
|
async def create_collection(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
collection_config: object
|
||||||
|
): raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_collection(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
collection_config: object
|
||||||
|
): raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_collection(
|
||||||
|
self,
|
||||||
|
collection_name: str
|
||||||
|
): raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_vector_index(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
vector_index_config: object
|
||||||
|
): raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_data_index(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
vector_index_config: object
|
||||||
|
): raise NotImplementedError
|
||||||
|
|
||||||
|
""" Data points """
|
||||||
|
@abstractmethod
|
||||||
|
async def create_data_point(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
payload: object
|
||||||
|
): raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_data_point(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
data_point_id: str
|
||||||
|
): raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_data_point(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
data_point_id: str,
|
||||||
|
payload: object
|
||||||
|
): raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_data_point(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
data_point_id: str
|
||||||
|
): raise NotImplementedError
|
||||||
Loading…
Add table
Reference in a new issue