Change data models for gemini (#600)
<!-- .github/pull_request_template.md --> ## Description Change Gemini adapter and data models so Gemini can use custom data models ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced provider-specific enhancements with updated data representations, including improved node labeling and enriched summary and description fields for graph displays. - Improved configuration management by automatically loading environment settings for better LLM operations. - **Refactor** - Streamlined response handling with a simplified approach for defining output formats. - Updated error handling by removing the try-except block for dotenv imports. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
5eef212668
commit
cade574bbf
4 changed files with 63 additions and 98 deletions
|
|
@ -13,9 +13,6 @@ from .modules.data.operations.get_pipeline_run_metrics import get_pipeline_run_m
|
||||||
# Pipelines
|
# Pipelines
|
||||||
from .modules import pipelines
|
from .modules import pipelines
|
||||||
|
|
||||||
try:
|
import dotenv
|
||||||
import dotenv
|
|
||||||
|
|
||||||
dotenv.load_dotenv(override=True)
|
dotenv.load_dotenv(override=True)
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
|
||||||
|
|
@ -42,81 +42,12 @@ class GeminiAdapter(LLMInterface):
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
try:
|
try:
|
||||||
if response_model is str:
|
if response_model is str:
|
||||||
simplified_prompt = system_prompt
|
|
||||||
response_schema = {"type": "string"}
|
response_schema = {"type": "string"}
|
||||||
else:
|
else:
|
||||||
response_schema = {
|
response_schema = response_model
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"summary": {"type": "string"},
|
|
||||||
"description": {"type": "string"},
|
|
||||||
"nodes": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"},
|
|
||||||
"type": {"type": "string"},
|
|
||||||
"description": {"type": "string"},
|
|
||||||
"id": {"type": "string"},
|
|
||||||
"label": {"type": "string"},
|
|
||||||
},
|
|
||||||
"required": ["name", "type", "description", "id", "label"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"edges": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"source_node_id": {"type": "string"},
|
|
||||||
"target_node_id": {"type": "string"},
|
|
||||||
"relationship_name": {"type": "string"},
|
|
||||||
},
|
|
||||||
"required": [
|
|
||||||
"source_node_id",
|
|
||||||
"target_node_id",
|
|
||||||
"relationship_name",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["summary", "description", "nodes", "edges"],
|
|
||||||
}
|
|
||||||
|
|
||||||
simplified_prompt = f"""
|
|
||||||
{system_prompt}
|
|
||||||
|
|
||||||
IMPORTANT: Your response must be a valid JSON object with these required fields:
|
|
||||||
1. summary: A brief summary
|
|
||||||
2. description: A detailed description
|
|
||||||
3. nodes: Array of nodes with name, type, description, id, and label
|
|
||||||
4. edges: Array of edges with source_node_id, target_node_id, and relationship_name
|
|
||||||
|
|
||||||
Example structure:
|
|
||||||
{{
|
|
||||||
"summary": "Brief summary",
|
|
||||||
"description": "Detailed description",
|
|
||||||
"nodes": [
|
|
||||||
{{
|
|
||||||
"name": "Example Node",
|
|
||||||
"type": "Concept",
|
|
||||||
"description": "Node description",
|
|
||||||
"id": "example-id",
|
|
||||||
"label": "Concept"
|
|
||||||
}}
|
|
||||||
],
|
|
||||||
"edges": [
|
|
||||||
{{
|
|
||||||
"source_node_id": "source-id",
|
|
||||||
"target_node_id": "target-id",
|
|
||||||
"relationship_name": "relates_to"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}"""
|
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": simplified_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": text_input},
|
{"role": "user", "content": text_input},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -127,7 +58,7 @@ class GeminiAdapter(LLMInterface):
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
response_format={"type": "json_object", "schema": response_schema},
|
response_format=response_schema,
|
||||||
timeout=100,
|
timeout=100,
|
||||||
num_retries=self.MAX_RETRIES,
|
num_retries=self.MAX_RETRIES,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
# Loading LLM Config from data_models.py requires to have dotenv imported first
|
||||||
|
# and to have it loaded
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
dotenv.load_dotenv(override=True)
|
||||||
|
|
@ -4,36 +4,67 @@ from enum import Enum, auto
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
|
if get_llm_config().llm_provider.lower() == "gemini":
|
||||||
|
"""
|
||||||
|
Note: Gemini doesn't allow for an empty dictionary to be a part of the data model
|
||||||
|
so we created new data models to bypass that issue, but other LLMs have slightly worse performance
|
||||||
|
when creating knowledge graphs with these data models compared to the old data models
|
||||||
|
so now there's an if statement here so that the rest of the LLMs can use the old data models.
|
||||||
|
"""
|
||||||
|
|
||||||
class Node(BaseModel):
|
class Node(BaseModel):
|
||||||
"""Node in a knowledge graph."""
|
"""Node in a knowledge graph."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
type: str
|
type: str
|
||||||
description: str
|
description: str
|
||||||
properties: Optional[Dict[str, Any]] = Field(
|
label: str
|
||||||
None, description="A dictionary of properties associated with the node."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
class Edge(BaseModel):
|
||||||
|
"""Edge in a knowledge graph."""
|
||||||
|
|
||||||
class Edge(BaseModel):
|
source_node_id: str
|
||||||
"""Edge in a knowledge graph."""
|
target_node_id: str
|
||||||
|
relationship_name: str
|
||||||
|
|
||||||
source_node_id: str
|
class KnowledgeGraph(BaseModel):
|
||||||
target_node_id: str
|
"""Knowledge graph."""
|
||||||
relationship_name: str
|
|
||||||
properties: Optional[Dict[str, Any]] = Field(
|
|
||||||
None, description="A dictionary of properties associated with the edge."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
summary: str
|
||||||
|
description: str
|
||||||
|
nodes: List[Node] = Field(..., default_factory=list)
|
||||||
|
edges: List[Edge] = Field(..., default_factory=list)
|
||||||
|
else:
|
||||||
|
|
||||||
class KnowledgeGraph(BaseModel):
|
class Node(BaseModel):
|
||||||
"""Knowledge graph."""
|
"""Node in a knowledge graph."""
|
||||||
|
|
||||||
nodes: List[Node] = Field(..., default_factory=list)
|
id: str
|
||||||
edges: List[Edge] = Field(..., default_factory=list)
|
name: str
|
||||||
|
type: str
|
||||||
|
description: str
|
||||||
|
properties: Optional[Dict[str, Any]] = Field(
|
||||||
|
None, description="A dictionary of properties associated with the node."
|
||||||
|
)
|
||||||
|
|
||||||
|
class Edge(BaseModel):
|
||||||
|
"""Edge in a knowledge graph."""
|
||||||
|
|
||||||
|
source_node_id: str
|
||||||
|
target_node_id: str
|
||||||
|
relationship_name: str
|
||||||
|
properties: Optional[Dict[str, Any]] = Field(
|
||||||
|
None, description="A dictionary of properties associated with the edge."
|
||||||
|
)
|
||||||
|
|
||||||
|
class KnowledgeGraph(BaseModel):
|
||||||
|
"""Knowledge graph."""
|
||||||
|
|
||||||
|
nodes: List[Node] = Field(..., default_factory=list)
|
||||||
|
edges: List[Edge] = Field(..., default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class GraphQLQuery(BaseModel):
|
class GraphQLQuery(BaseModel):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue