Fix context retireval

This commit is contained in:
Vasilije 2024-01-11 01:18:38 +01:00
parent 50e4d7c1e6
commit f41f09171c
3 changed files with 61 additions and 7 deletions

16
api.py
View file

@ -165,10 +165,7 @@ async def document_to_graph_db(payload: Payload):
async def cognitive_context_enrichment(payload: Payload):
try:
decoded_payload = payload.payload
# Execute the query - replace this with the actual execution method
async with session_scope(session=AsyncSessionLocal()) as session:
# Assuming you have a method in Neo4jGraphDB to execute the query
result = await user_context_enrichment(session, user_id = decoded_payload['user_id'], query= decoded_payload['query'], generative_response=decoded_payload['generative_response'], memory_type= decoded_payload['memory_type'])
return JSONResponse(content={"response": result}, status_code=200)
@ -176,6 +173,19 @@ async def cognitive_context_enrichment(payload: Payload):
raise HTTPException(status_code=500, detail=str(e))
@app.post("/classify-user-query")
async def classify_user_query(payload: Payload):
try:
decoded_payload = payload.payload
async with session_scope(session=AsyncSessionLocal()) as session:
from main import relevance_feedback
result = await relevance_feedback( query= decoded_payload['query'], input_type=decoded_payload['knowledge_type'])
return JSONResponse(content={"response": result}, status_code=200)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/user-query-classifier")
async def user_query_classfier(payload: Payload):
try:

View file

@ -72,6 +72,35 @@ def classify_retrieval():
pass
async def classify_user_input(query, input_type):
llm = ChatOpenAI(temperature=0, model=config.model)
prompt_classify = ChatPromptTemplate.from_template(
"""You are a classifier. Determine with a True or False if the following input: {query}, is relevant for the following memory category: {input_type}"""
)
json_structure = [{
"name": "classifier",
"description": "Classification",
"parameters": {
"type": "object",
"properties": {
"InputClassification": {
"type": "boolean",
"description": "The classification of the input"
}
}, "required": ["InputClassification"] }
}]
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure)
classifier_output = await chain_filter.ainvoke({"query": query, "input_type": input_type})
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
logging.info("This is the arguments string %s", arguments_str)
arguments_dict = json.loads(arguments_str)
logging.info("Relevant summary is %s", arguments_dict.get('DocumentSummary', None))
InputClassification = arguments_dict.get('InputClassification', None)
logging.info("This is the classification %s", InputClassification)
return InputClassification
# classify documents according to type of document
async def classify_call(query, document_summaries):
@ -102,6 +131,7 @@ async def classify_call(query, document_summaries):
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
print("This is the arguments string", arguments_str)
arguments_dict = json.loads(arguments_str)
logging.info("Relevant summary is %s", arguments_dict.get('DocumentSummary', None))
classfier_id = arguments_dict.get('d_id', None)
print("This is the classifier id ", classfier_id)

22
main.py
View file

@ -392,7 +392,7 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
if detect_language(query) != "en":
query = translate_text(query, "sr", "en")
logging.info("Translated query is", query)
logging.info("Translated query is %s", str(query))
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
password=config.graph_database_password)
@ -404,8 +404,15 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
# summaries = [record.get("summary") for record in result]
# logging.info('Possible document categories are', str(result))
# logging.info('Possible document categories are', str(categories))
relevant_summary_id = await classify_call( query= query, document_summaries=str(summaries))
max_attempts = 3
relevant_summary_id = None
for _ in range(max_attempts):
relevant_summary_id = await classify_call( query= query, document_summaries=str(summaries))
if relevant_summary_id is not None:
break
# logging.info("Relevant categories after the classifier are %s", relevant_categories)
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
@ -602,7 +609,11 @@ async def unlink_user_from_memory(user_id: str=None, labels:list=None, topic:str
logging.error(f"Error creating public memory node: {e}")
return None
async def relevance_feedback( query: str, input_type: str):
from cognitive_architecture.classifiers.classifier import classify_user_input
result = await classify_user_input( query, input_type=input_type)
return result
async def main():
user_id = "user"
@ -687,8 +698,11 @@ async def main():
# await attach_user_to_memory(user_id=user_id, labels=['sr'], topic="PublicMemory")
return_ = await user_context_enrichment(user_id=user_id, query="Koja je minimalna visina ograde na balkonu na stambenom objektu", session=session, memory_type="PublicMemory", generative_response=True)
print(return_)
# return_ = await user_context_enrichment(user_id=user_id, query="hi how are you", session=session, memory_type="PublicMemory", generative_response=True)
# print(return_)
aa = await relevance_feedback("I need to understand how to build a staircase in an apartment building", "PublicMemory")
print(aa)
# document_summary = {
# 'DocumentCategory': 'Science',
# 'Title': 'The Future of AI',