Fix context retireval
This commit is contained in:
parent
50e4d7c1e6
commit
f41f09171c
3 changed files with 61 additions and 7 deletions
16
api.py
16
api.py
|
|
@ -165,10 +165,7 @@ async def document_to_graph_db(payload: Payload):
|
||||||
async def cognitive_context_enrichment(payload: Payload):
|
async def cognitive_context_enrichment(payload: Payload):
|
||||||
try:
|
try:
|
||||||
decoded_payload = payload.payload
|
decoded_payload = payload.payload
|
||||||
|
|
||||||
# Execute the query - replace this with the actual execution method
|
|
||||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||||
# Assuming you have a method in Neo4jGraphDB to execute the query
|
|
||||||
result = await user_context_enrichment(session, user_id = decoded_payload['user_id'], query= decoded_payload['query'], generative_response=decoded_payload['generative_response'], memory_type= decoded_payload['memory_type'])
|
result = await user_context_enrichment(session, user_id = decoded_payload['user_id'], query= decoded_payload['query'], generative_response=decoded_payload['generative_response'], memory_type= decoded_payload['memory_type'])
|
||||||
return JSONResponse(content={"response": result}, status_code=200)
|
return JSONResponse(content={"response": result}, status_code=200)
|
||||||
|
|
||||||
|
|
@ -176,6 +173,19 @@ async def cognitive_context_enrichment(payload: Payload):
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/classify-user-query")
|
||||||
|
async def classify_user_query(payload: Payload):
|
||||||
|
try:
|
||||||
|
decoded_payload = payload.payload
|
||||||
|
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||||
|
from main import relevance_feedback
|
||||||
|
result = await relevance_feedback( query= decoded_payload['query'], input_type=decoded_payload['knowledge_type'])
|
||||||
|
return JSONResponse(content={"response": result}, status_code=200)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@app.post("/user-query-classifier")
|
@app.post("/user-query-classifier")
|
||||||
async def user_query_classfier(payload: Payload):
|
async def user_query_classfier(payload: Payload):
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,35 @@ def classify_retrieval():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def classify_user_input(query, input_type):
|
||||||
|
|
||||||
|
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||||
|
prompt_classify = ChatPromptTemplate.from_template(
|
||||||
|
"""You are a classifier. Determine with a True or False if the following input: {query}, is relevant for the following memory category: {input_type}"""
|
||||||
|
)
|
||||||
|
json_structure = [{
|
||||||
|
"name": "classifier",
|
||||||
|
"description": "Classification",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"InputClassification": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "The classification of the input"
|
||||||
|
}
|
||||||
|
}, "required": ["InputClassification"] }
|
||||||
|
}]
|
||||||
|
chain_filter = prompt_classify | llm.bind(function_call={"name": "classifier"}, functions=json_structure)
|
||||||
|
classifier_output = await chain_filter.ainvoke({"query": query, "input_type": input_type})
|
||||||
|
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
|
||||||
|
logging.info("This is the arguments string %s", arguments_str)
|
||||||
|
arguments_dict = json.loads(arguments_str)
|
||||||
|
logging.info("Relevant summary is %s", arguments_dict.get('DocumentSummary', None))
|
||||||
|
InputClassification = arguments_dict.get('InputClassification', None)
|
||||||
|
logging.info("This is the classification %s", InputClassification)
|
||||||
|
return InputClassification
|
||||||
|
|
||||||
|
|
||||||
# classify documents according to type of document
|
# classify documents according to type of document
|
||||||
async def classify_call(query, document_summaries):
|
async def classify_call(query, document_summaries):
|
||||||
|
|
||||||
|
|
@ -102,6 +131,7 @@ async def classify_call(query, document_summaries):
|
||||||
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
|
arguments_str = classifier_output.additional_kwargs['function_call']['arguments']
|
||||||
print("This is the arguments string", arguments_str)
|
print("This is the arguments string", arguments_str)
|
||||||
arguments_dict = json.loads(arguments_str)
|
arguments_dict = json.loads(arguments_str)
|
||||||
|
logging.info("Relevant summary is %s", arguments_dict.get('DocumentSummary', None))
|
||||||
classfier_id = arguments_dict.get('d_id', None)
|
classfier_id = arguments_dict.get('d_id', None)
|
||||||
|
|
||||||
print("This is the classifier id ", classfier_id)
|
print("This is the classifier id ", classfier_id)
|
||||||
|
|
|
||||||
22
main.py
22
main.py
|
|
@ -392,7 +392,7 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
|
||||||
|
|
||||||
if detect_language(query) != "en":
|
if detect_language(query) != "en":
|
||||||
query = translate_text(query, "sr", "en")
|
query = translate_text(query, "sr", "en")
|
||||||
logging.info("Translated query is", query)
|
logging.info("Translated query is %s", str(query))
|
||||||
|
|
||||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
||||||
password=config.graph_database_password)
|
password=config.graph_database_password)
|
||||||
|
|
@ -404,8 +404,15 @@ async def user_context_enrichment(session, user_id:str, query:str, generative_re
|
||||||
# summaries = [record.get("summary") for record in result]
|
# summaries = [record.get("summary") for record in result]
|
||||||
# logging.info('Possible document categories are', str(result))
|
# logging.info('Possible document categories are', str(result))
|
||||||
# logging.info('Possible document categories are', str(categories))
|
# logging.info('Possible document categories are', str(categories))
|
||||||
relevant_summary_id = await classify_call( query= query, document_summaries=str(summaries))
|
|
||||||
|
|
||||||
|
max_attempts = 3
|
||||||
|
relevant_summary_id = None
|
||||||
|
|
||||||
|
for _ in range(max_attempts):
|
||||||
|
relevant_summary_id = await classify_call( query= query, document_summaries=str(summaries))
|
||||||
|
|
||||||
|
if relevant_summary_id is not None:
|
||||||
|
break
|
||||||
|
|
||||||
# logging.info("Relevant categories after the classifier are %s", relevant_categories)
|
# logging.info("Relevant categories after the classifier are %s", relevant_categories)
|
||||||
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
neo4j_graph_db = Neo4jGraphDB(url=config.graph_database_url, username=config.graph_database_username,
|
||||||
|
|
@ -602,7 +609,11 @@ async def unlink_user_from_memory(user_id: str=None, labels:list=None, topic:str
|
||||||
logging.error(f"Error creating public memory node: {e}")
|
logging.error(f"Error creating public memory node: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def relevance_feedback( query: str, input_type: str):
|
||||||
|
from cognitive_architecture.classifiers.classifier import classify_user_input
|
||||||
|
|
||||||
|
result = await classify_user_input( query, input_type=input_type)
|
||||||
|
return result
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
user_id = "user"
|
user_id = "user"
|
||||||
|
|
@ -687,8 +698,11 @@ async def main():
|
||||||
|
|
||||||
# await attach_user_to_memory(user_id=user_id, labels=['sr'], topic="PublicMemory")
|
# await attach_user_to_memory(user_id=user_id, labels=['sr'], topic="PublicMemory")
|
||||||
|
|
||||||
return_ = await user_context_enrichment(user_id=user_id, query="Koja je minimalna visina ograde na balkonu na stambenom objektu", session=session, memory_type="PublicMemory", generative_response=True)
|
# return_ = await user_context_enrichment(user_id=user_id, query="hi how are you", session=session, memory_type="PublicMemory", generative_response=True)
|
||||||
print(return_)
|
# print(return_)
|
||||||
|
aa = await relevance_feedback("I need to understand how to build a staircase in an apartment building", "PublicMemory")
|
||||||
|
print(aa)
|
||||||
|
|
||||||
# document_summary = {
|
# document_summary = {
|
||||||
# 'DocumentCategory': 'Science',
|
# 'DocumentCategory': 'Science',
|
||||||
# 'Title': 'The Future of AI',
|
# 'Title': 'The Future of AI',
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue