cognee/cognitive_architecture/classifiers/classifier.py
2024-01-11 01:18:38 +01:00

172 lines
No EOL
6.6 KiB
Python

import logging
from langchain.prompts import ChatPromptTemplate
import json
#TO DO, ADD ALL CLASSIFIERS HERE
from langchain.chains import create_extraction_chain
from langchain.chat_models import ChatOpenAI
from ..config import Config
from ..database.vectordb.loaders.loaders import _document_loader
config = Config()
config.load()
OPENAI_API_KEY = config.openai_key
from langchain.document_loaders import TextLoader
from langchain.document_loaders import DirectoryLoader
async def classify_documents(query:str, document_id:str, content:str):
document_context = content
logging.info("This is the document context", document_context)
llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template(
"""You are a summarizer and classifier. Determine what book this is and where does it belong in the output : {query}, Id: {d_id} Document context is: {context}"""
)
json_structure = [{
"name": "summarizer",
"description": "Summarization and classification",
"parameters": {
"type": "object",
"properties": {
"DocumentCategory": {
"type": "string",
"description": "The classification of documents in groups such as legal, medical, etc."
},
"Title": {
"type": "string",
"description": "The title of the document"
},
"Summary": {
"type": "string",
"description": "The summary of the document"
},
"d_id": {
"type": "string",
"description": "The id of the document"
}
}, "required": ["DocumentCategory", "Title", "Summary","d_id"] }
}]
chain_filter = prompt_classify | llm.bind(function_call={"name": "summarizer"}, functions=json_structure)
classifier_output = await chain_filter.ainvoke({"query": query, "d_id": document_id, "context": str(document_context)})
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
print("This is the arguments string", arguments_str)
arguments_dict = json.loads(arguments_str)
return arguments_dict
# classify retrievals according to type of retrieval
def classify_retrieval():
pass
async def classify_user_input(query, input_type):
llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template(
"""You are a classifier. Determine with a True or False if the following input: {query}, is relevant for the following memory category: {input_type}"""
)
json_structure = [{
"name": "classifier",
"description": "Classification",
"parameters": {
"type": "object",
"properties": {
"InputClassification": {
"type": "boolean",
"description": "The classification of the input"
}
}, "required": ["InputClassification"] }
}]
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure)
classifier_output = await chain_filter.ainvoke({"query": query, "input_type": input_type})
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
logging.info("This is the arguments string %s", arguments_str)
arguments_dict = json.loads(arguments_str)
logging.info("Relevant summary is %s", arguments_dict.get('DocumentSummary', None))
InputClassification = arguments_dict.get('InputClassification', None)
logging.info("This is the classification %s", InputClassification)
return InputClassification
# classify documents according to type of document
async def classify_call(query, document_summaries):
llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template(
"""You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries and ids:{document_summaries}"""
)
json_structure = [{
"name": "classifier",
"description": "Classification",
"parameters": {
"type": "object",
"properties": {
"DocumentSummary": {
"type": "string",
"description": "The summary of the document and the topic it deals with."
},
"d_id": {
"type": "string",
"description": "The id of the document"
}
}, "required": ["DocumentSummary"] }
}]
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure)
classifier_output = await chain_filter.ainvoke({"query": query, "document_summaries": document_summaries})
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
print("This is the arguments string", arguments_str)
arguments_dict = json.loads(arguments_str)
logging.info("Relevant summary is %s", arguments_dict.get('DocumentSummary', None))
classfier_id = arguments_dict.get('d_id', None)
print("This is the classifier id ", classfier_id)
return classfier_id
async def classify_user_query(query, context, document_types):
llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template(
"""You are a classifier. You store user memories, thoughts and feelings. Determine if you need to use them to answer this query : {query}"""
)
json_structure = [{
"name": "classifier",
"description": "Classification",
"parameters": {
"type": "object",
"properties": {
"UserQueryClassifier": {
"type": "bool",
"description": "The classification of documents in groups such as legal, medical, etc."
}
}, "required": ["UserQueryClassiffier"] }
}]
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure)
classifier_output = await chain_filter.ainvoke({"query": query, "context": context, "document_types": document_types})
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
print("This is the arguments string", arguments_str)
arguments_dict = json.loads(arguments_str)
classfier_value = arguments_dict.get('UserQueryClassifier', None)
print("This is the classifier value", classfier_value)
return classfier_value