refactor: remove LLMGateway usage where not needed

This commit is contained in:
Igor Ilic 2025-09-09 13:50:16 +02:00
parent 89b51a244d
commit f1144abc54
18 changed files with 46 additions and 101 deletions

View file

@ -3,6 +3,7 @@ from pydantic import BaseModel
from cognee.eval_framework.evaluation.base_eval_adapter import BaseEvalAdapter
from cognee.eval_framework.eval_config import EvalConfig
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.infrastructure.llm import LLMGateway
@ -25,8 +26,8 @@ class DirectLLMEvalAdapter(BaseEvalAdapter):
) -> Dict[str, Any]:
args = {"question": question, "answer": answer, "golden_answer": golden_answer}
user_prompt = LLMGateway.render_prompt(self.eval_prompt_path, args)
system_prompt = LLMGateway.read_query_prompt(self.system_prompt_path)
user_prompt = render_prompt(self.eval_prompt_path, args)
system_prompt = read_query_prompt(self.system_prompt_path)
evaluation = await LLMGateway.acreate_structured_output(
text_input=user_prompt,

View file

@ -76,10 +76,10 @@ class LLMConfig(BaseSettings):
provider=self.baml_llm_provider,
options={
"model": self.baml_llm_model,
"temperature": self.baml_llm_temperature,
# "temperature": self.baml_llm_temperature,
"api_key": self.baml_llm_api_key,
"base_url": self.baml_llm_endpoint,
"api_version": self.baml_llm_api_version,
# "base_url": self.baml_llm_endpoint,
# "api_version": self.baml_llm_api_version,
},
)
# Sets the primary client

View file

@ -68,35 +68,3 @@ class AnthropicAdapter(LLMInterface):
],
response_model=response_model,
)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""
Format and display the prompt for a user query.
Parameters:
-----------
- text_input (str): The input text from the user, defaults to a placeholder if
empty.
- system_prompt (str): The path to the system prompt to be read and formatted.
Returns:
--------
- str: A formatted string displaying the system prompt and user input.
"""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -113,34 +113,3 @@ class GeminiAdapter(LLMInterface):
logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}")
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""
Format and display the prompt for a user query.
Raises an MissingQueryParameterError if no system prompt is provided.
Parameters:
-----------
- text_input (str): The user input text to display along with the system prompt.
- system_prompt (str): The path or content of the system prompt to be read and
displayed.
Returns:
--------
- str: Returns a formatted string containing the system prompt and user input.
"""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -7,6 +7,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway
logger = get_logger("CodeRetriever")
@ -41,7 +42,7 @@ class CodeRetriever(BaseRetriever):
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
system_prompt = LLMGateway.read_query_prompt("codegraph_retriever_system.txt")
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
try:
result = await LLMGateway.acreate_structured_output(

View file

@ -1,9 +1,10 @@
from typing import Any, Optional, List, Tuple, Type
from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
logger = get_logger()
@ -104,10 +105,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
valid_args = {"query": query, "answer": completion, "context": context}
valid_user_prompt = LLMGateway.render_prompt(
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = LLMGateway.read_query_prompt(
valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
@ -117,10 +118,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
response_model=str,
)
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
followup_prompt = LLMGateway.render_prompt(
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = LLMGateway.read_query_prompt(
followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)

View file

@ -2,6 +2,7 @@ from typing import Any, Optional
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
@ -49,7 +50,7 @@ class NaturalLanguageRetriever(BaseRetriever):
async def _generate_cypher_query(self, query: str, edge_schemas, previous_attempts=None) -> str:
"""Generate a Cypher query using LLM based on natural language query and schema information."""
system_prompt = LLMGateway.render_prompt(
system_prompt = render_prompt(
self.system_prompt_path,
context={
"edge_schemas": edge_schemas,

View file

@ -6,6 +6,7 @@ from operator import itemgetter
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm import LLMGateway
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.shared.logging_utils import get_logger
@ -72,7 +73,7 @@ class TemporalRetriever(GraphCompletionRetriever):
else:
base_directory = None
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
interval = await LLMGateway.acreate_structured_output(query, system_prompt, QueryInterval)

View file

@ -1,5 +1,6 @@
from typing import Optional
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
async def generate_completion(
@ -12,10 +13,8 @@ async def generate_completion(
) -> str:
"""Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = LLMGateway.render_prompt(user_prompt_path, args)
system_prompt = (
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
)
user_prompt = render_prompt(user_prompt_path, args)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
if only_context:
return context
@ -33,9 +32,7 @@ async def summarize_text(
system_prompt: str = None,
) -> str:
"""Summarizes text using LLM with the specified prompt."""
system_prompt = (
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
return await LLMGateway.acreate_structured_output(
text_input=text,

View file

@ -7,7 +7,7 @@ from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine.models import DataPoint
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.extraction import extract_categories
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
@ -40,7 +40,7 @@ async def chunk_naive_llm_classifier(
return data_chunks
chunk_classifications = await asyncio.gather(
*[LLMGateway.extract_categories(chunk.text, classification_model) for chunk in data_chunks],
*[extract_categories(chunk.text, classification_model) for chunk in data_chunks],
)
classification_data_points = []

View file

@ -4,6 +4,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.low_level import DataPoint
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm import LLMGateway
from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.models import NodeSet
@ -106,8 +107,8 @@ async def add_rule_associations(
user_context = {"chat": data, "rules": existing_rules}
user_prompt = LLMGateway.render_prompt(user_prompt_location, context=user_context)
system_prompt = LLMGateway.render_prompt(system_prompt_location, context={})
user_prompt = render_prompt(user_prompt_location, context=user_context)
system_prompt = render_prompt(system_prompt_location, context={})
rule_list = await LLMGateway.acreate_structured_output(
text_input=user_prompt, system_prompt=system_prompt, response_model=RuleSet

View file

@ -3,6 +3,7 @@ from typing import List
from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
from cognee.modules.engine.models import Entity
from cognee.modules.engine.models.EntityType import EntityType
@ -50,8 +51,8 @@ class LLMEntityExtractor(BaseEntityExtractor):
try:
logger.info(f"Extracting entities from text: {text[:100]}...")
user_prompt = LLMGateway.render_prompt(self.user_prompt_template, {"text": text})
system_prompt = LLMGateway.read_query_prompt(self.system_prompt_template)
user_prompt = render_prompt(self.user_prompt_template, {"text": text})
system_prompt = read_query_prompt(self.system_prompt_template)
response = await LLMGateway.acreate_structured_output(
text_input=user_prompt,

View file

@ -1,6 +1,7 @@
from typing import List, Tuple
from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.root_dir import get_absolute_path
@ -32,12 +33,12 @@ async def extract_content_nodes_and_relationship_names(
}
base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts")
text_input = LLMGateway.render_prompt(
text_input = render_prompt(
"extract_graph_relationship_names_prompt_input.txt",
context,
base_directory=base_directory,
)
system_prompt = LLMGateway.read_query_prompt(
system_prompt = read_query_prompt(
"extract_graph_relationship_names_prompt_system.txt", base_directory=base_directory
)
response = await LLMGateway.acreate_structured_output(

View file

@ -1,5 +1,6 @@
from typing import List
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.shared.data_models import KnowledgeGraph
from cognee.root_dir import get_absolute_path
@ -26,10 +27,10 @@ async def extract_edge_triplets(
}
base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts")
text_input = LLMGateway.render_prompt(
text_input = render_prompt(
"extract_graph_edge_triplets_prompt_input.txt", context, base_directory=base_directory
)
system_prompt = LLMGateway.read_query_prompt(
system_prompt = read_query_prompt(
"extract_graph_edge_triplets_prompt_system.txt", base_directory=base_directory
)
extracted_graph = await LLMGateway.acreate_structured_output(

View file

@ -1,6 +1,7 @@
from typing import List
from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.root_dir import get_absolute_path
@ -24,10 +25,10 @@ async def extract_nodes(text: str, n_rounds: int = 2) -> List[str]:
"text": text,
}
base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts")
text_input = LLMGateway.render_prompt(
text_input = render_prompt(
"extract_graph_nodes_prompt_input.txt", context, base_directory=base_directory
)
system_prompt = LLMGateway.read_query_prompt(
system_prompt = read_query_prompt(
"extract_graph_nodes_prompt_system.txt", base_directory=base_directory
)
response = await LLMGateway.acreate_structured_output(

View file

@ -2,7 +2,7 @@ import asyncio
from typing import Type, List
from pydantic import BaseModel
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.extraction import extract_content_graph
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.tasks.storage import add_data_points
@ -18,7 +18,7 @@ async def extract_graph_from_code(
- Graph nodes are stored using the `add_data_points` function for later retrieval or analysis.
"""
chunk_graphs = await asyncio.gather(
*[LLMGateway.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
for chunk_index, chunk in enumerate(data_chunks):

View file

@ -3,7 +3,7 @@ from typing import AsyncGenerator, Union
from uuid import uuid5
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.extraction import extract_code_summary
from .models import CodeSummary
@ -16,7 +16,7 @@ async def summarize_code(
code_data_points = [file for file in code_graph_nodes if hasattr(file, "source_code")]
file_summaries = await asyncio.gather(
*[LLMGateway.extract_code_summary(file.source_code) for file in code_data_points]
*[extract_code_summary(file.source_code) for file in code_data_points]
)
file_summaries_map = {

View file

@ -12,6 +12,7 @@ from cognee.tasks.temporal_awareness.index_graphiti_objects import (
)
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.modules.users.methods import get_default_user
@ -59,8 +60,8 @@ async def main():
"context": context,
}
user_prompt = LLMGateway.render_prompt("graph_context_for_question.txt", args)
system_prompt = LLMGateway.read_query_prompt("answer_simple_question_restricted.txt")
user_prompt = render_prompt("graph_context_for_question.txt", args)
system_prompt = read_query_prompt("answer_simple_question_restricted.txt")
computed_answer = await LLMGateway.acreate_structured_output(
text_input=user_prompt,