refactor: remove LLMGateway usage where not needed

This commit is contained in:
Igor Ilic 2025-09-09 13:50:16 +02:00
parent 89b51a244d
commit f1144abc54
18 changed files with 46 additions and 101 deletions

View file

@ -3,6 +3,7 @@ from pydantic import BaseModel
from cognee.eval_framework.evaluation.base_eval_adapter import BaseEvalAdapter from cognee.eval_framework.evaluation.base_eval_adapter import BaseEvalAdapter
from cognee.eval_framework.eval_config import EvalConfig from cognee.eval_framework.eval_config import EvalConfig
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.infrastructure.llm import LLMGateway from cognee.infrastructure.llm import LLMGateway
@ -25,8 +26,8 @@ class DirectLLMEvalAdapter(BaseEvalAdapter):
) -> Dict[str, Any]: ) -> Dict[str, Any]:
args = {"question": question, "answer": answer, "golden_answer": golden_answer} args = {"question": question, "answer": answer, "golden_answer": golden_answer}
user_prompt = LLMGateway.render_prompt(self.eval_prompt_path, args) user_prompt = render_prompt(self.eval_prompt_path, args)
system_prompt = LLMGateway.read_query_prompt(self.system_prompt_path) system_prompt = read_query_prompt(self.system_prompt_path)
evaluation = await LLMGateway.acreate_structured_output( evaluation = await LLMGateway.acreate_structured_output(
text_input=user_prompt, text_input=user_prompt,

View file

@ -76,10 +76,10 @@ class LLMConfig(BaseSettings):
provider=self.baml_llm_provider, provider=self.baml_llm_provider,
options={ options={
"model": self.baml_llm_model, "model": self.baml_llm_model,
"temperature": self.baml_llm_temperature, # "temperature": self.baml_llm_temperature,
"api_key": self.baml_llm_api_key, "api_key": self.baml_llm_api_key,
"base_url": self.baml_llm_endpoint, # "base_url": self.baml_llm_endpoint,
"api_version": self.baml_llm_api_version, # "api_version": self.baml_llm_api_version,
}, },
) )
# Sets the primary client # Sets the primary client

View file

@ -68,35 +68,3 @@ class AnthropicAdapter(LLMInterface):
], ],
response_model=response_model, response_model=response_model,
) )
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""
Format and display the prompt for a user query.
Parameters:
-----------
- text_input (str): The input text from the user, defaults to a placeholder if
empty.
- system_prompt (str): The path to the system prompt to be read and formatted.
Returns:
--------
- str: A formatted string displaying the system prompt and user input.
"""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -113,34 +113,3 @@ class GeminiAdapter(LLMInterface):
logger.error(f"Schema validation failed: {str(e)}") logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}") logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}") raise ValueError(f"Response failed schema validation: {str(e)}")
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""
Format and display the prompt for a user query.
Raises an MissingQueryParameterError if no system prompt is provided.
Parameters:
-----------
- text_input (str): The user input text to display along with the system prompt.
- system_prompt (str): The path or content of the system prompt to be read and
displayed.
Returns:
--------
- str: Returns a formatted string containing the system prompt and user input.
"""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -7,6 +7,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
logger = get_logger("CodeRetriever") logger = get_logger("CodeRetriever")
@ -41,7 +42,7 @@ class CodeRetriever(BaseRetriever):
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'" f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
) )
system_prompt = LLMGateway.read_query_prompt("codegraph_retriever_system.txt") system_prompt = read_query_prompt("codegraph_retriever_system.txt")
try: try:
result = await LLMGateway.acreate_structured_output( result = await LLMGateway.acreate_structured_output(

View file

@ -1,9 +1,10 @@
from typing import Any, Optional, List, Tuple, Type from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
logger = get_logger() logger = get_logger()
@ -104,10 +105,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter: if round_idx < max_iter:
valid_args = {"query": query, "answer": completion, "context": context} valid_args = {"query": query, "answer": completion, "context": context}
valid_user_prompt = LLMGateway.render_prompt( valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args filename=self.validation_user_prompt_path, context=valid_args
) )
valid_system_prompt = LLMGateway.read_query_prompt( valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path prompt_file_name=self.validation_system_prompt_path
) )
@ -117,10 +118,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
response_model=str, response_model=str,
) )
followup_args = {"query": query, "answer": completion, "reasoning": reasoning} followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
followup_prompt = LLMGateway.render_prompt( followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args filename=self.followup_user_prompt_path, context=followup_args
) )
followup_system = LLMGateway.read_query_prompt( followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path prompt_file_name=self.followup_system_prompt_path
) )

View file

@ -2,6 +2,7 @@ from typing import Any, Optional
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported from cognee.modules.retrieval.exceptions import SearchTypeNotSupported
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
@ -49,7 +50,7 @@ class NaturalLanguageRetriever(BaseRetriever):
async def _generate_cypher_query(self, query: str, edge_schemas, previous_attempts=None) -> str: async def _generate_cypher_query(self, query: str, edge_schemas, previous_attempts=None) -> str:
"""Generate a Cypher query using LLM based on natural language query and schema information.""" """Generate a Cypher query using LLM based on natural language query and schema information."""
system_prompt = LLMGateway.render_prompt( system_prompt = render_prompt(
self.system_prompt_path, self.system_prompt_path,
context={ context={
"edge_schemas": edge_schemas, "edge_schemas": edge_schemas,

View file

@ -6,6 +6,7 @@ from operator import itemgetter
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm import LLMGateway from cognee.infrastructure.llm import LLMGateway
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
@ -72,7 +73,7 @@ class TemporalRetriever(GraphCompletionRetriever):
else: else:
base_directory = None base_directory = None
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory) system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
interval = await LLMGateway.acreate_structured_output(query, system_prompt, QueryInterval) interval = await LLMGateway.acreate_structured_output(query, system_prompt, QueryInterval)

View file

@ -1,5 +1,6 @@
from typing import Optional from typing import Optional
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
async def generate_completion( async def generate_completion(
@ -12,10 +13,8 @@ async def generate_completion(
) -> str: ) -> str:
"""Generates a completion using LLM with given context and prompts.""" """Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context} args = {"question": query, "context": context}
user_prompt = LLMGateway.render_prompt(user_prompt_path, args) user_prompt = render_prompt(user_prompt_path, args)
system_prompt = ( system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
)
if only_context: if only_context:
return context return context
@ -33,9 +32,7 @@ async def summarize_text(
system_prompt: str = None, system_prompt: str = None,
) -> str: ) -> str:
"""Summarizes text using LLM with the specified prompt.""" """Summarizes text using LLM with the specified prompt."""
system_prompt = ( system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
)
return await LLMGateway.acreate_structured_output( return await LLMGateway.acreate_structured_output(
text_input=text, text_input=text,

View file

@ -7,7 +7,7 @@ from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine.models import DataPoint from cognee.infrastructure.engine.models import DataPoint
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.extraction import extract_categories
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
@ -40,7 +40,7 @@ async def chunk_naive_llm_classifier(
return data_chunks return data_chunks
chunk_classifications = await asyncio.gather( chunk_classifications = await asyncio.gather(
*[LLMGateway.extract_categories(chunk.text, classification_model) for chunk in data_chunks], *[extract_categories(chunk.text, classification_model) for chunk in data_chunks],
) )
classification_data_points = [] classification_data_points = []

View file

@ -4,6 +4,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.low_level import DataPoint from cognee.low_level import DataPoint
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm import LLMGateway from cognee.infrastructure.llm import LLMGateway
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.models import NodeSet from cognee.modules.engine.models import NodeSet
@ -106,8 +107,8 @@ async def add_rule_associations(
user_context = {"chat": data, "rules": existing_rules} user_context = {"chat": data, "rules": existing_rules}
user_prompt = LLMGateway.render_prompt(user_prompt_location, context=user_context) user_prompt = render_prompt(user_prompt_location, context=user_context)
system_prompt = LLMGateway.render_prompt(system_prompt_location, context={}) system_prompt = render_prompt(system_prompt_location, context={})
rule_list = await LLMGateway.acreate_structured_output( rule_list = await LLMGateway.acreate_structured_output(
text_input=user_prompt, system_prompt=system_prompt, response_model=RuleSet text_input=user_prompt, system_prompt=system_prompt, response_model=RuleSet

View file

@ -3,6 +3,7 @@ from typing import List
from pydantic import BaseModel from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
from cognee.modules.engine.models import Entity from cognee.modules.engine.models import Entity
from cognee.modules.engine.models.EntityType import EntityType from cognee.modules.engine.models.EntityType import EntityType
@ -50,8 +51,8 @@ class LLMEntityExtractor(BaseEntityExtractor):
try: try:
logger.info(f"Extracting entities from text: {text[:100]}...") logger.info(f"Extracting entities from text: {text[:100]}...")
user_prompt = LLMGateway.render_prompt(self.user_prompt_template, {"text": text}) user_prompt = render_prompt(self.user_prompt_template, {"text": text})
system_prompt = LLMGateway.read_query_prompt(self.system_prompt_template) system_prompt = read_query_prompt(self.system_prompt_template)
response = await LLMGateway.acreate_structured_output( response = await LLMGateway.acreate_structured_output(
text_input=user_prompt, text_input=user_prompt,

View file

@ -1,6 +1,7 @@
from typing import List, Tuple from typing import List, Tuple
from pydantic import BaseModel from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.root_dir import get_absolute_path from cognee.root_dir import get_absolute_path
@ -32,12 +33,12 @@ async def extract_content_nodes_and_relationship_names(
} }
base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts") base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts")
text_input = LLMGateway.render_prompt( text_input = render_prompt(
"extract_graph_relationship_names_prompt_input.txt", "extract_graph_relationship_names_prompt_input.txt",
context, context,
base_directory=base_directory, base_directory=base_directory,
) )
system_prompt = LLMGateway.read_query_prompt( system_prompt = read_query_prompt(
"extract_graph_relationship_names_prompt_system.txt", base_directory=base_directory "extract_graph_relationship_names_prompt_system.txt", base_directory=base_directory
) )
response = await LLMGateway.acreate_structured_output( response = await LLMGateway.acreate_structured_output(

View file

@ -1,5 +1,6 @@
from typing import List from typing import List
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.shared.data_models import KnowledgeGraph from cognee.shared.data_models import KnowledgeGraph
from cognee.root_dir import get_absolute_path from cognee.root_dir import get_absolute_path
@ -26,10 +27,10 @@ async def extract_edge_triplets(
} }
base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts") base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts")
text_input = LLMGateway.render_prompt( text_input = render_prompt(
"extract_graph_edge_triplets_prompt_input.txt", context, base_directory=base_directory "extract_graph_edge_triplets_prompt_input.txt", context, base_directory=base_directory
) )
system_prompt = LLMGateway.read_query_prompt( system_prompt = read_query_prompt(
"extract_graph_edge_triplets_prompt_system.txt", base_directory=base_directory "extract_graph_edge_triplets_prompt_system.txt", base_directory=base_directory
) )
extracted_graph = await LLMGateway.acreate_structured_output( extracted_graph = await LLMGateway.acreate_structured_output(

View file

@ -1,6 +1,7 @@
from typing import List from typing import List
from pydantic import BaseModel from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.root_dir import get_absolute_path from cognee.root_dir import get_absolute_path
@ -24,10 +25,10 @@ async def extract_nodes(text: str, n_rounds: int = 2) -> List[str]:
"text": text, "text": text,
} }
base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts") base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts")
text_input = LLMGateway.render_prompt( text_input = render_prompt(
"extract_graph_nodes_prompt_input.txt", context, base_directory=base_directory "extract_graph_nodes_prompt_input.txt", context, base_directory=base_directory
) )
system_prompt = LLMGateway.read_query_prompt( system_prompt = read_query_prompt(
"extract_graph_nodes_prompt_system.txt", base_directory=base_directory "extract_graph_nodes_prompt_system.txt", base_directory=base_directory
) )
response = await LLMGateway.acreate_structured_output( response = await LLMGateway.acreate_structured_output(

View file

@ -2,7 +2,7 @@ import asyncio
from typing import Type, List from typing import Type, List
from pydantic import BaseModel from pydantic import BaseModel
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.extraction import extract_content_graph
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
@ -18,7 +18,7 @@ async def extract_graph_from_code(
- Graph nodes are stored using the `add_data_points` function for later retrieval or analysis. - Graph nodes are stored using the `add_data_points` function for later retrieval or analysis.
""" """
chunk_graphs = await asyncio.gather( chunk_graphs = await asyncio.gather(
*[LLMGateway.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks] *[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
) )
for chunk_index, chunk in enumerate(data_chunks): for chunk_index, chunk in enumerate(data_chunks):

View file

@ -3,7 +3,7 @@ from typing import AsyncGenerator, Union
from uuid import uuid5 from uuid import uuid5
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.extraction import extract_code_summary
from .models import CodeSummary from .models import CodeSummary
@ -16,7 +16,7 @@ async def summarize_code(
code_data_points = [file for file in code_graph_nodes if hasattr(file, "source_code")] code_data_points = [file for file in code_graph_nodes if hasattr(file, "source_code")]
file_summaries = await asyncio.gather( file_summaries = await asyncio.gather(
*[LLMGateway.extract_code_summary(file.source_code) for file in code_data_points] *[extract_code_summary(file.source_code) for file in code_data_points]
) )
file_summaries_map = { file_summaries_map = {

View file

@ -12,6 +12,7 @@ from cognee.tasks.temporal_awareness.index_graphiti_objects import (
) )
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
@ -59,8 +60,8 @@ async def main():
"context": context, "context": context,
} }
user_prompt = LLMGateway.render_prompt("graph_context_for_question.txt", args) user_prompt = render_prompt("graph_context_for_question.txt", args)
system_prompt = LLMGateway.read_query_prompt("answer_simple_question_restricted.txt") system_prompt = read_query_prompt("answer_simple_question_restricted.txt")
computed_answer = await LLMGateway.acreate_structured_output( computed_answer = await LLMGateway.acreate_structured_output(
text_input=user_prompt, text_input=user_prompt,