Add initial classifier

This commit is contained in:
Vasilije 2023-12-11 11:15:43 +01:00
parent 0a07b1e96b
commit 1ed03efa70
2 changed files with 49 additions and 0 deletions

View file

@ -131,6 +131,22 @@ async def user_query_processor(payload: Payload):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/user-query-classifier")
async def user_query_classfier(payload: Payload):
try:
decoded_payload = payload.payload
# Execute the query - replace this with the actual execution method
async with session_scope(session=AsyncSessionLocal()) as session:
from cognitive_architecture.classifiers.classifier import classify_user_query
# Assuming you have a method in Neo4jGraphDB to execute the query
result = await classify_user_query(session, decoded_payload['user_id'], decoded_payload['query'])
return JSONResponse(content={"response": result}, status_code=200)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
"""
Start the API server using uvicorn.

View file

@ -102,4 +102,37 @@ async def classify_call(query, context, document_types):
print("This is the classifier value", classfier_value)
return classfier_value
async def classify_user_query(query, context, document_types):
llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template(
"""You are a classifier. You store user memories, thoughts and feelings. Determine if you need to use them to answer this query : {query}"""
)
json_structure = [{
"name": "classifier",
"description": "Classification",
"parameters": {
"type": "object",
"properties": {
"UserQueryClassifier": {
"type": "bool",
"description": "The classification of documents in groups such as legal, medical, etc."
}
}, "required": ["UserQueryClassiffier"] }
}]
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure)
classifier_output = await chain_filter.ainvoke({"query": query, "context": context, "document_types": document_types})
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
print("This is the arguments string", arguments_str)
arguments_dict = json.loads(arguments_str)
classfier_value = arguments_dict.get('UserQueryClassifier', None)
print("This is the classifier value", classfier_value)
return classfier_value