diff --git a/cognee/infrastructure/llm/gemini/adapter.py b/cognee/infrastructure/llm/gemini/adapter.py index 7beea94dc..f692ef485 100644 --- a/cognee/infrastructure/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/gemini/adapter.py @@ -42,12 +42,81 @@ class GeminiAdapter(LLMInterface): ) -> BaseModel: try: if response_model is str: + simplified_prompt = system_prompt response_schema = {"type": "string"} else: - response_schema = response_model + response_schema = { + "type": "object", + "properties": { + "summary": {"type": "string"}, + "description": {"type": "string"}, + "nodes": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "type": {"type": "string"}, + "description": {"type": "string"}, + "id": {"type": "string"}, + "label": {"type": "string"}, + }, + "required": ["name", "type", "description", "id", "label"], + }, + }, + "edges": { + "type": "array", + "items": { + "type": "object", + "properties": { + "source_node_id": {"type": "string"}, + "target_node_id": {"type": "string"}, + "relationship_name": {"type": "string"}, + }, + "required": [ + "source_node_id", + "target_node_id", + "relationship_name", + ], + }, + }, + }, + "required": ["summary", "description", "nodes", "edges"], + } + + simplified_prompt = f""" + {system_prompt} + + IMPORTANT: Your response must be a valid JSON object with these required fields: + 1. summary: A brief summary + 2. description: A detailed description + 3. nodes: Array of nodes with name, type, description, id, and label + 4. edges: Array of edges with source_node_id, target_node_id, and relationship_name + + Example structure: + {{ + "summary": "Brief summary", + "description": "Detailed description", + "nodes": [ + {{ + "name": "Example Node", + "type": "Concept", + "description": "Node description", + "id": "example-id", + "label": "Concept" + }} + ], + "edges": [ + {{ + "source_node_id": "source-id", + "target_node_id": "target-id", + "relationship_name": "relates_to" + }} + ] + }}""" messages = [ - {"role": "system", "content": system_prompt}, + {"role": "system", "content": simplified_prompt}, {"role": "user", "content": text_input}, ] @@ -58,7 +127,7 @@ class GeminiAdapter(LLMInterface): api_key=self.api_key, max_tokens=self.max_tokens, temperature=0.1, - response_format=response_schema, + response_format={"type": "json_object", "schema": response_schema}, timeout=100, num_retries=self.MAX_RETRIES, ) diff --git a/cognee/shared/data_models.py b/cognee/shared/data_models.py index d87d645a6..d23d2841c 100644 --- a/cognee/shared/data_models.py +++ b/cognee/shared/data_models.py @@ -4,67 +4,36 @@ from enum import Enum, auto from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field -from cognee.infrastructure.llm.config import get_llm_config -if get_llm_config().llm_provider.lower() == "gemini": - """ - Note: Gemini doesn't allow for an empty dictionary to be a part of the data model - so we created new data models to bypass that issue, but other LLMs have slightly worse performance - when creating knowledge graphs with these data models compared to the old data models - so now there's an if statement here so that the rest of the LLMs can use the old data models. - """ - class Node(BaseModel): - """Node in a knowledge graph.""" +class Node(BaseModel): + """Node in a knowledge graph.""" - id: str - name: str - type: str - description: str - label: str + id: str + name: str + type: str + description: str + properties: Optional[Dict[str, Any]] = Field( + None, description="A dictionary of properties associated with the node." + ) - class Edge(BaseModel): - """Edge in a knowledge graph.""" - source_node_id: str - target_node_id: str - relationship_name: str +class Edge(BaseModel): + """Edge in a knowledge graph.""" - class KnowledgeGraph(BaseModel): - """Knowledge graph.""" + source_node_id: str + target_node_id: str + relationship_name: str + properties: Optional[Dict[str, Any]] = Field( + None, description="A dictionary of properties associated with the edge." + ) - summary: str - description: str - nodes: List[Node] = Field(..., default_factory=list) - edges: List[Edge] = Field(..., default_factory=list) -else: - class Node(BaseModel): - """Node in a knowledge graph.""" +class KnowledgeGraph(BaseModel): + """Knowledge graph.""" - id: str - name: str - type: str - description: str - properties: Optional[Dict[str, Any]] = Field( - None, description="A dictionary of properties associated with the node." - ) - - class Edge(BaseModel): - """Edge in a knowledge graph.""" - - source_node_id: str - target_node_id: str - relationship_name: str - properties: Optional[Dict[str, Any]] = Field( - None, description="A dictionary of properties associated with the edge." - ) - - class KnowledgeGraph(BaseModel): - """Knowledge graph.""" - - nodes: List[Node] = Field(..., default_factory=list) - edges: List[Edge] = Field(..., default_factory=list) + nodes: List[Node] = Field(..., default_factory=list) + edges: List[Edge] = Field(..., default_factory=list) class GraphQLQuery(BaseModel):