diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 49c535d81..1770b5b9e 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1099,9 +1099,20 @@ class GeminiChat(Base): if system: self.model._system_instruction = content_types.to_content(system) - response = self.model.generate_content(hist, generation_config=gen_conf) - ans = response.text - return ans, response.usage_metadata.total_token_count + retry_count = 0 + max_retries = 3 + while retry_count < max_retries: + try: + response = self.model.generate_content(hist, generation_config=gen_conf) + ans = response.text + return ans, response.usage_metadata.total_token_count + except Exception as e: + retry_count += 1 + if retry_count >= max_retries: + raise e + else: + import time + time.sleep(50) def chat_streamly(self, system, history, gen_conf={}, **kwargs): from google.generativeai.types import content_types