Merge remote-tracking branch 'origin/COG-23' into COG-23
This commit is contained in:
commit
6afecd268e
3 changed files with 4 additions and 4 deletions
|
|
@ -59,6 +59,6 @@ async def classify_documents(query: str, document_id: str, content: str):
|
|||
{"query": query, "d_id": document_id, "context": str(document_context)}
|
||||
)
|
||||
arguments_str = classifier_output.additional_kwargs["function_call"]["arguments"]
|
||||
print("This is the arguments string", arguments_str)
|
||||
logging.info("This is the arguments string %s", arguments_str)
|
||||
arguments_dict = json.loads(arguments_str)
|
||||
return arguments_dict
|
||||
|
|
@ -52,11 +52,11 @@ async def classify_summary(query, document_summaries):
|
|||
{"query": query, "document_summaries": document_summaries}
|
||||
)
|
||||
arguments_str = classifier_output.additional_kwargs["function_call"]["arguments"]
|
||||
print("This is the arguments string", arguments_str)
|
||||
logging.info("This is the arguments string %s", arguments_str)
|
||||
arguments_dict = json.loads(arguments_str)
|
||||
logging.info("Relevant summary is %s", arguments_dict.get("DocumentSummary", None))
|
||||
classfier_id = arguments_dict.get("d_id", None)
|
||||
|
||||
print("This is the classifier id ", classfier_id)
|
||||
logging.info("This is the classifier id %s", classfier_id)
|
||||
|
||||
return classfier_id
|
||||
|
|
@ -36,7 +36,7 @@ async def classify_user_query(query, context, document_types):
|
|||
"description": "The classification of documents in groups such as legal, medical, etc.",
|
||||
}
|
||||
},
|
||||
"required": ["UserQueryClassiffier"],
|
||||
"required": ["UserQueryClassifier"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue