diff --git a/main.py b/main.py index 93d8b0553..b91c6cec4 100644 --- a/main.py +++ b/main.py @@ -39,7 +39,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from cognitive_architecture.utils import get_document_names, generate_letter_uuid, get_memory_name_by_doc_id, get_unsumarized_vector_db_namespace, get_vectordb_namespace, get_vectordb_document_name from cognitive_architecture.shared.language_processing import translate_text, detect_language - +from cognitive_architecture.classifiers.classifier import classify_user_input async def fetch_document_vectordb_namespace(session: AsyncSession, user_id: str, namespace_id:str, doc_id:str=None): logging.info("user id is", user_id) @@ -609,10 +609,14 @@ async def unlink_user_from_memory(user_id: str=None, labels:list=None, topic:str logging.error(f"Error creating public memory node: {e}") return None -async def relevance_feedback( query: str, input_type: str): - from cognitive_architecture.classifiers.classifier import classify_user_input +async def relevance_feedback(query: str, input_type: str): - result = await classify_user_input( query, input_type=input_type) + max_attempts = 6 + result = None + for attempt in range(1, max_attempts + 1): + result = await classify_user_input(query, input_type=input_type) + if isinstance(result, bool): + break # Exit the loop if a result of type bool is obtained return result async def main():