Add initial classifier
This commit is contained in:
parent
0a07b1e96b
commit
1ed03efa70
2 changed files with 49 additions and 0 deletions
|
|
@ -131,6 +131,22 @@ async def user_query_processor(payload: Payload):
|
|||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/user-query-classifier")
|
||||
async def user_query_classfier(payload: Payload):
|
||||
try:
|
||||
decoded_payload = payload.payload
|
||||
|
||||
# Execute the query - replace this with the actual execution method
|
||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||
from cognitive_architecture.classifiers.classifier import classify_user_query
|
||||
# Assuming you have a method in Neo4jGraphDB to execute the query
|
||||
result = await classify_user_query(session, decoded_payload['user_id'], decoded_payload['query'])
|
||||
return JSONResponse(content={"response": result}, status_code=200)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
||||
"""
|
||||
Start the API server using uvicorn.
|
||||
|
|
|
|||
|
|
@ -102,4 +102,37 @@ async def classify_call(query, context, document_types):
|
|||
|
||||
print("This is the classifier value", classfier_value)
|
||||
|
||||
return classfier_value
|
||||
|
||||
|
||||
|
||||
async def classify_user_query(query, context, document_types):
|
||||
|
||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||
prompt_classify = ChatPromptTemplate.from_template(
|
||||
"""You are a classifier. You store user memories, thoughts and feelings. Determine if you need to use them to answer this query : {query}"""
|
||||
)
|
||||
json_structure = [{
|
||||
"name": "classifier",
|
||||
"description": "Classification",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"UserQueryClassifier": {
|
||||
"type": "bool",
|
||||
"description": "The classification of documents in groups such as legal, medical, etc."
|
||||
}
|
||||
|
||||
|
||||
}, "required": ["UserQueryClassiffier"] }
|
||||
}]
|
||||
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure)
|
||||
classifier_output = await chain_filter.ainvoke({"query": query, "context": context, "document_types": document_types})
|
||||
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
|
||||
print("This is the arguments string", arguments_str)
|
||||
arguments_dict = json.loads(arguments_str)
|
||||
classfier_value = arguments_dict.get('UserQueryClassifier', None)
|
||||
|
||||
print("This is the classifier value", classfier_value)
|
||||
|
||||
return classfier_value
|
||||
Loading…
Add table
Reference in a new issue