diff --git a/cognee/infrastructure/llm/gemini/adapter.py b/cognee/infrastructure/llm/gemini/adapter.py index f37fb1c80..f692ef485 100644 --- a/cognee/infrastructure/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/gemini/adapter.py @@ -41,71 +41,79 @@ class GeminiAdapter(LLMInterface): self, text_input: str, system_prompt: str, response_model: Type[BaseModel] ) -> BaseModel: try: - response_schema = { - "type": "object", - "properties": { - "summary": {"type": "string"}, - "description": {"type": "string"}, - "nodes": { - "type": "array", - "items": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "type": {"type": "string"}, - "description": {"type": "string"}, - "id": {"type": "string"}, - "label": {"type": "string"}, + if response_model is str: + simplified_prompt = system_prompt + response_schema = {"type": "string"} + else: + response_schema = { + "type": "object", + "properties": { + "summary": {"type": "string"}, + "description": {"type": "string"}, + "nodes": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "type": {"type": "string"}, + "description": {"type": "string"}, + "id": {"type": "string"}, + "label": {"type": "string"}, + }, + "required": ["name", "type", "description", "id", "label"], + }, + }, + "edges": { + "type": "array", + "items": { + "type": "object", + "properties": { + "source_node_id": {"type": "string"}, + "target_node_id": {"type": "string"}, + "relationship_name": {"type": "string"}, + }, + "required": [ + "source_node_id", + "target_node_id", + "relationship_name", + ], }, - "required": ["name", "type", "description", "id", "label"], }, }, - "edges": { - "type": "array", - "items": { - "type": "object", - "properties": { - "source_node_id": {"type": "string"}, - "target_node_id": {"type": "string"}, - "relationship_name": {"type": "string"}, - }, - "required": ["source_node_id", "target_node_id", "relationship_name"], - }, - }, - }, - "required": ["summary", "description", "nodes", "edges"], - } + "required": ["summary", "description", "nodes", "edges"], + } - simplified_prompt = f""" -{system_prompt} + simplified_prompt = f""" + {system_prompt} -IMPORTANT: Your response must be a valid JSON object with these required fields: -1. summary: A brief summary -2. description: A detailed description -3. nodes: Array of nodes with name, type, description, id, and label -4. edges: Array of edges with source_node_id, target_node_id, and relationship_name + IMPORTANT: Your response must be a valid JSON object with these required fields: + 1. summary: A brief summary + 2. description: A detailed description + 3. nodes: Array of nodes with name, type, description, id, and label + 4. edges: Array of edges with source_node_id, target_node_id, and relationship_name -Example structure: -{{ - "summary": "Brief summary", - "description": "Detailed description", - "nodes": [ + Example structure: {{ - "name": "Example Node", - "type": "Concept", - "description": "Node description", - "id": "example-id", - "label": "Concept" - }} - ], - "edges": [ - {{ - "source_node_id": "source-id", - "target_node_id": "target-id", - "relationship_name": "relates_to" - }} - ] -}}""" + "summary": "Brief summary", + "description": "Detailed description", + "nodes": [ + {{ + "name": "Example Node", + "type": "Concept", + "description": "Node description", + "id": "example-id", + "label": "Concept" + }} + ], + "edges": [ + {{ + "source_node_id": "source-id", + "target_node_id": "target-id", + "relationship_name": "relates_to" + }} + ] + }}""" messages = [ {"role": "system", "content": simplified_prompt}, @@ -120,12 +128,14 @@ Example structure: max_tokens=self.max_tokens, temperature=0.1, response_format={"type": "json_object", "schema": response_schema}, - timeout=10, + timeout=100, num_retries=self.MAX_RETRIES, ) if response.choices and response.choices[0].message.content: content = response.choices[0].message.content + if response_model is str: + return content return response_model.model_validate_json(content) except litellm.exceptions.BadRequestError as e: diff --git a/cognee/infrastructure/llm/utils.py b/cognee/infrastructure/llm/utils.py index d0479fb30..4009c5ae0 100644 --- a/cognee/infrastructure/llm/utils.py +++ b/cognee/infrastructure/llm/utils.py @@ -46,6 +46,7 @@ async def test_llm_connection(): system_prompt='Respond to me with the following string: "test"', response_model=str, ) + except Exception as e: logger.error(e) logger.error("Connection to LLM could not be established.")