Change data models for gemini (#600)
<!-- .github/pull_request_template.md --> ## Description Change Gemini adapter and data models so Gemini can use custom data models ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced provider-specific enhancements with updated data representations, including improved node labeling and enriched summary and description fields for graph displays. - Improved configuration management by automatically loading environment settings for better LLM operations. - **Refactor** - Streamlined response handling with a simplified approach for defining output formats. - Updated error handling by removing the try-except block for dotenv imports. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
5eef212668
commit
cade574bbf
4 changed files with 63 additions and 98 deletions
|
|
@ -13,9 +13,6 @@ from .modules.data.operations.get_pipeline_run_metrics import get_pipeline_run_m
|
|||
# Pipelines
|
||||
from .modules import pipelines
|
||||
|
||||
try:
|
||||
import dotenv
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv(override=True)
|
||||
except ImportError:
|
||||
pass
|
||||
dotenv.load_dotenv(override=True)
|
||||
|
|
|
|||
|
|
@ -42,81 +42,12 @@ class GeminiAdapter(LLMInterface):
|
|||
) -> BaseModel:
|
||||
try:
|
||||
if response_model is str:
|
||||
simplified_prompt = system_prompt
|
||||
response_schema = {"type": "string"}
|
||||
else:
|
||||
response_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"summary": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"nodes": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"type": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"id": {"type": "string"},
|
||||
"label": {"type": "string"},
|
||||
},
|
||||
"required": ["name", "type", "description", "id", "label"],
|
||||
},
|
||||
},
|
||||
"edges": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source_node_id": {"type": "string"},
|
||||
"target_node_id": {"type": "string"},
|
||||
"relationship_name": {"type": "string"},
|
||||
},
|
||||
"required": [
|
||||
"source_node_id",
|
||||
"target_node_id",
|
||||
"relationship_name",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["summary", "description", "nodes", "edges"],
|
||||
}
|
||||
|
||||
simplified_prompt = f"""
|
||||
{system_prompt}
|
||||
|
||||
IMPORTANT: Your response must be a valid JSON object with these required fields:
|
||||
1. summary: A brief summary
|
||||
2. description: A detailed description
|
||||
3. nodes: Array of nodes with name, type, description, id, and label
|
||||
4. edges: Array of edges with source_node_id, target_node_id, and relationship_name
|
||||
|
||||
Example structure:
|
||||
{{
|
||||
"summary": "Brief summary",
|
||||
"description": "Detailed description",
|
||||
"nodes": [
|
||||
{{
|
||||
"name": "Example Node",
|
||||
"type": "Concept",
|
||||
"description": "Node description",
|
||||
"id": "example-id",
|
||||
"label": "Concept"
|
||||
}}
|
||||
],
|
||||
"edges": [
|
||||
{{
|
||||
"source_node_id": "source-id",
|
||||
"target_node_id": "target-id",
|
||||
"relationship_name": "relates_to"
|
||||
}}
|
||||
]
|
||||
}}"""
|
||||
response_schema = response_model
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": simplified_prompt},
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": text_input},
|
||||
]
|
||||
|
||||
|
|
@ -127,7 +58,7 @@ class GeminiAdapter(LLMInterface):
|
|||
api_key=self.api_key,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=0.1,
|
||||
response_format={"type": "json_object", "schema": response_schema},
|
||||
response_format=response_schema,
|
||||
timeout=100,
|
||||
num_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
# Loading LLM Config from data_models.py requires to have dotenv imported first
|
||||
# and to have it loaded
|
||||
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv(override=True)
|
||||
|
|
@ -4,36 +4,67 @@ from enum import Enum, auto
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
if get_llm_config().llm_provider.lower() == "gemini":
|
||||
"""
|
||||
Note: Gemini doesn't allow for an empty dictionary to be a part of the data model
|
||||
so we created new data models to bypass that issue, but other LLMs have slightly worse performance
|
||||
when creating knowledge graphs with these data models compared to the old data models
|
||||
so now there's an if statement here so that the rest of the LLMs can use the old data models.
|
||||
"""
|
||||
|
||||
class Node(BaseModel):
|
||||
"""Node in a knowledge graph."""
|
||||
class Node(BaseModel):
|
||||
"""Node in a knowledge graph."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
properties: Optional[Dict[str, Any]] = Field(
|
||||
None, description="A dictionary of properties associated with the node."
|
||||
)
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
label: str
|
||||
|
||||
class Edge(BaseModel):
|
||||
"""Edge in a knowledge graph."""
|
||||
|
||||
class Edge(BaseModel):
|
||||
"""Edge in a knowledge graph."""
|
||||
source_node_id: str
|
||||
target_node_id: str
|
||||
relationship_name: str
|
||||
|
||||
source_node_id: str
|
||||
target_node_id: str
|
||||
relationship_name: str
|
||||
properties: Optional[Dict[str, Any]] = Field(
|
||||
None, description="A dictionary of properties associated with the edge."
|
||||
)
|
||||
class KnowledgeGraph(BaseModel):
|
||||
"""Knowledge graph."""
|
||||
|
||||
summary: str
|
||||
description: str
|
||||
nodes: List[Node] = Field(..., default_factory=list)
|
||||
edges: List[Edge] = Field(..., default_factory=list)
|
||||
else:
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
"""Knowledge graph."""
|
||||
class Node(BaseModel):
|
||||
"""Node in a knowledge graph."""
|
||||
|
||||
nodes: List[Node] = Field(..., default_factory=list)
|
||||
edges: List[Edge] = Field(..., default_factory=list)
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
properties: Optional[Dict[str, Any]] = Field(
|
||||
None, description="A dictionary of properties associated with the node."
|
||||
)
|
||||
|
||||
class Edge(BaseModel):
|
||||
"""Edge in a knowledge graph."""
|
||||
|
||||
source_node_id: str
|
||||
target_node_id: str
|
||||
relationship_name: str
|
||||
properties: Optional[Dict[str, Any]] = Field(
|
||||
None, description="A dictionary of properties associated with the edge."
|
||||
)
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
"""Knowledge graph."""
|
||||
|
||||
nodes: List[Node] = Field(..., default_factory=list)
|
||||
edges: List[Edge] = Field(..., default_factory=list)
|
||||
|
||||
|
||||
class GraphQLQuery(BaseModel):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue