refactor: remove LLMGateway usage where not needed
This commit is contained in:
parent
89b51a244d
commit
f1144abc54
18 changed files with 46 additions and 101 deletions
|
|
@ -3,6 +3,7 @@ from pydantic import BaseModel
|
|||
from cognee.eval_framework.evaluation.base_eval_adapter import BaseEvalAdapter
|
||||
from cognee.eval_framework.eval_config import EvalConfig
|
||||
|
||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
|
||||
|
||||
|
|
@ -25,8 +26,8 @@ class DirectLLMEvalAdapter(BaseEvalAdapter):
|
|||
) -> Dict[str, Any]:
|
||||
args = {"question": question, "answer": answer, "golden_answer": golden_answer}
|
||||
|
||||
user_prompt = LLMGateway.render_prompt(self.eval_prompt_path, args)
|
||||
system_prompt = LLMGateway.read_query_prompt(self.system_prompt_path)
|
||||
user_prompt = render_prompt(self.eval_prompt_path, args)
|
||||
system_prompt = read_query_prompt(self.system_prompt_path)
|
||||
|
||||
evaluation = await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
|
|
|
|||
|
|
@ -76,10 +76,10 @@ class LLMConfig(BaseSettings):
|
|||
provider=self.baml_llm_provider,
|
||||
options={
|
||||
"model": self.baml_llm_model,
|
||||
"temperature": self.baml_llm_temperature,
|
||||
# "temperature": self.baml_llm_temperature,
|
||||
"api_key": self.baml_llm_api_key,
|
||||
"base_url": self.baml_llm_endpoint,
|
||||
"api_version": self.baml_llm_api_version,
|
||||
# "base_url": self.baml_llm_endpoint,
|
||||
# "api_version": self.baml_llm_api_version,
|
||||
},
|
||||
)
|
||||
# Sets the primary client
|
||||
|
|
|
|||
|
|
@ -68,35 +68,3 @@ class AnthropicAdapter(LLMInterface):
|
|||
],
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""
|
||||
Format and display the prompt for a user query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- text_input (str): The input text from the user, defaults to a placeholder if
|
||||
empty.
|
||||
- system_prompt (str): The path to the system prompt to be read and formatted.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A formatted string displaying the system prompt and user input.
|
||||
"""
|
||||
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise MissingSystemPromptPathError()
|
||||
|
||||
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
if system_prompt
|
||||
else None
|
||||
)
|
||||
|
||||
return formatted_prompt
|
||||
|
|
|
|||
|
|
@ -113,34 +113,3 @@ class GeminiAdapter(LLMInterface):
|
|||
logger.error(f"Schema validation failed: {str(e)}")
|
||||
logger.debug(f"Raw response: {e.raw_response}")
|
||||
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""
|
||||
Format and display the prompt for a user query.
|
||||
|
||||
Raises an MissingQueryParameterError if no system prompt is provided.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- text_input (str): The user input text to display along with the system prompt.
|
||||
- system_prompt (str): The path or content of the system prompt to be read and
|
||||
displayed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: Returns a formatted string containing the system prompt and user input.
|
||||
"""
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise MissingSystemPromptPathError()
|
||||
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
if system_prompt
|
||||
else None
|
||||
)
|
||||
return formatted_prompt
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from cognee.shared.logging_utils import get_logger
|
|||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
|
||||
logger = get_logger("CodeRetriever")
|
||||
|
|
@ -41,7 +42,7 @@ class CodeRetriever(BaseRetriever):
|
|||
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
system_prompt = LLMGateway.read_query_prompt("codegraph_retriever_system.txt")
|
||||
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
|
||||
|
||||
try:
|
||||
result = await LLMGateway.acreate_structured_output(
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Any, Optional, List, Tuple, Type
|
||||
from typing import Any, Optional, List, Type
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -104,10 +105,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||
if round_idx < max_iter:
|
||||
valid_args = {"query": query, "answer": completion, "context": context}
|
||||
valid_user_prompt = LLMGateway.render_prompt(
|
||||
valid_user_prompt = render_prompt(
|
||||
filename=self.validation_user_prompt_path, context=valid_args
|
||||
)
|
||||
valid_system_prompt = LLMGateway.read_query_prompt(
|
||||
valid_system_prompt = read_query_prompt(
|
||||
prompt_file_name=self.validation_system_prompt_path
|
||||
)
|
||||
|
||||
|
|
@ -117,10 +118,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
response_model=str,
|
||||
)
|
||||
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
|
||||
followup_prompt = LLMGateway.render_prompt(
|
||||
followup_prompt = render_prompt(
|
||||
filename=self.followup_user_prompt_path, context=followup_args
|
||||
)
|
||||
followup_system = LLMGateway.read_query_prompt(
|
||||
followup_system = read_query_prompt(
|
||||
prompt_file_name=self.followup_system_prompt_path
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Any, Optional
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
|
|
@ -49,7 +50,7 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|||
|
||||
async def _generate_cypher_query(self, query: str, edge_schemas, previous_attempts=None) -> str:
|
||||
"""Generate a Cypher query using LLM based on natural language query and schema information."""
|
||||
system_prompt = LLMGateway.render_prompt(
|
||||
system_prompt = render_prompt(
|
||||
self.system_prompt_path,
|
||||
context={
|
||||
"edge_schemas": edge_schemas,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from operator import itemgetter
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -72,7 +73,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
else:
|
||||
base_directory = None
|
||||
|
||||
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
|
||||
interval = await LLMGateway.acreate_structured_output(query, system_prompt, QueryInterval)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
|
||||
|
||||
async def generate_completion(
|
||||
|
|
@ -12,10 +13,8 @@ async def generate_completion(
|
|||
) -> str:
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
args = {"question": query, "context": context}
|
||||
user_prompt = LLMGateway.render_prompt(user_prompt_path, args)
|
||||
system_prompt = (
|
||||
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
|
||||
)
|
||||
user_prompt = render_prompt(user_prompt_path, args)
|
||||
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
||||
|
||||
if only_context:
|
||||
return context
|
||||
|
|
@ -33,9 +32,7 @@ async def summarize_text(
|
|||
system_prompt: str = None,
|
||||
) -> str:
|
||||
"""Summarizes text using LLM with the specified prompt."""
|
||||
system_prompt = (
|
||||
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
|
||||
)
|
||||
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
||||
|
||||
return await LLMGateway.acreate_structured_output(
|
||||
text_input=text,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from pydantic import BaseModel
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.engine.models import DataPoint
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.extraction import extract_categories
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
|
||||
|
||||
|
|
@ -40,7 +40,7 @@ async def chunk_naive_llm_classifier(
|
|||
return data_chunks
|
||||
|
||||
chunk_classifications = await asyncio.gather(
|
||||
*[LLMGateway.extract_categories(chunk.text, classification_model) for chunk in data_chunks],
|
||||
*[extract_categories(chunk.text, classification_model) for chunk in data_chunks],
|
||||
)
|
||||
|
||||
classification_data_points = []
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
from cognee.low_level import DataPoint
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.engine.models import NodeSet
|
||||
|
|
@ -106,8 +107,8 @@ async def add_rule_associations(
|
|||
|
||||
user_context = {"chat": data, "rules": existing_rules}
|
||||
|
||||
user_prompt = LLMGateway.render_prompt(user_prompt_location, context=user_context)
|
||||
system_prompt = LLMGateway.render_prompt(system_prompt_location, context={})
|
||||
user_prompt = render_prompt(user_prompt_location, context=user_context)
|
||||
system_prompt = render_prompt(system_prompt_location, context={})
|
||||
|
||||
rule_list = await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt, system_prompt=system_prompt, response_model=RuleSet
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
|
||||
from cognee.modules.engine.models import Entity
|
||||
from cognee.modules.engine.models.EntityType import EntityType
|
||||
|
|
@ -50,8 +51,8 @@ class LLMEntityExtractor(BaseEntityExtractor):
|
|||
try:
|
||||
logger.info(f"Extracting entities from text: {text[:100]}...")
|
||||
|
||||
user_prompt = LLMGateway.render_prompt(self.user_prompt_template, {"text": text})
|
||||
system_prompt = LLMGateway.read_query_prompt(self.system_prompt_template)
|
||||
user_prompt = render_prompt(self.user_prompt_template, {"text": text})
|
||||
system_prompt = read_query_prompt(self.system_prompt_template)
|
||||
|
||||
response = await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List, Tuple
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.root_dir import get_absolute_path
|
||||
|
||||
|
|
@ -32,12 +33,12 @@ async def extract_content_nodes_and_relationship_names(
|
|||
}
|
||||
|
||||
base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts")
|
||||
text_input = LLMGateway.render_prompt(
|
||||
text_input = render_prompt(
|
||||
"extract_graph_relationship_names_prompt_input.txt",
|
||||
context,
|
||||
base_directory=base_directory,
|
||||
)
|
||||
system_prompt = LLMGateway.read_query_prompt(
|
||||
system_prompt = read_query_prompt(
|
||||
"extract_graph_relationship_names_prompt_system.txt", base_directory=base_directory
|
||||
)
|
||||
response = await LLMGateway.acreate_structured_output(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import List
|
||||
|
||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.root_dir import get_absolute_path
|
||||
|
|
@ -26,10 +27,10 @@ async def extract_edge_triplets(
|
|||
}
|
||||
|
||||
base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts")
|
||||
text_input = LLMGateway.render_prompt(
|
||||
text_input = render_prompt(
|
||||
"extract_graph_edge_triplets_prompt_input.txt", context, base_directory=base_directory
|
||||
)
|
||||
system_prompt = LLMGateway.read_query_prompt(
|
||||
system_prompt = read_query_prompt(
|
||||
"extract_graph_edge_triplets_prompt_system.txt", base_directory=base_directory
|
||||
)
|
||||
extracted_graph = await LLMGateway.acreate_structured_output(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.root_dir import get_absolute_path
|
||||
|
||||
|
|
@ -24,10 +25,10 @@ async def extract_nodes(text: str, n_rounds: int = 2) -> List[str]:
|
|||
"text": text,
|
||||
}
|
||||
base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts")
|
||||
text_input = LLMGateway.render_prompt(
|
||||
text_input = render_prompt(
|
||||
"extract_graph_nodes_prompt_input.txt", context, base_directory=base_directory
|
||||
)
|
||||
system_prompt = LLMGateway.read_query_prompt(
|
||||
system_prompt = read_query_prompt(
|
||||
"extract_graph_nodes_prompt_system.txt", base_directory=base_directory
|
||||
)
|
||||
response = await LLMGateway.acreate_structured_output(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
from typing import Type, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.extraction import extract_content_graph
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
|
|
@ -18,7 +18,7 @@ async def extract_graph_from_code(
|
|||
- Graph nodes are stored using the `add_data_points` function for later retrieval or analysis.
|
||||
"""
|
||||
chunk_graphs = await asyncio.gather(
|
||||
*[LLMGateway.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
)
|
||||
|
||||
for chunk_index, chunk in enumerate(data_chunks):
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import AsyncGenerator, Union
|
|||
from uuid import uuid5
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.extraction import extract_code_summary
|
||||
from .models import CodeSummary
|
||||
|
||||
|
||||
|
|
@ -16,7 +16,7 @@ async def summarize_code(
|
|||
code_data_points = [file for file in code_graph_nodes if hasattr(file, "source_code")]
|
||||
|
||||
file_summaries = await asyncio.gather(
|
||||
*[LLMGateway.extract_code_summary(file.source_code) for file in code_data_points]
|
||||
*[extract_code_summary(file.source_code) for file in code_data_points]
|
||||
)
|
||||
|
||||
file_summaries_map = {
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from cognee.tasks.temporal_awareness.index_graphiti_objects import (
|
|||
)
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
|
|
@ -59,8 +60,8 @@ async def main():
|
|||
"context": context,
|
||||
}
|
||||
|
||||
user_prompt = LLMGateway.render_prompt("graph_context_for_question.txt", args)
|
||||
system_prompt = LLMGateway.read_query_prompt("answer_simple_question_restricted.txt")
|
||||
user_prompt = render_prompt("graph_context_for_question.txt", args)
|
||||
system_prompt = read_query_prompt("answer_simple_question_restricted.txt")
|
||||
|
||||
computed_answer = await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue