From 7761b702299ae12ca7ae408dd4e6f262a1039fc7 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 5 Aug 2025 18:49:32 +0200 Subject: [PATCH] feat: Add adapter for structured output and LLM usage --- .../codingagents/coding_rule_associations.py | 17 +- cognee/api/v1/config/config.py | 2 +- .../responses/routers/get_responses_router.py | 2 +- cognee/base_config.py | 4 +- .../evaluation/direct_llm_eval_adapter.py | 16 +- .../embeddings/FastembedEmbeddingEngine.py | 2 +- .../embeddings/LiteLLMEmbeddingEngine.py | 10 +- .../embeddings/OllamaEmbeddingEngine.py | 4 +- .../vector/embeddings/get_embedding_engine.py | 2 +- cognee/infrastructure/llm/__init__.py | 13 +- .../llm/generic_llm_api/adapter.py | 156 ----- .../baml_src/__init__.py | 0 .../baml_src/config.py | 185 ------ .../baml_src/extract_categories.baml | 109 ---- .../baml_src/extract_content_graph.baml | 343 ----------- .../baml_src/extraction/__init__.py | 2 - .../baml_src/extraction/extract_categories.py | 114 ---- .../baml_src/extraction/extract_summary.py | 114 ---- .../extraction/knowledge_graph/__init__.py | 0 .../knowledge_graph/extract_content_graph.py | 49 -- .../baml_src/generators.baml | 18 - .../extraction/__init__.py | 1 + .../knowledge_graph/extract_content_graph.py | 2 +- .../llitellm_instructor/llm/config.py | 176 ------ .../llm/embedding_rate_limiter.py | 550 ------------------ .../llm/generic_llm_api/adapter.py | 110 +++- .../llitellm_instructor/llm/get_llm_client.py | 2 +- .../llitellm_instructor/llm/rate_limiter.py | 4 +- .../llitellm_instructor/llm/utils.py | 107 ---- .../document_types/AudioDocument.py | 6 +- .../document_types/ImageDocument.py | 6 +- .../modules/pipelines/operations/pipeline.py | 2 +- cognee/modules/retrieval/code_retriever.py | 12 +- .../graph_completion_cot_retriever.py | 22 +- .../retrieval/natural_language_retriever.py | 12 +- cognee/modules/retrieval/utils/completion.py | 20 +- .../utils/description_to_codepart_search.py | 7 +- cognee/shared/data_models.py | 2 +- .../chunk_naive_llm_classifier.py | 6 +- .../entity_extractors/llm_entity_extractor.py | 15 +- ...ct_content_nodes_and_relationship_names.py | 15 +- .../utils/extract_edge_triplets.py | 16 +- .../cascade_extract/utils/extract_nodes.py | 15 +- cognee/tasks/graph/extract_graph_from_code.py | 17 +- cognee/tasks/graph/extract_graph_from_data.py | 17 +- cognee/tasks/graph/infer_data_ontology.py | 13 +- cognee/tasks/summarization/summarize_code.py | 19 +- cognee/tasks/summarization/summarize_text.py | 19 +- .../databases/vector/__init__.py | 1 - .../infrastructure/mock_embedding_engine.py | 2 +- .../test_embedding_rate_limiting_realistic.py | 4 +- .../test_rate_limiting_realistic.py | 2 +- .../agentic_reasoning_procurement_example.py | 6 +- examples/python/graphiti_example.py | 15 +- 54 files changed, 192 insertions(+), 2193 deletions(-) delete mode 100644 cognee/infrastructure/llm/generic_llm_api/adapter.py delete mode 100644 cognee/infrastructure/llm/structured_output_framework/baml_src/__init__.py delete mode 100644 cognee/infrastructure/llm/structured_output_framework/baml_src/config.py delete mode 100644 cognee/infrastructure/llm/structured_output_framework/baml_src/extract_categories.baml delete mode 100644 cognee/infrastructure/llm/structured_output_framework/baml_src/extract_content_graph.baml delete mode 100644 cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/__init__.py delete mode 100644 cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/extract_categories.py delete mode 100644 cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/extract_summary.py delete mode 100644 cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/knowledge_graph/__init__.py delete mode 100644 cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/knowledge_graph/extract_content_graph.py delete mode 100644 cognee/infrastructure/llm/structured_output_framework/baml_src/generators.baml delete mode 100644 cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/config.py delete mode 100644 cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/embedding_rate_limiter.py delete mode 100644 cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/utils.py diff --git a/cognee-mcp/src/codingagents/coding_rule_associations.py b/cognee-mcp/src/codingagents/coding_rule_associations.py index 952ddfd66..19d94b9f9 100644 --- a/cognee-mcp/src/codingagents/coding_rule_associations.py +++ b/cognee-mcp/src/codingagents/coding_rule_associations.py @@ -2,13 +2,9 @@ from uuid import NAMESPACE_OID, uuid5 from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( - render_prompt, -) + from cognee.low_level import DataPoint -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) +from cognee.infrastructure.llm import LLMAdapter from cognee.shared.logging_utils import get_logger from cognee.modules.engine.models import NodeSet from cognee.tasks.storage import add_data_points, index_graph_edges @@ -95,16 +91,17 @@ async def get_origin_edges(data: str, rules: List[Rule]) -> list[Any]: async def add_rule_associations(data: str, rules_nodeset_name: str): - llm_client = get_llm_client() graph_engine = await get_graph_engine() existing_rules = await get_existing_rules(rules_nodeset_name=rules_nodeset_name) user_context = {"chat": data, "rules": existing_rules} - user_prompt = render_prompt("coding_rule_association_agent_user.txt", context=user_context) - system_prompt = render_prompt("coding_rule_association_agent_system.txt", context={}) + user_prompt = LLMAdapter.render_prompt( + "coding_rule_association_agent_user.txt", context=user_context + ) + system_prompt = LLMAdapter.render_prompt("coding_rule_association_agent_system.txt", context={}) - rule_list = await llm_client.acreate_structured_output( + rule_list = await LLMAdapter.acreate_structured_output( text_input=user_prompt, system_prompt=system_prompt, response_model=RuleSet ) diff --git a/cognee/api/v1/config/config.py b/cognee/api/v1/config/config.py index 0628ca154..9970b7471 100644 --- a/cognee/api/v1/config/config.py +++ b/cognee/api/v1/config/config.py @@ -7,7 +7,7 @@ from cognee.modules.cognify.config import get_cognify_config from cognee.infrastructure.data.chunking.config import get_chunk_config from cognee.infrastructure.databases.vector import get_vectordb_config from cognee.infrastructure.databases.graph.config import get_graph_config -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import ( +from cognee.infrastructure.llm.config import ( get_llm_config, ) from cognee.infrastructure.databases.relational import get_relational_config, get_migration_config diff --git a/cognee/api/v1/responses/routers/get_responses_router.py b/cognee/api/v1/responses/routers/get_responses_router.py index d196ddc05..cf1f003c0 100644 --- a/cognee/api/v1/responses/routers/get_responses_router.py +++ b/cognee/api/v1/responses/routers/get_responses_router.py @@ -17,7 +17,7 @@ from cognee.api.v1.responses.models import ( ) from cognee.api.v1.responses.dispatch_function import dispatch_function from cognee.api.v1.responses.default_tools import DEFAULT_TOOLS -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import ( +from cognee.infrastructure.llm.config import ( get_llm_config, ) from cognee.modules.users.models import User diff --git a/cognee/base_config.py b/cognee/base_config.py index b306122c4..748683066 100644 --- a/cognee/base_config.py +++ b/cognee/base_config.py @@ -15,9 +15,7 @@ class BaseConfig(BaseSettings): langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST") default_user_email: Optional[str] = os.getenv("DEFAULT_USER_EMAIL") default_user_password: Optional[str] = os.getenv("DEFAULT_USER_PASSWORD") - structured_output_framework: str = os.getenv( - "STRUCTURED_OUTPUT_FRAMEWORK", "llitellm_instructor" - ) + model_config = SettingsConfigDict(env_file=".env", extra="allow") def to_dict(self) -> dict: diff --git a/cognee/eval_framework/evaluation/direct_llm_eval_adapter.py b/cognee/eval_framework/evaluation/direct_llm_eval_adapter.py index 2359c08ec..d9c7b9851 100644 --- a/cognee/eval_framework/evaluation/direct_llm_eval_adapter.py +++ b/cognee/eval_framework/evaluation/direct_llm_eval_adapter.py @@ -1,15 +1,10 @@ from typing import Any, Dict, List from pydantic import BaseModel -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) from cognee.eval_framework.evaluation.base_eval_adapter import BaseEvalAdapter -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( - read_query_prompt, - render_prompt, -) from cognee.eval_framework.eval_config import EvalConfig +from cognee.infrastructure.llm import LLMAdapter + class CorrectnessEvaluation(BaseModel): """Response model containing evaluation score and explanation.""" @@ -24,17 +19,16 @@ class DirectLLMEvalAdapter(BaseEvalAdapter): config = EvalConfig() self.system_prompt_path = config.direct_llm_system_prompt self.eval_prompt_path = config.direct_llm_eval_prompt - self.llm_client = get_llm_client() async def evaluate_correctness( self, question: str, answer: str, golden_answer: str ) -> Dict[str, Any]: args = {"question": question, "answer": answer, "golden_answer": golden_answer} - user_prompt = render_prompt(self.eval_prompt_path, args) - system_prompt = read_query_prompt(self.system_prompt_path) + user_prompt = LLMAdapter.render_prompt(self.eval_prompt_path, args) + system_prompt = LLMAdapter.read_query_prompt(self.system_prompt_path) - evaluation = await self.llm_client.acreate_structured_output( + evaluation = await LLMAdapter.acreate_structured_output( text_input=user_prompt, system_prompt=system_prompt, response_model=CorrectnessEvaluation, diff --git a/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py index 5b45c96b4..f7925a3d2 100644 --- a/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py @@ -5,7 +5,7 @@ import litellm import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.tokenizer.TikToken import ( +from cognee.infrastructure.llm.tokenizer.TikToken import ( TikTokenTokenizer, ) diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 94257699e..c6e94df91 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -7,19 +7,19 @@ import litellm import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.tokenizer.Gemini import ( +from cognee.infrastructure.llm.tokenizer.Gemini import ( GeminiTokenizer, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.tokenizer.HuggingFace import ( +from cognee.infrastructure.llm.tokenizer.HuggingFace import ( HuggingFaceTokenizer, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.tokenizer import ( +from cognee.infrastructure.llm.tokenizer.Mistral import ( MistralTokenizer, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.tokenizer.TikToken import ( +from cognee.infrastructure.llm.tokenizer.TikToken import ( TikTokenTokenizer, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.embedding_rate_limiter import ( +from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import ( embedding_rate_limit_async, embedding_sleep_and_retry_async, ) diff --git a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py index 24cf58d2b..bfb24a2d3 100644 --- a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py @@ -7,10 +7,10 @@ import os import aiohttp.http_exceptions from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.tokenizer.HuggingFace import ( +from cognee.infrastructure.llm.tokenizer.HuggingFace import ( HuggingFaceTokenizer, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.embedding_rate_limiter import ( +from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import ( embedding_rate_limit_async, embedding_sleep_and_retry_async, ) diff --git a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py index 4fa1473bf..d250525a3 100644 --- a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +++ b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py @@ -1,5 +1,5 @@ from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import ( +from cognee.infrastructure.llm.config import ( get_llm_config, ) from .EmbeddingEngine import EmbeddingEngine diff --git a/cognee/infrastructure/llm/__init__.py b/cognee/infrastructure/llm/__init__.py index e1b5135fd..e1dd2550d 100644 --- a/cognee/infrastructure/llm/__init__.py +++ b/cognee/infrastructure/llm/__init__.py @@ -1,15 +1,14 @@ -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import ( +from cognee.infrastructure.llm.config import ( get_llm_config, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.utils import ( +from cognee.infrastructure.llm.utils import ( get_max_chunk_tokens, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.utils import ( +from cognee.infrastructure.llm.utils import ( test_llm_connection, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.utils import ( +from cognee.infrastructure.llm.utils import ( test_embedding_connection, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm import ( - rate_limiter, -) + +from LLMAdapter import LLMAdapter diff --git a/cognee/infrastructure/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/generic_llm_api/adapter.py deleted file mode 100644 index b97d2342c..000000000 --- a/cognee/infrastructure/llm/generic_llm_api/adapter.py +++ /dev/null @@ -1,156 +0,0 @@ -"""Adapter for Generic API LLM provider API""" - -import logging -import litellm -import instructor -from typing import Type -from pydantic import BaseModel -from openai import ContentFilterFinishReasonError -from litellm.exceptions import ContentPolicyViolationError -from instructor.exceptions import InstructorRetryException - -from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.llm_interface import ( - LLMInterface, -) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.rate_limiter import ( - rate_limit_async, - sleep_and_retry_async, -) - - -class GenericAPIAdapter(LLMInterface): - """ - Adapter for Generic API LLM provider API. - - This class initializes the API adapter with necessary credentials and configurations for - interacting with a language model. It provides methods for creating structured outputs - based on user input and system prompts. - - Public methods: - - acreate_structured_output(text_input: str, system_prompt: str, response_model: - Type[BaseModel]) -> BaseModel - """ - - name: str - model: str - api_key: str - - def __init__( - self, - endpoint, - api_key: str, - model: str, - name: str, - max_tokens: int, - fallback_model: str = None, - fallback_api_key: str = None, - fallback_endpoint: str = None, - ): - self.name = name - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.max_tokens = max_tokens - - self.fallback_model = fallback_model - self.fallback_api_key = fallback_api_key - self.fallback_endpoint = fallback_endpoint - - self.aclient = instructor.from_litellm( - litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key - ) - - @sleep_and_retry_async() - @rate_limit_async - async def acreate_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel] - ) -> BaseModel: - """ - Generate a response from a user query. - - This asynchronous method sends a user query and a system prompt to a language model and - retrieves the generated response. It handles API communication and retries up to a - specified limit in case of request failures. - - Parameters: - ----------- - - - text_input (str): The input text from the user to generate a response for. - - system_prompt (str): A prompt that provides context or instructions for the - response generation. - - response_model (Type[BaseModel]): A Pydantic model that defines the structure of - the expected response. - - Returns: - -------- - - - BaseModel: An instance of the specified response model containing the structured - output from the language model. - """ - - try: - return await self.aclient.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": f"""{text_input}""", - }, - { - "role": "system", - "content": system_prompt, - }, - ], - max_retries=5, - api_base=self.endpoint, - response_model=response_model, - ) - except ( - ContentFilterFinishReasonError, - ContentPolicyViolationError, - InstructorRetryException, - ) as error: - if ( - isinstance(error, InstructorRetryException) - and "content management policy" not in str(error).lower() - ): - raise error - - if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint): - raise ContentPolicyFilterError( - f"The provided input contains content that is not aligned with our content policy: {text_input}" - ) - - try: - return await self.aclient.chat.completions.create( - model=self.fallback_model, - messages=[ - { - "role": "user", - "content": f"""{text_input}""", - }, - { - "role": "system", - "content": system_prompt, - }, - ], - max_retries=5, - api_key=self.fallback_api_key, - api_base=self.fallback_endpoint, - response_model=response_model, - ) - except ( - ContentFilterFinishReasonError, - ContentPolicyViolationError, - InstructorRetryException, - ) as error: - if ( - isinstance(error, InstructorRetryException) - and "content management policy" not in str(error).lower() - ): - raise error - else: - raise ContentPolicyFilterError( - f"The provided input contains content that is not aligned with our content policy: {text_input}" - ) diff --git a/cognee/infrastructure/llm/structured_output_framework/baml_src/__init__.py b/cognee/infrastructure/llm/structured_output_framework/baml_src/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/cognee/infrastructure/llm/structured_output_framework/baml_src/config.py b/cognee/infrastructure/llm/structured_output_framework/baml_src/config.py deleted file mode 100644 index 2a305cf9e..000000000 --- a/cognee/infrastructure/llm/structured_output_framework/baml_src/config.py +++ /dev/null @@ -1,185 +0,0 @@ -import os -from typing import Optional, ClassVar -from functools import lru_cache -from pydantic_settings import BaseSettings, SettingsConfigDict -from pydantic import model_validator -from baml_py import ClientRegistry - - -class LLMConfig(BaseSettings): - """ - Configuration settings for the LLM (Large Language Model) provider and related options. - - Public instance variables include: - - llm_provider - - llm_model - - llm_endpoint - - llm_api_key - - llm_api_version - - llm_temperature - - llm_streaming - - llm_max_tokens - - transcription_model - - graph_prompt_path - - llm_rate_limit_enabled - - llm_rate_limit_requests - - llm_rate_limit_interval - - embedding_rate_limit_enabled - - embedding_rate_limit_requests - - embedding_rate_limit_interval - - Public methods include: - - ensure_env_vars_for_ollama - - to_dict - """ - - llm_provider: str = "openai" - llm_model: str = "gpt-4o-mini" - llm_endpoint: str = "" - llm_api_key: Optional[str] = None - llm_api_version: Optional[str] = None - llm_temperature: float = 0.0 - llm_streaming: bool = False - llm_max_tokens: int = 16384 - transcription_model: str = "whisper-1" - graph_prompt_path: str = "generate_graph_prompt.txt" - llm_rate_limit_enabled: bool = False - llm_rate_limit_requests: int = 60 - llm_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute) - embedding_rate_limit_enabled: bool = False - embedding_rate_limit_requests: int = 60 - embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute) - baml_registry: ClassVar[ClientRegistry] = ClientRegistry() - - model_config = SettingsConfigDict(env_file=".env", extra="allow") - - def model_post_init(self, __context) -> None: - """Initialize the BAML registry after the model is created.""" - self.baml_registry.add_llm_client( - name=self.llm_provider, - provider=self.llm_provider, - options={ - "model": self.llm_model, - "temperature": self.llm_temperature, - "api_key": self.llm_api_key, - }, - ) - # Sets the primary client - self.baml_registry.set_primary(self.llm_provider) - - @model_validator(mode="after") - def ensure_env_vars_for_ollama(self) -> "LLMConfig": - """ - Validate required environment variables for the 'ollama' LLM provider. - - Raises ValueError if some required environment variables are set without the others. - Only checks are performed when 'llm_provider' is set to 'ollama'. - - Returns: - -------- - - - 'LLMConfig': The instance of LLMConfig after validation. - """ - - if self.llm_provider != "ollama": - # Skip checks unless provider is "ollama" - return self - - def is_env_set(var_name: str) -> bool: - """ - Check if a given environment variable is set and non-empty. - - Parameters: - ----------- - - - var_name (str): The name of the environment variable to check. - - Returns: - -------- - - - bool: True if the environment variable exists and is not empty, otherwise False. - """ - val = os.environ.get(var_name) - return val is not None and val.strip() != "" - - # - # 1. Check LLM environment variables - # - llm_env_vars = { - "LLM_MODEL": is_env_set("LLM_MODEL"), - "LLM_ENDPOINT": is_env_set("LLM_ENDPOINT"), - "LLM_API_KEY": is_env_set("LLM_API_KEY"), - } - if any(llm_env_vars.values()) and not all(llm_env_vars.values()): - missing_llm = [key for key, is_set in llm_env_vars.items() if not is_set] - raise ValueError( - "You have set some but not all of the required environment variables " - f"for LLM usage (LLM_MODEL, LLM_ENDPOINT, LLM_API_KEY). Missing: {missing_llm}" - ) - - # - # 2. Check embedding environment variables - # - embedding_env_vars = { - "EMBEDDING_PROVIDER": is_env_set("EMBEDDING_PROVIDER"), - "EMBEDDING_MODEL": is_env_set("EMBEDDING_MODEL"), - "EMBEDDING_DIMENSIONS": is_env_set("EMBEDDING_DIMENSIONS"), - "HUGGINGFACE_TOKENIZER": is_env_set("HUGGINGFACE_TOKENIZER"), - } - if any(embedding_env_vars.values()) and not all(embedding_env_vars.values()): - missing_embed = [key for key, is_set in embedding_env_vars.items() if not is_set] - raise ValueError( - "You have set some but not all of the required environment variables " - "for embeddings (EMBEDDING_PROVIDER, EMBEDDING_MODEL, " - "EMBEDDING_DIMENSIONS, HUGGINGFACE_TOKENIZER). Missing: " - f"{missing_embed}" - ) - - return self - - def to_dict(self) -> dict: - """ - Convert the LLMConfig instance into a dictionary representation. - - Returns: - -------- - - - dict: A dictionary containing the configuration settings of the LLMConfig - instance. - """ - return { - "provider": self.llm_provider, - "model": self.llm_model, - "endpoint": self.llm_endpoint, - "api_key": self.llm_api_key, - "api_version": self.llm_api_version, - "temperature": self.llm_temperature, - "streaming": self.llm_streaming, - "max_tokens": self.llm_max_tokens, - "transcription_model": self.transcription_model, - "graph_prompt_path": self.graph_prompt_path, - "rate_limit_enabled": self.llm_rate_limit_enabled, - "rate_limit_requests": self.llm_rate_limit_requests, - "rate_limit_interval": self.llm_rate_limit_interval, - "embedding_rate_limit_enabled": self.embedding_rate_limit_enabled, - "embedding_rate_limit_requests": self.embedding_rate_limit_requests, - "embedding_rate_limit_interval": self.embedding_rate_limit_interval, - } - - -@lru_cache -def get_llm_config(): - """ - Retrieve and cache the LLM configuration. - - This function returns an instance of the LLMConfig class. It leverages - caching to ensure that repeated calls do not create new instances, - but instead return the already created configuration object. - - Returns: - -------- - - - LLMConfig: An instance of the LLMConfig class containing the configuration for the - LLM. - """ - return LLMConfig() diff --git a/cognee/infrastructure/llm/structured_output_framework/baml_src/extract_categories.baml b/cognee/infrastructure/llm/structured_output_framework/baml_src/extract_categories.baml deleted file mode 100644 index c718b754e..000000000 --- a/cognee/infrastructure/llm/structured_output_framework/baml_src/extract_categories.baml +++ /dev/null @@ -1,109 +0,0 @@ -// Content classification data models - matching shared/data_models.py -class TextContent { - type string - subclass string[] -} - -class AudioContent { - type string - subclass string[] -} - -class ImageContent { - type string - subclass string[] -} - -class VideoContent { - type string - subclass string[] -} - -class MultimediaContent { - type string - subclass string[] -} - -class Model3DContent { - type string - subclass string[] -} - -class ProceduralContent { - type string - subclass string[] -} - -class ContentLabel { - content_type "text" | "audio" | "image" | "video" | "multimedia" | "3d_model" | "procedural" - type string - subclass string[] -} - -class DefaultContentPrediction { - label ContentLabel -} - -// Content classification prompt template -template_string ClassifyContentPrompt() #" - You are a classification engine and should classify content. Make sure to use one of the existing classification options and not invent your own. - - Classify the content into one of these main categories and their relevant subclasses: - - **TEXT CONTENT** (content_type: "text"): - - type: "TEXTUAL_DOCUMENTS_USED_FOR_GENERAL_PURPOSES" - - subclass options: ["Articles, essays, and reports", "Books and manuscripts", "News stories and blog posts", "Research papers and academic publications", "Social media posts and comments", "Website content and product descriptions", "Personal narratives and stories", "Spreadsheets and tables", "Forms and surveys", "Databases and CSV files", "Source code in various programming languages", "Shell commands and scripts", "Markup languages (HTML, XML)", "Stylesheets (CSS) and configuration files (YAML, JSON, INI)", "Chat transcripts and messaging history", "Customer service logs and interactions", "Conversational AI training data", "Textbook content and lecture notes", "Exam questions and academic exercises", "E-learning course materials", "Poetry and prose", "Scripts for plays, movies, and television", "Song lyrics", "Manuals and user guides", "Technical specifications and API documentation", "Helpdesk articles and FAQs", "Contracts and agreements", "Laws, regulations, and legal case documents", "Policy documents and compliance materials", "Clinical trial reports", "Patient records and case notes", "Scientific journal articles", "Financial reports and statements", "Business plans and proposals", "Market research and analysis reports", "Ad copies and marketing slogans", "Product catalogs and brochures", "Press releases and promotional content", "Professional and formal correspondence", "Personal emails and letters", "Image and video captions", "Annotations and metadata for various media", "Vocabulary lists and grammar rules", "Language exercises and quizzes", "Other types of text data"] - - **AUDIO CONTENT** (content_type: "audio"): - - type: "AUDIO_DOCUMENTS_USED_FOR_GENERAL_PURPOSES" - - subclass options: ["Music tracks and albums", "Podcasts and radio broadcasts", "Audiobooks and audio guides", "Recorded interviews and speeches", "Sound effects and ambient sounds", "Other types of audio recordings"] - - **IMAGE CONTENT** (content_type: "image"): - - type: "IMAGE_DOCUMENTS_USED_FOR_GENERAL_PURPOSES" - - subclass options: ["Photographs and digital images", "Illustrations, diagrams, and charts", "Infographics and visual data representations", "Artwork and paintings", "Screenshots and graphical user interfaces", "Other types of images"] - - **VIDEO CONTENT** (content_type: "video"): - - type: "VIDEO_DOCUMENTS_USED_FOR_GENERAL_PURPOSES" - - subclass options: ["Movies and short films", "Documentaries and educational videos", "Video tutorials and how-to guides", "Animated features and cartoons", "Live event recordings and sports broadcasts", "Other types of video content"] - - **MULTIMEDIA CONTENT** (content_type: "multimedia"): - - type: "MULTIMEDIA_DOCUMENTS_USED_FOR_GENERAL_PURPOSES" - - subclass options: ["Interactive web content and games", "Virtual reality (VR) and augmented reality (AR) experiences", "Mixed media presentations and slide decks", "E-learning modules with integrated multimedia", "Digital exhibitions and virtual tours", "Other types of multimedia content"] - - **3D MODEL CONTENT** (content_type: "3d_model"): - - type: "3D_MODEL_DOCUMENTS_USED_FOR_GENERAL_PURPOSES" - - subclass options: ["Architectural renderings and building plans", "Product design models and prototypes", "3D animations and character models", "Scientific simulations and visualizations", "Virtual objects for AR/VR applications", "Other types of 3D models"] - - **PROCEDURAL CONTENT** (content_type: "procedural"): - - type: "PROCEDURAL_DOCUMENTS_USED_FOR_GENERAL_PURPOSES" - - subclass options: ["Tutorials and step-by-step guides", "Workflow and process descriptions", "Simulation and training exercises", "Recipes and crafting instructions", "Other types of procedural content"] - - Select the most appropriate content_type, type, and relevant subclasses. -"# - -// OpenAI client defined once for all BAML files - -// Classification function -function ExtractCategories(content: string) -> DefaultContentPrediction { - client OpenAI - - prompt #" - {{ ClassifyContentPrompt() }} - - {{ ctx.output_format(prefix="Answer in this schema:\n") }} - - {{ _.role('user') }} - {{ content }} - "# -} - -// Test case for classification -test ExtractCategoriesExample { - functions [ExtractCategories] - args { - content #" - Natural language processing (NLP) is an interdisciplinary subfield of computer science and information retrieval. - It deals with the interaction between computers and human language, in particular how to program computers to process and analyze large amounts of natural language data. - "# - } -} diff --git a/cognee/infrastructure/llm/structured_output_framework/baml_src/extract_content_graph.baml b/cognee/infrastructure/llm/structured_output_framework/baml_src/extract_content_graph.baml deleted file mode 100644 index 5b500f12e..000000000 --- a/cognee/infrastructure/llm/structured_output_framework/baml_src/extract_content_graph.baml +++ /dev/null @@ -1,343 +0,0 @@ -class Node { - id string - name string - type string - description string - @@dynamic -} - -/// doc string for edge -class Edge { - /// doc string for source_node_id - source_node_id string - target_node_id string - relationship_name string -} - -class KnowledgeGraph { - nodes (Node @stream.done)[] - edges Edge[] -} - -// Summarization classes -class SummarizedContent { - summary string - description string -} - -class SummarizedFunction { - name string - description string - inputs string[]? - outputs string[]? - decorators string[]? -} - -class SummarizedClass { - name string - description string - methods SummarizedFunction[]? - decorators string[]? -} - -class SummarizedCode { - high_level_summary string - key_features string[] - imports string[] - constants string[] - classes SummarizedClass[] - functions SummarizedFunction[] - workflow_description string? -} - -class DynamicKnowledgeGraph { - @@dynamic -} - - -// Simple template for basic extraction (fast, good quality) -template_string ExtractContentGraphPrompt() #" - You are an advanced algorithm that extracts structured data into a knowledge graph. - - - **Nodes**: Entities/concepts (like Wikipedia articles). - - **Edges**: Relationships (like Wikipedia links). Use snake_case (e.g., `acted_in`). - - **Rules:** - - 1. **Node Labeling & IDs** - - Use basic types only (e.g., "Person", "Date", "Organization"). - - Avoid overly specific or generic terms (e.g., no "Mathematician" or "Entity"). - - Node IDs must be human-readable names from the text (no numbers). - - 2. **Dates & Numbers** - - Label dates as **"Date"** in "YYYY-MM-DD" format (use available parts if incomplete). - - Properties are key-value pairs; do not use escaped quotes. - - 3. **Coreference Resolution** - - Use a single, complete identifier for each entity (e.g., always "John Doe" not "Joe" or "he"). - - 4. **Relationship Labels**: - - Use descriptive, lowercase, snake_case names for edges. - - *Example*: born_in, married_to, invented_by. - - Avoid vague or generic labels like isA, relatesTo, has. - - Avoid duplicated relationships like produces, produced by. - - 5. **Strict Compliance** - - Follow these rules exactly. Non-compliance results in termination. -"# - -// Summarization prompt template -template_string SummarizeContentPrompt() #" - You are a top-tier summarization engine. Your task is to summarize text and make it versatile. - Be brief and concise, but keep the important information and the subject. - Use synonym words where possible in order to change the wording but keep the meaning. -"# - -// Code summarization prompt template -template_string SummarizeCodePrompt() #" - You are an expert code analyst. Analyze the provided source code and extract key information: - - 1. Provide a high-level summary of what the code does - 2. List key features and functionality - 3. Identify imports and dependencies - 4. List constants and global variables - 5. Summarize classes with their methods - 6. Summarize standalone functions - 7. Describe the overall workflow if applicable - - Be precise and technical while remaining clear and concise. -"# - -// Detailed template for complex extraction (slower, higher quality) -template_string DetailedExtractContentGraphPrompt() #" - You are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph. - **Nodes** represent entities and concepts. They're akin to Wikipedia nodes. - **Edges** represent relationships between concepts. They're akin to Wikipedia links. - - The aim is to achieve simplicity and clarity in the knowledge graph. - - # 1. Labeling Nodes - **Consistency**: Ensure you use basic or elementary types for node labels. - - For example, when you identify an entity representing a person, always label it as **"Person"**. - - Avoid using more specific terms like "Mathematician" or "Scientist", keep those as "profession" property. - - Don't use too generic terms like "Entity". - **Node IDs**: Never utilize integers as node IDs. - - Node IDs should be names or human-readable identifiers found in the text. - - # 2. Handling Numerical Data and Dates - - For example, when you identify an entity representing a date, make sure it has type **"Date"**. - - Extract the date in the format "YYYY-MM-DD" - - If not possible to extract the whole date, extract month or year, or both if available. - - **Property Format**: Properties must be in a key-value format. - - **Quotation Marks**: Never use escaped single or double quotes within property values. - - **Naming Convention**: Use snake_case for relationship names, e.g., `acted_in`. - - # 3. Coreference Resolution - - **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency. - If an entity, such as "John Doe", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., "Joe", "he"), - always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Person's ID. - Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. - - # 4. Strict Compliance - Adhere to the rules strictly. Non-compliance will result in termination. -"# - -// Guided template with step-by-step instructions -template_string GuidedExtractContentGraphPrompt() #" - You are an advanced algorithm designed to extract structured information to build a clean, consistent, and human-readable knowledge graph. - - **Objective**: - - Nodes represent entities and concepts, similar to Wikipedia articles. - - Edges represent typed relationships between nodes, similar to Wikipedia hyperlinks. - - The graph must be clear, minimal, consistent, and semantically precise. - - **Node Guidelines**: - - 1. **Label Consistency**: - - Use consistent, basic types for all node labels. - - Do not switch between granular or vague labels for the same kind of entity. - - Pick one label for each category and apply it uniformly. - - Each entity type should be in a singular form and in a case of multiple words separated by whitespaces - - 2. **Node Identifiers**: - - Node IDs must be human-readable and derived directly from the text. - - Prefer full names and canonical terms. - - Never use integers or autogenerated IDs. - - *Example*: Use "Marie Curie", "Theory of Evolution", "Google". - - 3. **Coreference Resolution**: - - Maintain one consistent node ID for each real-world entity. - - Resolve aliases, acronyms, and pronouns to the most complete form. - - *Example*: Always use "John Doe" even if later referred to as "Doe" or "he". - - **Edge Guidelines**: - - 4. **Relationship Labels**: - - Use descriptive, lowercase, snake_case names for edges. - - *Example*: born_in, married_to, invented_by. - - Avoid vague or generic labels like isA, relatesTo, has. - - 5. **Relationship Direction**: - - Edges must be directional and logically consistent. - - *Example*: - - "Marie Curie" —[born_in]→ "Warsaw" - - "Radioactivity" —[discovered_by]→ "Marie Curie" - - **Compliance**: - Strict adherence to these guidelines is required. Any deviation will result in immediate termination of the task. -"# - -// Strict template with zero-tolerance rules -template_string StrictExtractContentGraphPrompt() #" - You are a top-tier algorithm for **extracting structured information** from unstructured text to build a **knowledge graph**. - - Your primary goal is to extract: - - **Nodes**: Representing **entities** and **concepts** (like Wikipedia nodes). - - **Edges**: Representing **relationships** between those concepts (like Wikipedia links). - - The resulting knowledge graph must be **simple, consistent, and human-readable**. - - ## 1. Node Labeling and Identification - - ### Node Types - Use **basic atomic types** for node labels. Always prefer general types over specific roles or professions: - - "Person" for any human. - - "Organization" for companies, institutions, etc. - - "Location" for geographic or place entities. - - "Date" for any temporal expression. - - "Event" for historical or scheduled occurrences. - - "Work" for books, films, artworks, or research papers. - - "Concept" for abstract notions or ideas. - - ### Node IDs - - Always assign **human-readable and unambiguous identifiers**. - - Never use numeric or autogenerated IDs. - - Prioritize **most complete form** of entity names for consistency. - - ## 2. Relationship Handling - - Use **snake_case** for all relationship (edge) types. - - Keep relationship types semantically clear and consistent. - - Avoid vague relation names like "related_to" unless no better alternative exists. - - ## 3. Strict Compliance - Follow all rules exactly. Any deviation may lead to rejection or incorrect graph construction. -"# - -// OpenAI client with environment model selection -client OpenAI { - provider openai - options { - model client_registry.model - api_key client_registry.api_key - } -} - - - -// Function that returns raw structured output (for custom objects - to be handled in Python) -function ExtractContentGraphGeneric( - content: string, - mode: "simple" | "base" | "guided" | "strict" | "custom"?, - custom_prompt_content: string? -) -> KnowledgeGraph { - client OpenAI - - prompt #" - {% if mode == "base" %} - {{ DetailedExtractContentGraphPrompt() }} - {% elif mode == "guided" %} - {{ GuidedExtractContentGraphPrompt() }} - {% elif mode == "strict" %} - {{ StrictExtractContentGraphPrompt() }} - {% elif mode == "custom" and custom_prompt_content %} - {{ custom_prompt_content }} - {% else %} - {{ ExtractContentGraphPrompt() }} - {% endif %} - - {{ ctx.output_format(prefix="Answer in this schema:\n") }} - - Before answering, briefly describe what you'll extract from the text, then provide the structured output. - - Example format: - I'll extract the main entities and their relationships from this text... - - { ... } - - {{ _.role('user') }} - {{ content }} - "# -} - -// Backward-compatible function specifically for KnowledgeGraph -function ExtractDynamicContentGraph( - content: string, - mode: "simple" | "base" | "guided" | "strict" | "custom"?, - custom_prompt_content: string? -) -> DynamicKnowledgeGraph { - client OpenAI - - prompt #" - {% if mode == "base" %} - {{ DetailedExtractContentGraphPrompt() }} - {% elif mode == "guided" %} - {{ GuidedExtractContentGraphPrompt() }} - {% elif mode == "strict" %} - {{ StrictExtractContentGraphPrompt() }} - {% elif mode == "custom" and custom_prompt_content %} - {{ custom_prompt_content }} - {% else %} - {{ ExtractContentGraphPrompt() }} - {% endif %} - - {{ ctx.output_format(prefix="Answer in this schema:\n") }} - - Before answering, briefly describe what you'll extract from the text, then provide the structured output. - - Example format: - I'll extract the main entities and their relationships from this text... - - { ... } - - {{ _.role('user') }} - {{ content }} - "# -} - - -// Summarization functions -function SummarizeContent(content: string) -> SummarizedContent { - client OpenAI - - prompt #" - {{ SummarizeContentPrompt() }} - - {{ ctx.output_format(prefix="Answer in this schema:\n") }} - - {{ _.role('user') }} - {{ content }} - "# -} - -function SummarizeCode(content: string) -> SummarizedCode { - client OpenAI - - prompt #" - {{ SummarizeCodePrompt() }} - - {{ ctx.output_format(prefix="Answer in this schema:\n") }} - - {{ _.role('user') }} - {{ content }} - "# -} - -test ExtractStrictExample { - functions [ExtractContentGraphGeneric] - args { - content #" - The Python programming language was created by Guido van Rossum in 1991. - "# - mode "strict" - } -} \ No newline at end of file diff --git a/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/__init__.py b/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/__init__.py deleted file mode 100644 index 157cbe7e7..000000000 --- a/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .knowledge_graph.extract_content_graph import extract_content_graph -from .extract_summary import extract_summary, extract_code_summary diff --git a/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/extract_categories.py b/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/extract_categories.py deleted file mode 100644 index 58cb67043..000000000 --- a/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/extract_categories.py +++ /dev/null @@ -1,114 +0,0 @@ -import os -from typing import Type -from pydantic import BaseModel -from cognee.infrastructure.llm.structured_output_framework.baml_src.config import get_llm_config -from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b -from cognee.shared.data_models import SummarizedCode -from cognee.shared.logging_utils import get_logger -from baml_py import ClientRegistry - -config = get_llm_config() - - -logger = get_logger("extract_summary_baml") - - -def get_mock_summarized_code(): - """Local mock function to avoid circular imports.""" - return SummarizedCode( - high_level_summary="Mock code summary", - key_features=["Mock feature 1", "Mock feature 2"], - imports=["mock_import"], - constants=["MOCK_CONSTANT"], - classes=[], - functions=[], - workflow_description="Mock workflow description", - ) - - -async def extract_summary(content: str, response_model: Type[BaseModel]): - """ - Extract summary using BAML framework. - - Args: - content: The content to summarize - response_model: The Pydantic model type for the response - - Returns: - BaseModel: The summarized content in the specified format - """ - config = get_llm_config() - - baml_registry = ClientRegistry() - - baml_registry.add_llm_client( - name="extract_category_client", - provider=config.llm_model, - options={ - "model": config.llm_model, - "temperature": config.llm_temperature, - "api_key": config.llm_api_key, - }, - ) - baml_registry.set_primary("extract_category_client") - - # Use BAML's SummarizeContent function - summary_result = await b.SummarizeContent( - content, baml_options={"client_registry": baml_registry} - ) - - # Convert BAML result to the expected response model - if response_model is SummarizedCode: - # If it's asking for SummarizedCode but we got SummarizedContent, - # we need to use SummarizeCode instead - code_result = await b.SummarizeCode( - content, baml_options={"client_registry": baml_registry} - ) - return code_result - else: - # For other models, return the summary result - return summary_result - - -async def extract_code_summary(content: str): - """ - Extract code summary using BAML framework with mocking support. - - Args: - content: The code content to summarize - - Returns: - SummarizedCode: The summarized code information - """ - enable_mocking = os.getenv("MOCK_CODE_SUMMARY", "false") - if isinstance(enable_mocking, bool): - enable_mocking = str(enable_mocking).lower() - enable_mocking = enable_mocking in ("true", "1", "yes") - - if enable_mocking: - result = get_mock_summarized_code() - return result - else: - try: - config = get_llm_config() - - baml_registry = ClientRegistry() - - baml_registry.add_llm_client( - name="extract_content_category", - provider=config.llm_provider, - options={ - "model": config.llm_model, - "temperature": config.llm_temperature, - "api_key": config.llm_api_key, - }, - ) - baml_registry.set_primary("extract_content_category") - result = await b.SummarizeCode(content, baml_options={"client_registry": baml_registry}) - except Exception as e: - logger.error( - "Failed to extract code summary with BAML, falling back to mock summary", exc_info=e - ) - result = get_mock_summarized_code() - - return result diff --git a/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/extract_summary.py b/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/extract_summary.py deleted file mode 100644 index 3ccc8f817..000000000 --- a/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/extract_summary.py +++ /dev/null @@ -1,114 +0,0 @@ -import os -from typing import Type -from pydantic import BaseModel -from baml_py import ClientRegistry -from cognee.shared.logging_utils import get_logger -from cognee.shared.data_models import SummarizedCode -from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b -from cognee.infrastructure.llm.structured_output_framework.baml_src.config import get_llm_config - -config = get_llm_config() - - -logger = get_logger("extract_summary_baml") - - -def get_mock_summarized_code(): - """Local mock function to avoid circular imports.""" - return SummarizedCode( - high_level_summary="Mock code summary", - key_features=["Mock feature 1", "Mock feature 2"], - imports=["mock_import"], - constants=["MOCK_CONSTANT"], - classes=[], - functions=[], - workflow_description="Mock workflow description", - ) - - -async def extract_summary(content: str, response_model: Type[BaseModel]): - """ - Extract summary using BAML framework. - - Args: - content: The content to summarize - response_model: The Pydantic model type for the response - - Returns: - BaseModel: The summarized content in the specified format - """ - config = get_llm_config() - - baml_registry = ClientRegistry() - - baml_registry.add_llm_client( - name="def", - provider="openai", - options={ - "model": config.llm_model, - "temperature": config.llm_temperature, - "api_key": config.llm_api_key, - }, - ) - baml_registry.set_primary("def") - - # Use BAML's SummarizeContent function - summary_result = await b.SummarizeContent( - content, baml_options={"client_registry": baml_registry} - ) - - # Convert BAML result to the expected response model - if response_model is SummarizedCode: - # If it's asking for SummarizedCode but we got SummarizedContent, - # we need to use SummarizeCode instead - code_result = await b.SummarizeCode( - content, baml_options={"client_registry": config.baml_registry} - ) - return code_result - else: - # For other models, return the summary result - return summary_result - - -async def extract_code_summary(content: str): - """ - Extract code summary using BAML framework with mocking support. - - Args: - content: The code content to summarize - - Returns: - SummarizedCode: The summarized code information - """ - enable_mocking = os.getenv("MOCK_CODE_SUMMARY", "false") - if isinstance(enable_mocking, bool): - enable_mocking = str(enable_mocking).lower() - enable_mocking = enable_mocking in ("true", "1", "yes") - - if enable_mocking: - result = get_mock_summarized_code() - return result - else: - try: - config = get_llm_config() - - baml_registry = ClientRegistry() - - baml_registry.add_llm_client( - name="def", - provider="openai", - options={ - "model": config.llm_model, - "temperature": config.llm_temperature, - "api_key": config.llm_api_key, - }, - ) - baml_registry.set_primary("def") - result = await b.SummarizeCode(content, baml_options={"client_registry": baml_registry}) - except Exception as e: - logger.error( - "Failed to extract code summary with BAML, falling back to mock summary", exc_info=e - ) - result = get_mock_summarized_code() - - return result diff --git a/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/knowledge_graph/__init__.py b/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/knowledge_graph/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/knowledge_graph/extract_content_graph.py b/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/knowledge_graph/extract_content_graph.py deleted file mode 100644 index 465b4843c..000000000 --- a/cognee/infrastructure/llm/structured_output_framework/baml_src/extraction/knowledge_graph/extract_content_graph.py +++ /dev/null @@ -1,49 +0,0 @@ -from baml_py import ClientRegistry -from typing import Type -from pydantic import BaseModel -from cognee.infrastructure.llm.structured_output_framework.baml_src.config import get_llm_config -from cognee.shared.logging_utils import get_logger, setup_logging -from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b - -config = get_llm_config() - - -async def extract_content_graph( - content: str, response_model: Type[BaseModel], mode: str = "simple" -): - config = get_llm_config() - setup_logging() - - get_logger(level="INFO") - - baml_registry = ClientRegistry() - - baml_registry.add_llm_client( - name="extract_content_client", - provider=config.llm_provider, - options={ - "model": config.llm_model, - "temperature": config.llm_temperature, - "api_key": config.llm_api_key, - }, - ) - baml_registry.set_primary("extract_content_client") - - # if response_model: - # # tb = TypeBuilder() - # # country = tb.union \ - # # ([tb.literal_string("USA"), tb.literal_string("UK"), tb.literal_string("Germany"), tb.literal_string("other")]) - # # tb.Node.add_property("country", country) - # - # graph = await b.ExtractDynamicContentGraph( - # content, mode=mode, baml_options={"client_registry": baml_registry} - # ) - # - # return graph - - # else: - graph = await b.ExtractContentGraphGeneric( - content, mode=mode, baml_options={"client_registry": baml_registry} - ) - - return graph diff --git a/cognee/infrastructure/llm/structured_output_framework/baml_src/generators.baml b/cognee/infrastructure/llm/structured_output_framework/baml_src/generators.baml deleted file mode 100644 index ac4f08efb..000000000 --- a/cognee/infrastructure/llm/structured_output_framework/baml_src/generators.baml +++ /dev/null @@ -1,18 +0,0 @@ -// This helps use auto generate libraries you can use in the language of -// your choice. You can have multiple generators if you use multiple languages. -// Just ensure that the output_dir is different for each generator. -generator target { - // Valid values: "python/pydantic", "typescript", "ruby/sorbet", "rest/openapi" - output_type "python/pydantic" - - // Where the generated code will be saved (relative to baml_src/) - output_dir "../baml/" - - // The version of the BAML package you have installed (e.g. same version as your baml-py or @boundaryml/baml). - // The BAML VSCode extension version should also match this version. - version "0.201.0" - - // Valid values: "sync", "async" - // This controls what `b.FunctionName()` will be (sync or async). - default_client_mode sync -} diff --git a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/extraction/__init__.py b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/extraction/__init__.py index 157cbe7e7..3d4edab27 100644 --- a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/extraction/__init__.py +++ b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/extraction/__init__.py @@ -1,2 +1,3 @@ from .knowledge_graph.extract_content_graph import extract_content_graph +from .extract_categories import extract_categories from .extract_summary import extract_summary, extract_code_summary diff --git a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/extraction/knowledge_graph/extract_content_graph.py b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/extraction/knowledge_graph/extract_content_graph.py index aa243246d..c532fcbd0 100644 --- a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/extraction/knowledge_graph/extract_content_graph.py +++ b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/extraction/knowledge_graph/extract_content_graph.py @@ -7,7 +7,7 @@ from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.l from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( render_prompt, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import ( +from cognee.infrastructure.llm.config import ( get_llm_config, ) diff --git a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/config.py b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/config.py deleted file mode 100644 index b26d3d463..000000000 --- a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/config.py +++ /dev/null @@ -1,176 +0,0 @@ -import os -from typing import Optional -from functools import lru_cache -from pydantic_settings import BaseSettings, SettingsConfigDict -from pydantic import model_validator - - -class LLMConfig(BaseSettings): - """ - Configuration settings for the LLM (Large Language Model) provider and related options. - - Public instance variables include: - - llm_provider - - llm_model - - llm_endpoint - - llm_api_key - - llm_api_version - - llm_temperature - - llm_streaming - - llm_max_tokens - - transcription_model - - graph_prompt_path - - llm_rate_limit_enabled - - llm_rate_limit_requests - - llm_rate_limit_interval - - embedding_rate_limit_enabled - - embedding_rate_limit_requests - - embedding_rate_limit_interval - - Public methods include: - - ensure_env_vars_for_ollama - - to_dict - """ - - llm_provider: str = "openai" - llm_model: str = "gpt-4o-mini" - llm_endpoint: str = "" - llm_api_key: Optional[str] = None - llm_api_version: Optional[str] = None - llm_temperature: float = 0.0 - llm_streaming: bool = False - llm_max_tokens: int = 16384 - transcription_model: str = "whisper-1" - graph_prompt_path: str = "generate_graph_prompt.txt" - llm_rate_limit_enabled: bool = False - llm_rate_limit_requests: int = 60 - llm_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute) - embedding_rate_limit_enabled: bool = False - embedding_rate_limit_requests: int = 60 - embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute) - - fallback_api_key: str = "" - fallback_endpoint: str = "" - fallback_model: str = "" - - model_config = SettingsConfigDict(env_file=".env", extra="allow") - - @model_validator(mode="after") - def ensure_env_vars_for_ollama(self) -> "LLMConfig": - """ - Validate required environment variables for the 'ollama' LLM provider. - - Raises ValueError if some required environment variables are set without the others. - Only checks are performed when 'llm_provider' is set to 'ollama'. - - Returns: - -------- - - - 'LLMConfig': The instance of LLMConfig after validation. - """ - - if self.llm_provider != "ollama": - # Skip checks unless provider is "ollama" - return self - - def is_env_set(var_name: str) -> bool: - """ - Check if a given environment variable is set and non-empty. - - Parameters: - ----------- - - - var_name (str): The name of the environment variable to check. - - Returns: - -------- - - - bool: True if the environment variable exists and is not empty, otherwise False. - """ - val = os.environ.get(var_name) - return val is not None and val.strip() != "" - - # - # 1. Check LLM environment variables - # - llm_env_vars = { - "LLM_MODEL": is_env_set("LLM_MODEL"), - "LLM_ENDPOINT": is_env_set("LLM_ENDPOINT"), - "LLM_API_KEY": is_env_set("LLM_API_KEY"), - } - if any(llm_env_vars.values()) and not all(llm_env_vars.values()): - missing_llm = [key for key, is_set in llm_env_vars.items() if not is_set] - raise ValueError( - "You have set some but not all of the required environment variables " - f"for LLM usage (LLM_MODEL, LLM_ENDPOINT, LLM_API_KEY). Missing: {missing_llm}" - ) - - # - # 2. Check embedding environment variables - # - embedding_env_vars = { - "EMBEDDING_PROVIDER": is_env_set("EMBEDDING_PROVIDER"), - "EMBEDDING_MODEL": is_env_set("EMBEDDING_MODEL"), - "EMBEDDING_DIMENSIONS": is_env_set("EMBEDDING_DIMENSIONS"), - "HUGGINGFACE_TOKENIZER": is_env_set("HUGGINGFACE_TOKENIZER"), - } - if any(embedding_env_vars.values()) and not all(embedding_env_vars.values()): - missing_embed = [key for key, is_set in embedding_env_vars.items() if not is_set] - raise ValueError( - "You have set some but not all of the required environment variables " - "for embeddings (EMBEDDING_PROVIDER, EMBEDDING_MODEL, " - "EMBEDDING_DIMENSIONS, HUGGINGFACE_TOKENIZER). Missing: " - f"{missing_embed}" - ) - - return self - - def to_dict(self) -> dict: - """ - Convert the LLMConfig instance into a dictionary representation. - - Returns: - -------- - - - dict: A dictionary containing the configuration settings of the LLMConfig - instance. - """ - return { - "provider": self.llm_provider, - "model": self.llm_model, - "endpoint": self.llm_endpoint, - "api_key": self.llm_api_key, - "api_version": self.llm_api_version, - "temperature": self.llm_temperature, - "streaming": self.llm_streaming, - "max_tokens": self.llm_max_tokens, - "transcription_model": self.transcription_model, - "graph_prompt_path": self.graph_prompt_path, - "rate_limit_enabled": self.llm_rate_limit_enabled, - "rate_limit_requests": self.llm_rate_limit_requests, - "rate_limit_interval": self.llm_rate_limit_interval, - "embedding_rate_limit_enabled": self.embedding_rate_limit_enabled, - "embedding_rate_limit_requests": self.embedding_rate_limit_requests, - "embedding_rate_limit_interval": self.embedding_rate_limit_interval, - "fallback_api_key": self.fallback_api_key, - "fallback_endpoint": self.fallback_endpoint, - "fallback_model": self.fallback_model, - } - - -@lru_cache -def get_llm_config(): - """ - Retrieve and cache the LLM configuration. - - This function returns an instance of the LLMConfig class. It leverages - caching to ensure that repeated calls do not create new instances, - but instead return the already created configuration object. - - Returns: - -------- - - - LLMConfig: An instance of the LLMConfig class containing the configuration for the - LLM. - """ - return LLMConfig() diff --git a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/embedding_rate_limiter.py b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/embedding_rate_limiter.py deleted file mode 100644 index 145bc3457..000000000 --- a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/embedding_rate_limiter.py +++ /dev/null @@ -1,550 +0,0 @@ -import threading -import logging -import functools -import os -import time -import asyncio -import random -from cognee.shared.logging_utils import get_logger -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import ( - get_llm_config, -) - - -logger = get_logger() - -# Common error patterns that indicate rate limiting -RATE_LIMIT_ERROR_PATTERNS = [ - "rate limit", - "rate_limit", - "ratelimit", - "too many requests", - "retry after", - "capacity", - "quota", - "limit exceeded", - "tps limit exceeded", - "request limit exceeded", - "maximum requests", - "exceeded your current quota", - "throttled", - "throttling", -] - -# Default retry settings -DEFAULT_MAX_RETRIES = 5 -DEFAULT_INITIAL_BACKOFF = 1.0 # seconds -DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier -DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd - - -class EmbeddingRateLimiter: - """ - Rate limiter for embedding API calls. - - This class implements a singleton pattern to ensure that rate limiting - is consistent across all embedding requests. It uses the limits - library with a moving window strategy to control request rates. - - The rate limiter uses the same configuration as the LLM API rate limiter - but uses a separate key to track embedding API calls independently. - - Public Methods: - - get_instance - - reset_instance - - hit_limit - - wait_if_needed - - async_wait_if_needed - - Instance Variables: - - enabled - - requests_limit - - interval_seconds - - request_times - - lock - """ - - _instance = None - lock = threading.Lock() - - @classmethod - def get_instance(cls): - """ - Retrieve the singleton instance of the EmbeddingRateLimiter. - - This method ensures that only one instance of the class exists and - is thread-safe. It lazily initializes the instance if it doesn't - already exist. - - Returns: - -------- - - The singleton instance of the EmbeddingRateLimiter class. - """ - if cls._instance is None: - with cls.lock: - if cls._instance is None: - cls._instance = cls() - return cls._instance - - @classmethod - def reset_instance(cls): - """ - Reset the singleton instance of the EmbeddingRateLimiter. - - This method is thread-safe and sets the instance to None, allowing - for a new instance to be created when requested again. - """ - with cls.lock: - cls._instance = None - - def __init__(self): - config = get_llm_config() - self.enabled = config.embedding_rate_limit_enabled - self.requests_limit = config.embedding_rate_limit_requests - self.interval_seconds = config.embedding_rate_limit_interval - self.request_times = [] - self.lock = threading.Lock() - - logging.info( - f"EmbeddingRateLimiter initialized: enabled={self.enabled}, " - f"requests_limit={self.requests_limit}, interval_seconds={self.interval_seconds}" - ) - - def hit_limit(self) -> bool: - """ - Check if the current request would exceed the rate limit. - - This method checks if the rate limiter is enabled and evaluates - the number of requests made in the elapsed interval. - - Returns: - - bool: True if the rate limit would be exceeded, False otherwise. - - Returns: - -------- - - - bool: True if the rate limit would be exceeded, otherwise False. - """ - if not self.enabled: - return False - - current_time = time.time() - - with self.lock: - # Remove expired request times - cutoff_time = current_time - self.interval_seconds - self.request_times = [t for t in self.request_times if t > cutoff_time] - - # Check if adding a new request would exceed the limit - if len(self.request_times) >= self.requests_limit: - logger.info( - f"Rate limit hit: {len(self.request_times)} requests in the last {self.interval_seconds} seconds" - ) - return True - - # Otherwise, we're under the limit - return False - - def wait_if_needed(self) -> float: - """ - Block until a request can be made without exceeding the rate limit. - - This method will wait if the current request would exceed the - rate limit and returns the time waited in seconds. - - Returns: - - float: Time waited in seconds before a request is allowed. - - Returns: - -------- - - - float: Time waited in seconds before proceeding. - """ - if not self.enabled: - return 0 - - wait_time = 0 - start_time = time.time() - - while self.hit_limit(): - time.sleep(0.5) # Poll every 0.5 seconds - wait_time = time.time() - start_time - - # Record this request - with self.lock: - self.request_times.append(time.time()) - - return wait_time - - async def async_wait_if_needed(self) -> float: - """ - Asynchronously wait until a request can be made without exceeding the rate limit. - - This method will wait if the current request would exceed the - rate limit and returns the time waited in seconds. - - Returns: - - float: Time waited in seconds before a request is allowed. - - Returns: - -------- - - - float: Time waited in seconds before proceeding. - """ - if not self.enabled: - return 0 - - wait_time = 0 - start_time = time.time() - - while self.hit_limit(): - await asyncio.sleep(0.5) # Poll every 0.5 seconds - wait_time = time.time() - start_time - - # Record this request - with self.lock: - self.request_times.append(time.time()) - - return wait_time - - -def embedding_rate_limit_sync(func): - """ - Apply rate limiting to a synchronous embedding function. - - Parameters: - ----------- - - - func: Function to decorate with rate limiting logic. - - Returns: - -------- - - Returns the decorated function that applies rate limiting. - """ - - @functools.wraps(func) - def wrapper(*args, **kwargs): - """ - Wrap the given function with rate limiting logic to control the embedding API usage. - - Checks if the rate limit has been exceeded before allowing the function to execute. If - the limit is hit, it logs a warning and raises an EmbeddingException. Otherwise, it - updates the request count and proceeds to call the original function. - - Parameters: - ----------- - - - *args: Variable length argument list for the wrapped function. - - **kwargs: Keyword arguments for the wrapped function. - - Returns: - -------- - - Returns the result of the wrapped function if rate limiting conditions are met. - """ - limiter = EmbeddingRateLimiter.get_instance() - - # Check if rate limiting is enabled and if we're at the limit - if limiter.hit_limit(): - error_msg = "Embedding API rate limit exceeded" - logger.warning(error_msg) - - # Create a custom embedding rate limit exception - from cognee.infrastructure.databases.exceptions.EmbeddingException import ( - EmbeddingException, - ) - - raise EmbeddingException(error_msg) - - # Add this request to the counter and proceed - limiter.wait_if_needed() - return func(*args, **kwargs) - - return wrapper - - -def embedding_rate_limit_async(func): - """ - Decorator that applies rate limiting to an asynchronous embedding function. - - Parameters: - ----------- - - - func: Async function to decorate. - - Returns: - -------- - - Returns the decorated async function that applies rate limiting. - """ - - @functools.wraps(func) - async def wrapper(*args, **kwargs): - """ - Handle function calls with embedding rate limiting. - - This asynchronous wrapper checks if the embedding API rate limit is exceeded before - allowing the function to execute. If the limit is exceeded, it logs a warning and raises - an EmbeddingException. If not, it waits as necessary and proceeds with the function - call. - - Parameters: - ----------- - - - *args: Positional arguments passed to the wrapped function. - - **kwargs: Keyword arguments passed to the wrapped function. - - Returns: - -------- - - Returns the result of the wrapped function after handling rate limiting. - """ - limiter = EmbeddingRateLimiter.get_instance() - - # Check if rate limiting is enabled and if we're at the limit - if limiter.hit_limit(): - error_msg = "Embedding API rate limit exceeded" - logger.warning(error_msg) - - # Create a custom embedding rate limit exception - from cognee.infrastructure.databases.exceptions.EmbeddingException import ( - EmbeddingException, - ) - - raise EmbeddingException(error_msg) - - # Add this request to the counter and proceed - await limiter.async_wait_if_needed() - return await func(*args, **kwargs) - - return wrapper - - -def embedding_sleep_and_retry_sync(max_retries=5, base_backoff=1.0, jitter=0.5): - """ - Add retry with exponential backoff for synchronous embedding functions. - - Parameters: - ----------- - - - max_retries: Maximum number of retries before giving up. (default 5) - - base_backoff: Base backoff time in seconds for retry intervals. (default 1.0) - - jitter: Jitter factor to randomize the backoff time to avoid collision. (default - 0.5) - - Returns: - -------- - - A decorator that retries the wrapped function on rate limit errors, applying - exponential backoff with jitter. - """ - - def decorator(func): - """ - Wraps a function to apply retry logic on rate limit errors. - - Parameters: - ----------- - - - func: The function to be wrapped with retry logic. - - Returns: - -------- - - Returns the wrapped function with retry logic applied. - """ - - @functools.wraps(func) - def wrapper(*args, **kwargs): - """ - Retry the execution of a function with backoff on failure due to rate limit errors. - - This wrapper function will call the specified function and if it raises an exception, it - will handle retries according to defined conditions. It will check the environment for a - DISABLE_RETRIES flag to determine whether to retry or propagate errors immediately - during tests. If the error is identified as a rate limit error, it will apply an - exponential backoff strategy with jitter before retrying, up to a maximum number of - retries. If the retries are exhausted, it raises the last encountered error. - - Parameters: - ----------- - - - *args: Positional arguments passed to the wrapped function. - - **kwargs: Keyword arguments passed to the wrapped function. - - Returns: - -------- - - Returns the result of the wrapped function if successful; otherwise, raises the last - error encountered after maximum retries are exhausted. - """ - # If DISABLE_RETRIES is set, don't retry for testing purposes - disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in ( - "true", - "1", - "yes", - ) - - retries = 0 - last_error = None - - while retries <= max_retries: - try: - return func(*args, **kwargs) - except Exception as e: - # Check if this is a rate limit error - error_str = str(e).lower() - error_type = type(e).__name__ - is_rate_limit = any( - pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS - ) - - if disable_retries: - # For testing, propagate the exception immediately - raise - - if is_rate_limit and retries < max_retries: - # Calculate backoff with jitter - backoff = ( - base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter)) - ) - - logger.warning( - f"Embedding rate limit hit, retrying in {backoff:.2f}s " - f"(attempt {retries + 1}/{max_retries}): " - f"({error_str!r}, {error_type!r})" - ) - - time.sleep(backoff) - retries += 1 - last_error = e - else: - # Not a rate limit error or max retries reached, raise - raise - - # If we exit the loop due to max retries, raise the last error - if last_error: - raise last_error - - return wrapper - - return decorator - - -def embedding_sleep_and_retry_async(max_retries=5, base_backoff=1.0, jitter=0.5): - """ - Add retry logic with exponential backoff for asynchronous embedding functions. - - This decorator retries the wrapped asynchronous function upon encountering rate limit - errors, utilizing exponential backoff with optional jitter to space out retry attempts. - It allows for a maximum number of retries before giving up and raising the last error - encountered. - - Parameters: - ----------- - - - max_retries: Maximum number of retries allowed before giving up. (default 5) - - base_backoff: Base amount of time in seconds to wait before retrying after a rate - limit error. (default 1.0) - - jitter: Amount of randomness to add to the backoff duration to help mitigate burst - issues on retries. (default 0.5) - - Returns: - -------- - - Returns a decorated asynchronous function that implements the retry logic on rate - limit errors. - """ - - def decorator(func): - """ - Handle retries for an async function with exponential backoff and jitter. - - Parameters: - ----------- - - - func: An asynchronous function to be wrapped with retry logic. - - Returns: - -------- - - Returns the wrapper function that manages the retry behavior for the wrapped async - function. - """ - - @functools.wraps(func) - async def wrapper(*args, **kwargs): - """ - Handle retries for an async function with exponential backoff and jitter. - - If the environment variable DISABLE_RETRIES is set to true, 1, or yes, the function will - not retry on errors. - It attempts to call the wrapped function until it succeeds or the maximum number of - retries is reached. If an exception occurs, it checks if it's a rate limit error to - determine if a retry is needed. - - Parameters: - ----------- - - - *args: Positional arguments passed to the wrapped function. - - **kwargs: Keyword arguments passed to the wrapped function. - - Returns: - -------- - - Returns the result of the wrapped async function if successful; raises the last - encountered error if all retries fail. - """ - # If DISABLE_RETRIES is set, don't retry for testing purposes - disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in ( - "true", - "1", - "yes", - ) - - retries = 0 - last_error = None - - while retries <= max_retries: - try: - return await func(*args, **kwargs) - except Exception as e: - # Check if this is a rate limit error - error_str = str(e).lower() - error_type = type(e).__name__ - is_rate_limit = any( - pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS - ) - - if disable_retries: - # For testing, propagate the exception immediately - raise - - if is_rate_limit and retries < max_retries: - # Calculate backoff with jitter - backoff = ( - base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter)) - ) - - logger.warning( - f"Embedding rate limit hit, retrying in {backoff:.2f}s " - f"(attempt {retries + 1}/{max_retries}): " - f"({error_str!r}, {error_type!r})" - ) - - await asyncio.sleep(backoff) - retries += 1 - last_error = e - else: - # Not a rate limit error or max retries reached, raise - raise - - # If we exit the loop due to max retries, raise the last error - if last_error: - raise last_error - - return wrapper - - return decorator diff --git a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/generic_llm_api/adapter.py index 421334cd0..b97d2342c 100644 --- a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/generic_llm_api/adapter.py @@ -1,9 +1,15 @@ """Adapter for Generic API LLM provider API""" -from typing import Type - -from pydantic import BaseModel +import logging +import litellm import instructor +from typing import Type +from pydantic import BaseModel +from openai import ContentFilterFinishReasonError +from litellm.exceptions import ContentPolicyViolationError +from instructor.exceptions import InstructorRetryException + +from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.llm_interface import ( LLMInterface, ) @@ -11,7 +17,6 @@ from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.l rate_limit_async, sleep_and_retry_async, ) -import litellm class GenericAPIAdapter(LLMInterface): @@ -31,13 +36,27 @@ class GenericAPIAdapter(LLMInterface): model: str api_key: str - def __init__(self, endpoint, api_key: str, model: str, name: str, max_tokens: int): + def __init__( + self, + endpoint, + api_key: str, + model: str, + name: str, + max_tokens: int, + fallback_model: str = None, + fallback_api_key: str = None, + fallback_endpoint: str = None, + ): self.name = name self.model = model self.api_key = api_key self.endpoint = endpoint self.max_tokens = max_tokens + self.fallback_model = fallback_model + self.fallback_api_key = fallback_api_key + self.fallback_endpoint = fallback_endpoint + self.aclient = instructor.from_litellm( litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key ) @@ -70,19 +89,68 @@ class GenericAPIAdapter(LLMInterface): output from the language model. """ - return await self.aclient.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": f"""{text_input}""", - }, - { - "role": "system", - "content": system_prompt, - }, - ], - max_retries=5, - api_base=self.endpoint, - response_model=response_model, - ) + try: + return await self.aclient.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": f"""{text_input}""", + }, + { + "role": "system", + "content": system_prompt, + }, + ], + max_retries=5, + api_base=self.endpoint, + response_model=response_model, + ) + except ( + ContentFilterFinishReasonError, + ContentPolicyViolationError, + InstructorRetryException, + ) as error: + if ( + isinstance(error, InstructorRetryException) + and "content management policy" not in str(error).lower() + ): + raise error + + if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint): + raise ContentPolicyFilterError( + f"The provided input contains content that is not aligned with our content policy: {text_input}" + ) + + try: + return await self.aclient.chat.completions.create( + model=self.fallback_model, + messages=[ + { + "role": "user", + "content": f"""{text_input}""", + }, + { + "role": "system", + "content": system_prompt, + }, + ], + max_retries=5, + api_key=self.fallback_api_key, + api_base=self.fallback_endpoint, + response_model=response_model, + ) + except ( + ContentFilterFinishReasonError, + ContentPolicyViolationError, + InstructorRetryException, + ) as error: + if ( + isinstance(error, InstructorRetryException) + and "content management policy" not in str(error).lower() + ): + raise error + else: + raise ContentPolicyFilterError( + f"The provided input contains content that is not aligned with our content policy: {text_input}" + ) diff --git a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/get_llm_client.py index 9361f47fe..7fdf67765 100644 --- a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/get_llm_client.py @@ -50,7 +50,7 @@ def get_llm_client(): # Check if max_token value is defined in liteLLM for given model # if not use value from cognee configuration - from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.utils import ( + from cognee.infrastructure.llm.utils import ( get_model_max_tokens, ) # imported here to avoid circular imports diff --git a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/rate_limiter.py b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/rate_limiter.py index bb39713f0..4230bb3d0 100644 --- a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/rate_limiter.py +++ b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/rate_limiter.py @@ -49,9 +49,7 @@ from functools import wraps from limits import RateLimitItemPerMinute, storage from limits.strategies import MovingWindowRateLimiter from cognee.shared.logging_utils import get_logger -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import ( - get_llm_config, -) +from cognee.infrastructure.llm.config import get_llm_config logger = get_logger() diff --git a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/utils.py b/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/utils.py deleted file mode 100644 index e66743762..000000000 --- a/cognee/infrastructure/llm/structured_output_framework/llitellm_instructor/llm/utils.py +++ /dev/null @@ -1,107 +0,0 @@ -import litellm - -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) -from cognee.shared.logging_utils import get_logger - -logger = get_logger() - - -def get_max_chunk_tokens(): - """ - Calculate the maximum number of tokens allowed in a chunk. - - The function determines the maximum chunk size based on the maximum token limit of the - embedding engine and half of the LLM maximum context token size. It ensures that the - chunk size does not exceed these constraints. - - Returns: - -------- - - - int: The maximum number of tokens that can be included in a chunk, determined by - the smaller value of the embedding engine's max tokens and half of the LLM's - maximum tokens. - """ - # NOTE: Import must be done in function to avoid circular import issue - from cognee.infrastructure.databases.vector import get_vector_engine - - # Calculate max chunk size based on the following formula - embedding_engine = get_vector_engine().embedding_engine - llm_client = get_llm_client() - - # We need to make sure chunk size won't take more than half of LLM max context token size - # but it also can't be bigger than the embedding engine max token size - llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division - max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point) - - return max_chunk_tokens - - -def get_model_max_tokens(model_name: str): - """ - Retrieve the maximum token limit for a specified model name if it exists. - - Checks if the provided model name is present in the predefined model cost dictionary. If - found, it logs the maximum token count for that model and returns it. If the model name - is not recognized, it logs an informational message and returns None. - - Parameters: - ----------- - - - model_name (str): Name of LLM or embedding model - - Returns: - -------- - - Number of max tokens of model, or None if model is unknown - """ - max_tokens = None - - if model_name in litellm.model_cost: - max_tokens = litellm.model_cost[model_name]["max_tokens"] - logger.debug(f"Max input tokens for {model_name}: {max_tokens}") - else: - logger.info("Model not found in LiteLLM's model_cost.") - - return max_tokens - - -async def test_llm_connection(): - """ - Establish a connection to the LLM and create a structured output. - - Attempt to connect to the LLM client and uses the adapter to create a structured output - with a predefined text input and system prompt. Log any exceptions encountered during - the connection attempt and re-raise the exception for further handling. - """ - try: - llm_adapter = get_llm_client() - await llm_adapter.acreate_structured_output( - text_input="test", - system_prompt='Respond to me with the following string: "test"', - response_model=str, - ) - - except Exception as e: - logger.error(e) - logger.error("Connection to LLM could not be established.") - raise e - - -async def test_embedding_connection(): - """ - Test the connection to the embedding engine by embedding a sample text. - - Handles exceptions that may occur during the operation, logs the error, and re-raises - the exception if the connection to the embedding handler cannot be established. - """ - try: - # NOTE: Vector engine import must be done in function to avoid circular import issue - from cognee.infrastructure.databases.vector import get_vector_engine - - await get_vector_engine().embedding_engine.embed_text("test") - except Exception as e: - logger.error(e) - logger.error("Connection to Embedding handler could not be established.") - raise e diff --git a/cognee/modules/data/processing/document_types/AudioDocument.py b/cognee/modules/data/processing/document_types/AudioDocument.py index b2cfcaabb..8d3971cdd 100644 --- a/cognee/modules/data/processing/document_types/AudioDocument.py +++ b/cognee/modules/data/processing/document_types/AudioDocument.py @@ -1,7 +1,5 @@ -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) from cognee.modules.chunking.Chunker import Chunker +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter from .Document import Document @@ -10,7 +8,7 @@ class AudioDocument(Document): type: str = "audio" async def create_transcript(self): - result = await get_llm_client().create_transcript(self.raw_data_location) + result = await LLMAdapter.create_transcript(self.raw_data_location) return result.text async def read(self, chunker_cls: Chunker, max_chunk_size: int): diff --git a/cognee/modules/data/processing/document_types/ImageDocument.py b/cognee/modules/data/processing/document_types/ImageDocument.py index 5eec618b9..b99ecb8b7 100644 --- a/cognee/modules/data/processing/document_types/ImageDocument.py +++ b/cognee/modules/data/processing/document_types/ImageDocument.py @@ -1,6 +1,4 @@ -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter from cognee.modules.chunking.Chunker import Chunker from .Document import Document @@ -10,7 +8,7 @@ class ImageDocument(Document): type: str = "image" async def transcribe_image(self): - result = await get_llm_client().transcribe_image(self.raw_data_location) + result = await LLMAdapter.transcribe_image(self.raw_data_location) return result.choices[0].message.content async def read(self, chunker_cls: Chunker, max_chunk_size: int): diff --git a/cognee/modules/pipelines/operations/pipeline.py b/cognee/modules/pipelines/operations/pipeline.py index 9672dab11..938744733 100644 --- a/cognee/modules/pipelines/operations/pipeline.py +++ b/cognee/modules/pipelines/operations/pipeline.py @@ -69,7 +69,7 @@ async def cognee_pipeline( cognee_pipeline.first_run = True if cognee_pipeline.first_run: - from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.utils import ( + from cognee.infrastructure.llm.utils import ( test_llm_connection, test_embedding_connection, ) diff --git a/cognee/modules/retrieval/code_retriever.py b/cognee/modules/retrieval/code_retriever.py index 309e6ac1f..a177bc5dd 100644 --- a/cognee/modules/retrieval/code_retriever.py +++ b/cognee/modules/retrieval/code_retriever.py @@ -7,12 +7,7 @@ from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( - read_query_prompt, -) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter logger = get_logger("CodeRetriever") @@ -46,11 +41,10 @@ class CodeRetriever(BaseRetriever): f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'" ) - system_prompt = read_query_prompt("codegraph_retriever_system.txt") - llm_client = get_llm_client() + system_prompt = LLMAdapter.read_query_prompt("codegraph_retriever_system.txt") try: - result = await llm_client.acreate_structured_output( + result = await LLMAdapter.acreate_structured_output( text_input=query, system_prompt=system_prompt, response_model=self.CodeQueryInfo, diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 722f0e04c..3ce480ab5 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -1,14 +1,9 @@ from typing import Any, Optional, List, Type from cognee.shared.logging_utils import get_logger -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) + from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.utils.completion import generate_completion -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( - read_query_prompt, - render_prompt, -) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter logger = get_logger() @@ -78,7 +73,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer to the user's query. """ - llm_client = get_llm_client() followup_question = "" triplets = [] answer = [""] @@ -100,27 +94,27 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): logger.info(f"Chain-of-thought: round {round_idx} - answer: {answer}") if round_idx < max_iter: valid_args = {"query": query, "answer": answer, "context": context} - valid_user_prompt = render_prompt( + valid_user_prompt = LLMAdapter.render_prompt( filename=self.validation_user_prompt_path, context=valid_args ) - valid_system_prompt = read_query_prompt( + valid_system_prompt = LLMAdapter.read_query_prompt( prompt_file_name=self.validation_system_prompt_path ) - reasoning = await llm_client.acreate_structured_output( + reasoning = await LLMAdapter.acreate_structured_output( text_input=valid_user_prompt, system_prompt=valid_system_prompt, response_model=str, ) followup_args = {"query": query, "answer": answer, "reasoning": reasoning} - followup_prompt = render_prompt( + followup_prompt = LLMAdapter.render_prompt( filename=self.followup_user_prompt_path, context=followup_args ) - followup_system = read_query_prompt( + followup_system = LLMAdapter.read_query_prompt( prompt_file_name=self.followup_system_prompt_path ) - followup_question = await llm_client.acreate_structured_output( + followup_question = await LLMAdapter.llm_client.acreate_structured_output( text_input=followup_prompt, system_prompt=followup_system, response_model=str ) logger.info( diff --git a/cognee/modules/retrieval/natural_language_retriever.py b/cognee/modules/retrieval/natural_language_retriever.py index 1d1a916c3..67760eb99 100644 --- a/cognee/modules/retrieval/natural_language_retriever.py +++ b/cognee/modules/retrieval/natural_language_retriever.py @@ -2,12 +2,7 @@ from typing import Any, Optional from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( - render_prompt, -) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.exceptions import SearchTypeNotSupported from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface @@ -55,8 +50,7 @@ class NaturalLanguageRetriever(BaseRetriever): async def _generate_cypher_query(self, query: str, edge_schemas, previous_attempts=None) -> str: """Generate a Cypher query using LLM based on natural language query and schema information.""" - llm_client = get_llm_client() - system_prompt = render_prompt( + system_prompt = LLMAdapter.render_prompt( self.system_prompt_path, context={ "edge_schemas": edge_schemas, @@ -64,7 +58,7 @@ class NaturalLanguageRetriever(BaseRetriever): }, ) - return await llm_client.acreate_structured_output( + return await LLMAdapter.acreate_structured_output( text_input=query, system_prompt=system_prompt, response_model=str, diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index 22a7fce49..5a87e9a89 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -1,10 +1,4 @@ -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( - read_query_prompt, - render_prompt, -) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter async def generate_completion( @@ -15,11 +9,10 @@ async def generate_completion( ) -> str: """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} - user_prompt = render_prompt(user_prompt_path, args) - system_prompt = read_query_prompt(system_prompt_path) + user_prompt = LLMAdapter.render_prompt(user_prompt_path, args) + system_prompt = LLMAdapter.read_query_prompt(system_prompt_path) - llm_client = get_llm_client() - return await llm_client.acreate_structured_output( + return await LLMAdapter.acreate_structured_output( text_input=user_prompt, system_prompt=system_prompt, response_model=str, @@ -31,10 +24,9 @@ async def summarize_text( prompt_path: str = "summarize_search_results.txt", ) -> str: """Summarizes text using LLM with the specified prompt.""" - system_prompt = read_query_prompt(prompt_path) - llm_client = get_llm_client() + system_prompt = LLMAdapter.read_query_prompt(prompt_path) - return await llm_client.acreate_structured_output( + return await LLMAdapter.acreate_structured_output( text_input=text, system_prompt=system_prompt, response_model=str, diff --git a/cognee/modules/retrieval/utils/description_to_codepart_search.py b/cognee/modules/retrieval/utils/description_to_codepart_search.py index 3470968f8..7c285011c 100644 --- a/cognee/modules/retrieval/utils/description_to_codepart_search.py +++ b/cognee/modules/retrieval/utils/description_to_codepart_search.py @@ -9,9 +9,7 @@ from cognee.modules.users.methods import get_default_user from cognee.modules.users.models import User from cognee.shared.utils import send_telemetry from cognee.modules.search.methods import search -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter logger = get_logger(level=ERROR) @@ -73,8 +71,7 @@ async def code_description_to_code_part( if isinstance(obj, dict) and "description" in obj ) - llm_client = get_llm_client() - context_from_documents = await llm_client.acreate_structured_output( + context_from_documents = await LLMAdapter.acreate_structured_output( text_input=f"The retrieved context from documents is {concatenated_descriptions}.", system_prompt="You are a Senior Software Engineer, summarize the context from documents" f" in a way that it is gonna be provided next to codeparts as context" diff --git a/cognee/shared/data_models.py b/cognee/shared/data_models.py index d5093d0cb..4e80ed81a 100644 --- a/cognee/shared/data_models.py +++ b/cognee/shared/data_models.py @@ -4,7 +4,7 @@ from enum import Enum, auto from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import ( +from cognee.infrastructure.llm.config import ( get_llm_config, ) diff --git a/cognee/tasks/chunk_naive_llm_classifier/chunk_naive_llm_classifier.py b/cognee/tasks/chunk_naive_llm_classifier/chunk_naive_llm_classifier.py index 43cd38a5c..ef0eaf6f8 100644 --- a/cognee/tasks/chunk_naive_llm_classifier/chunk_naive_llm_classifier.py +++ b/cognee/tasks/chunk_naive_llm_classifier/chunk_naive_llm_classifier.py @@ -7,9 +7,7 @@ from pydantic import BaseModel from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.engine.models import DataPoint -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction.extract_categories import ( - extract_categories, -) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter from cognee.modules.chunking.models.DocumentChunk import DocumentChunk @@ -42,7 +40,7 @@ async def chunk_naive_llm_classifier( return data_chunks chunk_classifications = await asyncio.gather( - *[extract_categories(chunk.text, classification_model) for chunk in data_chunks], + *[LLMAdapter.extract_categories(chunk.text, classification_model) for chunk in data_chunks], ) classification_data_points = [] diff --git a/cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py b/cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py index c3d89905a..d52e7bf4b 100644 --- a/cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py +++ b/cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py @@ -6,13 +6,7 @@ from pydantic import BaseModel from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor from cognee.modules.engine.models import Entity from cognee.modules.engine.models.EntityType import EntityType -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( - read_query_prompt, - render_prompt, -) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter logger = get_logger("llm_entity_extractor") @@ -56,11 +50,10 @@ class LLMEntityExtractor(BaseEntityExtractor): try: logger.info(f"Extracting entities from text: {text[:100]}...") - llm_client = get_llm_client() - user_prompt = render_prompt(self.user_prompt_template, {"text": text}) - system_prompt = read_query_prompt(self.system_prompt_template) + user_prompt = LLMAdapter.render_prompt(self.user_prompt_template, {"text": text}) + system_prompt = LLMAdapter.read_query_prompt(self.system_prompt_template) - response = await llm_client.acreate_structured_output( + response = await LLMAdapter.acreate_structured_output( text_input=user_prompt, system_prompt=system_prompt, response_model=EntityList, diff --git a/cognee/tasks/graph/cascade_extract/utils/extract_content_nodes_and_relationship_names.py b/cognee/tasks/graph/cascade_extract/utils/extract_content_nodes_and_relationship_names.py index 105062cb5..359ace6f3 100644 --- a/cognee/tasks/graph/cascade_extract/utils/extract_content_nodes_and_relationship_names.py +++ b/cognee/tasks/graph/cascade_extract/utils/extract_content_nodes_and_relationship_names.py @@ -1,13 +1,7 @@ from typing import List, Tuple from pydantic import BaseModel -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( - render_prompt, - read_query_prompt, -) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter from cognee.root_dir import get_absolute_path @@ -22,7 +16,6 @@ async def extract_content_nodes_and_relationship_names( content: str, existing_nodes: List[str], n_rounds: int = 2 ) -> Tuple[List[str], List[str]]: """Extracts node names and relationship_names from content through multiple rounds of analysis.""" - llm_client = get_llm_client() all_nodes: List[str] = existing_nodes.copy() all_relationship_names: List[str] = [] existing_node_set = {node.lower() for node in all_nodes} @@ -39,15 +32,15 @@ async def extract_content_nodes_and_relationship_names( } base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts") - text_input = render_prompt( + text_input = LLMAdapter.render_prompt( "extract_graph_relationship_names_prompt_input.txt", context, base_directory=base_directory, ) - system_prompt = read_query_prompt( + system_prompt = LLMAdapter.read_query_prompt( "extract_graph_relationship_names_prompt_system.txt", base_directory=base_directory ) - response = await llm_client.acreate_structured_output( + response = await LLMAdapter.acreate_structured_output( text_input=text_input, system_prompt=system_prompt, response_model=PotentialNodesAndRelationshipNames, diff --git a/cognee/tasks/graph/cascade_extract/utils/extract_edge_triplets.py b/cognee/tasks/graph/cascade_extract/utils/extract_edge_triplets.py index 119eb6661..fe8ce6696 100644 --- a/cognee/tasks/graph/cascade_extract/utils/extract_edge_triplets.py +++ b/cognee/tasks/graph/cascade_extract/utils/extract_edge_triplets.py @@ -1,11 +1,6 @@ from typing import List -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( - render_prompt, - read_query_prompt, -) + +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter from cognee.shared.data_models import KnowledgeGraph from cognee.root_dir import get_absolute_path @@ -14,7 +9,6 @@ async def extract_edge_triplets( content: str, nodes: List[str], relationship_names: List[str], n_rounds: int = 2 ) -> KnowledgeGraph: """Creates a knowledge graph by identifying relationships between the provided nodes.""" - llm_client = get_llm_client() final_graph = KnowledgeGraph(nodes=[], edges=[]) existing_nodes = set() existing_node_ids = set() @@ -32,13 +26,13 @@ async def extract_edge_triplets( } base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts") - text_input = render_prompt( + text_input = LLMAdapter.render_prompt( "extract_graph_edge_triplets_prompt_input.txt", context, base_directory=base_directory ) - system_prompt = read_query_prompt( + system_prompt = LLMAdapter.read_query_prompt( "extract_graph_edge_triplets_prompt_system.txt", base_directory=base_directory ) - extracted_graph = await llm_client.acreate_structured_output( + extracted_graph = await LLMAdapter.acreate_structured_output( text_input=text_input, system_prompt=system_prompt, response_model=KnowledgeGraph ) diff --git a/cognee/tasks/graph/cascade_extract/utils/extract_nodes.py b/cognee/tasks/graph/cascade_extract/utils/extract_nodes.py index 647be9de7..1cda91924 100644 --- a/cognee/tasks/graph/cascade_extract/utils/extract_nodes.py +++ b/cognee/tasks/graph/cascade_extract/utils/extract_nodes.py @@ -1,13 +1,7 @@ from typing import List from pydantic import BaseModel -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( - render_prompt, - read_query_prompt, -) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter from cognee.root_dir import get_absolute_path @@ -19,7 +13,6 @@ class PotentialNodes(BaseModel): async def extract_nodes(text: str, n_rounds: int = 2) -> List[str]: """Extracts node names from content through multiple rounds of analysis.""" - llm_client = get_llm_client() all_nodes: List[str] = [] existing_nodes = set() @@ -31,13 +24,13 @@ async def extract_nodes(text: str, n_rounds: int = 2) -> List[str]: "text": text, } base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts") - text_input = render_prompt( + text_input = LLMAdapter.render_prompt( "extract_graph_nodes_prompt_input.txt", context, base_directory=base_directory ) - system_prompt = read_query_prompt( + system_prompt = LLMAdapter.read_query_prompt( "extract_graph_nodes_prompt_system.txt", base_directory=base_directory ) - response = await llm_client.acreate_structured_output( + response = await LLMAdapter.acreate_structured_output( text_input=text_input, system_prompt=system_prompt, response_model=PotentialNodes ) diff --git a/cognee/tasks/graph/extract_graph_from_code.py b/cognee/tasks/graph/extract_graph_from_code.py index 7e1316e02..72a6df2bb 100644 --- a/cognee/tasks/graph/extract_graph_from_code.py +++ b/cognee/tasks/graph/extract_graph_from_code.py @@ -2,22 +2,9 @@ import asyncio from typing import Type, List from pydantic import BaseModel +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.tasks.storage import add_data_points -from cognee.base_config import get_base_config - -# Framework selection -base = get_base_config() -if base.structured_output_framework == "BAML": - print(f"Using BAML framework: {base.structured_output_framework}") - from cognee.infrastructure.llm.structured_output_framework.baml_src.extraction import ( - extract_content_graph, - ) -else: - print(f"Using llitellm_instructor framework: {base.structured_output_framework}") - from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import ( - extract_content_graph, - ) async def extract_graph_from_code( @@ -31,7 +18,7 @@ async def extract_graph_from_code( - Graph nodes are stored using the `add_data_points` function for later retrieval or analysis. """ chunk_graphs = await asyncio.gather( - *[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks] + *[LLMAdapter.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks] ) for chunk_index, chunk in enumerate(data_chunks): diff --git a/cognee/tasks/graph/extract_graph_from_data.py b/cognee/tasks/graph/extract_graph_from_data.py index 9601568c8..581e548b5 100644 --- a/cognee/tasks/graph/extract_graph_from_data.py +++ b/cognee/tasks/graph/extract_graph_from_data.py @@ -6,25 +6,12 @@ from cognee.infrastructure.databases.graph import get_graph_engine from cognee.tasks.storage.add_data_points import add_data_points from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver from cognee.modules.chunking.models.DocumentChunk import DocumentChunk -from cognee.base_config import get_base_config from cognee.modules.graph.utils import ( expand_with_nodes_and_edges, retrieve_existing_edges, ) from cognee.shared.data_models import KnowledgeGraph - -# Framework selection -base = get_base_config() -if base.structured_output_framework == "BAML": - print(f"Using BAML framework: {base.structured_output_framework}") - from cognee.infrastructure.llm.structured_output_framework.baml_src.extraction import ( - extract_content_graph, - ) -else: - print(f"Using llitellm_instructor framework: {base.structured_output_framework}") - from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import ( - extract_content_graph, - ) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter async def integrate_chunk_graphs( @@ -69,7 +56,7 @@ async def extract_graph_from_data( Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model. """ chunk_graphs = await asyncio.gather( - *[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks] + *[LLMAdapter.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks] ) # Note: Filter edges with missing source or target nodes diff --git a/cognee/tasks/graph/infer_data_ontology.py b/cognee/tasks/graph/infer_data_ontology.py index 45ad74c04..d19225b04 100644 --- a/cognee/tasks/graph/infer_data_ontology.py +++ b/cognee/tasks/graph/infer_data_ontology.py @@ -15,12 +15,7 @@ from pydantic import BaseModel from cognee.modules.graph.exceptions import EntityNotFoundError from cognee.modules.ingestion.exceptions import IngestionError -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( - read_query_prompt, -) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) + from cognee.infrastructure.data.chunking.config import get_chunk_config from cognee.infrastructure.data.chunking.get_chunking_engine import get_chunk_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine @@ -32,6 +27,7 @@ from cognee.modules.data.methods.add_model_class_to_graph import ( from cognee.tasks.graph.models import NodeModel, GraphOntology from cognee.shared.data_models import KnowledgeGraph from cognee.modules.engine.utils import generate_node_id, generate_node_name +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter logger = get_logger("task:infer_data_ontology") @@ -56,11 +52,10 @@ async def extract_ontology(content: str, response_model: Type[BaseModel]): The structured ontology extracted from the content. """ - llm_client = get_llm_client() - system_prompt = read_query_prompt("extract_ontology.txt") + system_prompt = LLMAdapter.read_query_prompt("extract_ontology.txt") - ontology = await llm_client.acreate_structured_output(content, system_prompt, response_model) + ontology = await LLMAdapter.acreate_structured_output(content, system_prompt, response_model) return ontology diff --git a/cognee/tasks/summarization/summarize_code.py b/cognee/tasks/summarization/summarize_code.py index 7a5801a83..ce48fbad2 100644 --- a/cognee/tasks/summarization/summarize_code.py +++ b/cognee/tasks/summarization/summarize_code.py @@ -3,24 +3,9 @@ from typing import AsyncGenerator, Union from uuid import uuid5 from cognee.infrastructure.engine import DataPoint -from cognee.base_config import get_base_config +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter from .models import CodeSummary -# Framework selection -base = get_base_config() -if base.structured_output_framework == "BAML": - print(f"Using BAML framework for code summarization: {base.structured_output_framework}") - from cognee.infrastructure.llm.structured_output_framework.baml_src.extraction import ( - extract_code_summary, - ) -else: - print( - f"Using llitellm_instructor framework for code summarization: {base.structured_output_framework}" - ) - from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import ( - extract_code_summary, - ) - async def summarize_code( code_graph_nodes: list[DataPoint], @@ -31,7 +16,7 @@ async def summarize_code( code_data_points = [file for file in code_graph_nodes if hasattr(file, "source_code")] file_summaries = await asyncio.gather( - *[extract_code_summary(file.source_code) for file in code_data_points] + *[LLMAdapter.extract_code_summary(file.source_code) for file in code_data_points] ) file_summaries_map = { diff --git a/cognee/tasks/summarization/summarize_text.py b/cognee/tasks/summarization/summarize_text.py index f7f5de939..30a475ed5 100644 --- a/cognee/tasks/summarization/summarize_text.py +++ b/cognee/tasks/summarization/summarize_text.py @@ -3,26 +3,11 @@ from typing import Type from uuid import uuid5 from pydantic import BaseModel -from cognee.base_config import get_base_config from cognee.modules.chunking.models.DocumentChunk import DocumentChunk +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter from cognee.modules.cognify.config import get_cognify_config from .models import TextSummary -# Framework selection -base = get_base_config() -if base.structured_output_framework == "BAML": - print(f"Using BAML framework for text summarization: {base.structured_output_framework}") - from cognee.infrastructure.llm.structured_output_framework.baml_src.extraction import ( - extract_summary, - ) -else: - print( - f"Using llitellm_instructor framework for text summarization: {base.structured_output_framework}" - ) - from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import ( - extract_summary, - ) - async def summarize_text( data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel] = None @@ -58,7 +43,7 @@ async def summarize_text( summarization_model = cognee_config.summarization_model chunk_summaries = await asyncio.gather( - *[extract_summary(chunk.text, summarization_model) for chunk in data_chunks] + *[LLMAdapter.extract_summary(chunk.text, summarization_model) for chunk in data_chunks] ) summaries = [ diff --git a/cognee/tests/unit/infrastructure/databases/vector/__init__.py b/cognee/tests/unit/infrastructure/databases/vector/__init__.py index 9399921ac..e69de29bb 100644 --- a/cognee/tests/unit/infrastructure/databases/vector/__init__.py +++ b/cognee/tests/unit/infrastructure/databases/vector/__init__.py @@ -1 +0,0 @@ -# Vector database tests module diff --git a/cognee/tests/unit/infrastructure/mock_embedding_engine.py b/cognee/tests/unit/infrastructure/mock_embedding_engine.py index e600430bc..c114d1dc8 100644 --- a/cognee/tests/unit/infrastructure/mock_embedding_engine.py +++ b/cognee/tests/unit/infrastructure/mock_embedding_engine.py @@ -4,7 +4,7 @@ from typing import List from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import ( LiteLLMEmbeddingEngine, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.embedding_rate_limiter import ( +from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import ( embedding_rate_limit_async, embedding_sleep_and_retry_async, ) diff --git a/cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py b/cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py index 7d18bb6b9..25566dce1 100644 --- a/cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +++ b/cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py @@ -3,10 +3,10 @@ import time import asyncio import logging -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import ( +from cognee.infrastructure.llm.config import ( get_llm_config, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.embedding_rate_limiter import ( +from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import ( EmbeddingRateLimiter, ) from cognee.tests.unit.infrastructure.mock_embedding_engine import MockEmbeddingEngine diff --git a/cognee/tests/unit/infrastructure/test_rate_limiting_realistic.py b/cognee/tests/unit/infrastructure/test_rate_limiting_realistic.py index ec64bc1a4..c2953eb80 100644 --- a/cognee/tests/unit/infrastructure/test_rate_limiting_realistic.py +++ b/cognee/tests/unit/infrastructure/test_rate_limiting_realistic.py @@ -5,7 +5,7 @@ from cognee.shared.logging_utils import get_logger from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.rate_limiter import ( llm_rate_limiter, ) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import ( +from cognee.infrastructure.llm.config import ( get_llm_config, ) diff --git a/examples/python/agentic_reasoning_procurement_example.py b/examples/python/agentic_reasoning_procurement_example.py index 4c99ab8bb..6e39d0515 100644 --- a/examples/python/agentic_reasoning_procurement_example.py +++ b/examples/python/agentic_reasoning_procurement_example.py @@ -3,9 +3,7 @@ import logging import cognee import asyncio -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter from dotenv import load_dotenv from cognee.api.v1.search import SearchType from cognee.modules.engine.models import NodeSet @@ -187,7 +185,7 @@ async def run_procurement_example(): print(research_information) print("\nPassing research to LLM for final procurement recommendation...\n") - final_decision = await get_llm_client().acreate_structured_output( + final_decision = await LLMAdapter.acreate_structured_output( text_input=research_information, system_prompt="""You are a procurement decision assistant. Use the provided QA pairs that were collected through a research phase. Recommend the best vendor, based on pricing, delivery, warranty, policy fit, and past performance. Be concise and justify your choice with evidence. diff --git a/examples/python/graphiti_example.py b/examples/python/graphiti_example.py index c372830b3..6c0ff5f54 100644 --- a/examples/python/graphiti_example.py +++ b/examples/python/graphiti_example.py @@ -12,13 +12,7 @@ from cognee.tasks.temporal_awareness.index_graphiti_objects import ( ) from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import ( - read_query_prompt, - render_prompt, -) -from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import ( - get_llm_client, -) +from cognee.infrastructure.llm.LLMAdapter import LLMAdapter from cognee.modules.users.methods import get_default_user text_list = [ @@ -65,11 +59,10 @@ async def main(): "context": context, } - user_prompt = render_prompt("graph_context_for_question.txt", args) - system_prompt = read_query_prompt("answer_simple_question_restricted.txt") + user_prompt = LLMAdapter.render_prompt("graph_context_for_question.txt", args) + system_prompt = LLMAdapter.read_query_prompt("answer_simple_question_restricted.txt") - llm_client = get_llm_client() - computed_answer = await llm_client.acreate_structured_output( + computed_answer = await LLMAdapter.acreate_structured_output( text_input=user_prompt, system_prompt=system_prompt, response_model=str,