feat: Add adapter for structured output and LLM usage

This commit is contained in:
Igor Ilic 2025-08-05 18:49:32 +02:00
parent 9ca6750407
commit 7761b70229
54 changed files with 192 additions and 2193 deletions

View file

@ -2,13 +2,9 @@ from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
render_prompt,
)
from cognee.low_level import DataPoint
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm import LLMAdapter
from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.models import NodeSet
from cognee.tasks.storage import add_data_points, index_graph_edges
@ -95,16 +91,17 @@ async def get_origin_edges(data: str, rules: List[Rule]) -> list[Any]:
async def add_rule_associations(data: str, rules_nodeset_name: str):
llm_client = get_llm_client()
graph_engine = await get_graph_engine()
existing_rules = await get_existing_rules(rules_nodeset_name=rules_nodeset_name)
user_context = {"chat": data, "rules": existing_rules}
user_prompt = render_prompt("coding_rule_association_agent_user.txt", context=user_context)
system_prompt = render_prompt("coding_rule_association_agent_system.txt", context={})
user_prompt = LLMAdapter.render_prompt(
"coding_rule_association_agent_user.txt", context=user_context
)
system_prompt = LLMAdapter.render_prompt("coding_rule_association_agent_system.txt", context={})
rule_list = await llm_client.acreate_structured_output(
rule_list = await LLMAdapter.acreate_structured_output(
text_input=user_prompt, system_prompt=system_prompt, response_model=RuleSet
)

View file

@ -7,7 +7,7 @@ from cognee.modules.cognify.config import get_cognify_config
from cognee.infrastructure.data.chunking.config import get_chunk_config
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import (
from cognee.infrastructure.llm.config import (
get_llm_config,
)
from cognee.infrastructure.databases.relational import get_relational_config, get_migration_config

View file

@ -17,7 +17,7 @@ from cognee.api.v1.responses.models import (
)
from cognee.api.v1.responses.dispatch_function import dispatch_function
from cognee.api.v1.responses.default_tools import DEFAULT_TOOLS
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import (
from cognee.infrastructure.llm.config import (
get_llm_config,
)
from cognee.modules.users.models import User

View file

@ -15,9 +15,7 @@ class BaseConfig(BaseSettings):
langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST")
default_user_email: Optional[str] = os.getenv("DEFAULT_USER_EMAIL")
default_user_password: Optional[str] = os.getenv("DEFAULT_USER_PASSWORD")
structured_output_framework: str = os.getenv(
"STRUCTURED_OUTPUT_FRAMEWORK", "llitellm_instructor"
)
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:

View file

@ -1,15 +1,10 @@
from typing import Any, Dict, List
from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.eval_framework.evaluation.base_eval_adapter import BaseEvalAdapter
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
render_prompt,
)
from cognee.eval_framework.eval_config import EvalConfig
from cognee.infrastructure.llm import LLMAdapter
class CorrectnessEvaluation(BaseModel):
"""Response model containing evaluation score and explanation."""
@ -24,17 +19,16 @@ class DirectLLMEvalAdapter(BaseEvalAdapter):
config = EvalConfig()
self.system_prompt_path = config.direct_llm_system_prompt
self.eval_prompt_path = config.direct_llm_eval_prompt
self.llm_client = get_llm_client()
async def evaluate_correctness(
self, question: str, answer: str, golden_answer: str
) -> Dict[str, Any]:
args = {"question": question, "answer": answer, "golden_answer": golden_answer}
user_prompt = render_prompt(self.eval_prompt_path, args)
system_prompt = read_query_prompt(self.system_prompt_path)
user_prompt = LLMAdapter.render_prompt(self.eval_prompt_path, args)
system_prompt = LLMAdapter.read_query_prompt(self.system_prompt_path)
evaluation = await self.llm_client.acreate_structured_output(
evaluation = await LLMAdapter.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=CorrectnessEvaluation,

View file

@ -5,7 +5,7 @@ import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.tokenizer.TikToken import (
from cognee.infrastructure.llm.tokenizer.TikToken import (
TikTokenTokenizer,
)

View file

@ -7,19 +7,19 @@ import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.tokenizer.Gemini import (
from cognee.infrastructure.llm.tokenizer.Gemini import (
GeminiTokenizer,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.tokenizer.HuggingFace import (
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
HuggingFaceTokenizer,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.tokenizer import (
from cognee.infrastructure.llm.tokenizer.Mistral import (
MistralTokenizer,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.tokenizer.TikToken import (
from cognee.infrastructure.llm.tokenizer.TikToken import (
TikTokenTokenizer,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.embedding_rate_limiter import (
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
embedding_rate_limit_async,
embedding_sleep_and_retry_async,
)

View file

@ -7,10 +7,10 @@ import os
import aiohttp.http_exceptions
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.tokenizer.HuggingFace import (
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
HuggingFaceTokenizer,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.embedding_rate_limiter import (
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
embedding_rate_limit_async,
embedding_sleep_and_retry_async,
)

View file

@ -1,5 +1,5 @@
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import (
from cognee.infrastructure.llm.config import (
get_llm_config,
)
from .EmbeddingEngine import EmbeddingEngine

View file

@ -1,15 +1,14 @@
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import (
from cognee.infrastructure.llm.config import (
get_llm_config,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.utils import (
from cognee.infrastructure.llm.utils import (
get_max_chunk_tokens,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.utils import (
from cognee.infrastructure.llm.utils import (
test_llm_connection,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.utils import (
from cognee.infrastructure.llm.utils import (
test_embedding_connection,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm import (
rate_limiter,
)
from LLMAdapter import LLMAdapter

View file

@ -1,156 +0,0 @@
"""Adapter for Generic API LLM provider API"""
import logging
import litellm
import instructor
from typing import Type
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
from instructor.exceptions import InstructorRetryException
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
class GenericAPIAdapter(LLMInterface):
"""
Adapter for Generic API LLM provider API.
This class initializes the API adapter with necessary credentials and configurations for
interacting with a language model. It provides methods for creating structured outputs
based on user input and system prompts.
Public methods:
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
Type[BaseModel]) -> BaseModel
"""
name: str
model: str
api_key: str
def __init__(
self,
endpoint,
api_key: str,
model: str,
name: str,
max_tokens: int,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.name = name
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.max_tokens = max_tokens
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
)
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""
Generate a response from a user query.
This asynchronous method sends a user query and a system prompt to a language model and
retrieves the generated response. It handles API communication and retries up to a
specified limit in case of request failures.
Parameters:
-----------
- text_input (str): The input text from the user to generate a response for.
- system_prompt (str): A prompt that provides context or instructions for the
response generation.
- response_model (Type[BaseModel]): A Pydantic model that defines the structure of
the expected response.
Returns:
--------
- BaseModel: An instance of the specified response model containing the structured
output from the language model.
"""
try:
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_base=self.endpoint,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
InstructorRetryException,
) as error:
if (
isinstance(error, InstructorRetryException)
and "content management policy" not in str(error).lower()
):
raise error
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
try:
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
InstructorRetryException,
) as error:
if (
isinstance(error, InstructorRetryException)
and "content management policy" not in str(error).lower()
):
raise error
else:
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)

View file

@ -1,185 +0,0 @@
import os
from typing import Optional, ClassVar
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import model_validator
from baml_py import ClientRegistry
class LLMConfig(BaseSettings):
"""
Configuration settings for the LLM (Large Language Model) provider and related options.
Public instance variables include:
- llm_provider
- llm_model
- llm_endpoint
- llm_api_key
- llm_api_version
- llm_temperature
- llm_streaming
- llm_max_tokens
- transcription_model
- graph_prompt_path
- llm_rate_limit_enabled
- llm_rate_limit_requests
- llm_rate_limit_interval
- embedding_rate_limit_enabled
- embedding_rate_limit_requests
- embedding_rate_limit_interval
Public methods include:
- ensure_env_vars_for_ollama
- to_dict
"""
llm_provider: str = "openai"
llm_model: str = "gpt-4o-mini"
llm_endpoint: str = ""
llm_api_key: Optional[str] = None
llm_api_version: Optional[str] = None
llm_temperature: float = 0.0
llm_streaming: bool = False
llm_max_tokens: int = 16384
transcription_model: str = "whisper-1"
graph_prompt_path: str = "generate_graph_prompt.txt"
llm_rate_limit_enabled: bool = False
llm_rate_limit_requests: int = 60
llm_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
embedding_rate_limit_enabled: bool = False
embedding_rate_limit_requests: int = 60
embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
baml_registry: ClassVar[ClientRegistry] = ClientRegistry()
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def model_post_init(self, __context) -> None:
"""Initialize the BAML registry after the model is created."""
self.baml_registry.add_llm_client(
name=self.llm_provider,
provider=self.llm_provider,
options={
"model": self.llm_model,
"temperature": self.llm_temperature,
"api_key": self.llm_api_key,
},
)
# Sets the primary client
self.baml_registry.set_primary(self.llm_provider)
@model_validator(mode="after")
def ensure_env_vars_for_ollama(self) -> "LLMConfig":
"""
Validate required environment variables for the 'ollama' LLM provider.
Raises ValueError if some required environment variables are set without the others.
Only checks are performed when 'llm_provider' is set to 'ollama'.
Returns:
--------
- 'LLMConfig': The instance of LLMConfig after validation.
"""
if self.llm_provider != "ollama":
# Skip checks unless provider is "ollama"
return self
def is_env_set(var_name: str) -> bool:
"""
Check if a given environment variable is set and non-empty.
Parameters:
-----------
- var_name (str): The name of the environment variable to check.
Returns:
--------
- bool: True if the environment variable exists and is not empty, otherwise False.
"""
val = os.environ.get(var_name)
return val is not None and val.strip() != ""
#
# 1. Check LLM environment variables
#
llm_env_vars = {
"LLM_MODEL": is_env_set("LLM_MODEL"),
"LLM_ENDPOINT": is_env_set("LLM_ENDPOINT"),
"LLM_API_KEY": is_env_set("LLM_API_KEY"),
}
if any(llm_env_vars.values()) and not all(llm_env_vars.values()):
missing_llm = [key for key, is_set in llm_env_vars.items() if not is_set]
raise ValueError(
"You have set some but not all of the required environment variables "
f"for LLM usage (LLM_MODEL, LLM_ENDPOINT, LLM_API_KEY). Missing: {missing_llm}"
)
#
# 2. Check embedding environment variables
#
embedding_env_vars = {
"EMBEDDING_PROVIDER": is_env_set("EMBEDDING_PROVIDER"),
"EMBEDDING_MODEL": is_env_set("EMBEDDING_MODEL"),
"EMBEDDING_DIMENSIONS": is_env_set("EMBEDDING_DIMENSIONS"),
"HUGGINGFACE_TOKENIZER": is_env_set("HUGGINGFACE_TOKENIZER"),
}
if any(embedding_env_vars.values()) and not all(embedding_env_vars.values()):
missing_embed = [key for key, is_set in embedding_env_vars.items() if not is_set]
raise ValueError(
"You have set some but not all of the required environment variables "
"for embeddings (EMBEDDING_PROVIDER, EMBEDDING_MODEL, "
"EMBEDDING_DIMENSIONS, HUGGINGFACE_TOKENIZER). Missing: "
f"{missing_embed}"
)
return self
def to_dict(self) -> dict:
"""
Convert the LLMConfig instance into a dictionary representation.
Returns:
--------
- dict: A dictionary containing the configuration settings of the LLMConfig
instance.
"""
return {
"provider": self.llm_provider,
"model": self.llm_model,
"endpoint": self.llm_endpoint,
"api_key": self.llm_api_key,
"api_version": self.llm_api_version,
"temperature": self.llm_temperature,
"streaming": self.llm_streaming,
"max_tokens": self.llm_max_tokens,
"transcription_model": self.transcription_model,
"graph_prompt_path": self.graph_prompt_path,
"rate_limit_enabled": self.llm_rate_limit_enabled,
"rate_limit_requests": self.llm_rate_limit_requests,
"rate_limit_interval": self.llm_rate_limit_interval,
"embedding_rate_limit_enabled": self.embedding_rate_limit_enabled,
"embedding_rate_limit_requests": self.embedding_rate_limit_requests,
"embedding_rate_limit_interval": self.embedding_rate_limit_interval,
}
@lru_cache
def get_llm_config():
"""
Retrieve and cache the LLM configuration.
This function returns an instance of the LLMConfig class. It leverages
caching to ensure that repeated calls do not create new instances,
but instead return the already created configuration object.
Returns:
--------
- LLMConfig: An instance of the LLMConfig class containing the configuration for the
LLM.
"""
return LLMConfig()

View file

@ -1,109 +0,0 @@
// Content classification data models - matching shared/data_models.py
class TextContent {
type string
subclass string[]
}
class AudioContent {
type string
subclass string[]
}
class ImageContent {
type string
subclass string[]
}
class VideoContent {
type string
subclass string[]
}
class MultimediaContent {
type string
subclass string[]
}
class Model3DContent {
type string
subclass string[]
}
class ProceduralContent {
type string
subclass string[]
}
class ContentLabel {
content_type "text" | "audio" | "image" | "video" | "multimedia" | "3d_model" | "procedural"
type string
subclass string[]
}
class DefaultContentPrediction {
label ContentLabel
}
// Content classification prompt template
template_string ClassifyContentPrompt() #"
You are a classification engine and should classify content. Make sure to use one of the existing classification options and not invent your own.
Classify the content into one of these main categories and their relevant subclasses:
**TEXT CONTENT** (content_type: "text"):
- type: "TEXTUAL_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Articles, essays, and reports", "Books and manuscripts", "News stories and blog posts", "Research papers and academic publications", "Social media posts and comments", "Website content and product descriptions", "Personal narratives and stories", "Spreadsheets and tables", "Forms and surveys", "Databases and CSV files", "Source code in various programming languages", "Shell commands and scripts", "Markup languages (HTML, XML)", "Stylesheets (CSS) and configuration files (YAML, JSON, INI)", "Chat transcripts and messaging history", "Customer service logs and interactions", "Conversational AI training data", "Textbook content and lecture notes", "Exam questions and academic exercises", "E-learning course materials", "Poetry and prose", "Scripts for plays, movies, and television", "Song lyrics", "Manuals and user guides", "Technical specifications and API documentation", "Helpdesk articles and FAQs", "Contracts and agreements", "Laws, regulations, and legal case documents", "Policy documents and compliance materials", "Clinical trial reports", "Patient records and case notes", "Scientific journal articles", "Financial reports and statements", "Business plans and proposals", "Market research and analysis reports", "Ad copies and marketing slogans", "Product catalogs and brochures", "Press releases and promotional content", "Professional and formal correspondence", "Personal emails and letters", "Image and video captions", "Annotations and metadata for various media", "Vocabulary lists and grammar rules", "Language exercises and quizzes", "Other types of text data"]
**AUDIO CONTENT** (content_type: "audio"):
- type: "AUDIO_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Music tracks and albums", "Podcasts and radio broadcasts", "Audiobooks and audio guides", "Recorded interviews and speeches", "Sound effects and ambient sounds", "Other types of audio recordings"]
**IMAGE CONTENT** (content_type: "image"):
- type: "IMAGE_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Photographs and digital images", "Illustrations, diagrams, and charts", "Infographics and visual data representations", "Artwork and paintings", "Screenshots and graphical user interfaces", "Other types of images"]
**VIDEO CONTENT** (content_type: "video"):
- type: "VIDEO_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Movies and short films", "Documentaries and educational videos", "Video tutorials and how-to guides", "Animated features and cartoons", "Live event recordings and sports broadcasts", "Other types of video content"]
**MULTIMEDIA CONTENT** (content_type: "multimedia"):
- type: "MULTIMEDIA_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Interactive web content and games", "Virtual reality (VR) and augmented reality (AR) experiences", "Mixed media presentations and slide decks", "E-learning modules with integrated multimedia", "Digital exhibitions and virtual tours", "Other types of multimedia content"]
**3D MODEL CONTENT** (content_type: "3d_model"):
- type: "3D_MODEL_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Architectural renderings and building plans", "Product design models and prototypes", "3D animations and character models", "Scientific simulations and visualizations", "Virtual objects for AR/VR applications", "Other types of 3D models"]
**PROCEDURAL CONTENT** (content_type: "procedural"):
- type: "PROCEDURAL_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Tutorials and step-by-step guides", "Workflow and process descriptions", "Simulation and training exercises", "Recipes and crafting instructions", "Other types of procedural content"]
Select the most appropriate content_type, type, and relevant subclasses.
"#
// OpenAI client defined once for all BAML files
// Classification function
function ExtractCategories(content: string) -> DefaultContentPrediction {
client OpenAI
prompt #"
{{ ClassifyContentPrompt() }}
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
{{ _.role('user') }}
{{ content }}
"#
}
// Test case for classification
test ExtractCategoriesExample {
functions [ExtractCategories]
args {
content #"
Natural language processing (NLP) is an interdisciplinary subfield of computer science and information retrieval.
It deals with the interaction between computers and human language, in particular how to program computers to process and analyze large amounts of natural language data.
"#
}
}

View file

@ -1,343 +0,0 @@
class Node {
id string
name string
type string
description string
@@dynamic
}
/// doc string for edge
class Edge {
/// doc string for source_node_id
source_node_id string
target_node_id string
relationship_name string
}
class KnowledgeGraph {
nodes (Node @stream.done)[]
edges Edge[]
}
// Summarization classes
class SummarizedContent {
summary string
description string
}
class SummarizedFunction {
name string
description string
inputs string[]?
outputs string[]?
decorators string[]?
}
class SummarizedClass {
name string
description string
methods SummarizedFunction[]?
decorators string[]?
}
class SummarizedCode {
high_level_summary string
key_features string[]
imports string[]
constants string[]
classes SummarizedClass[]
functions SummarizedFunction[]
workflow_description string?
}
class DynamicKnowledgeGraph {
@@dynamic
}
// Simple template for basic extraction (fast, good quality)
template_string ExtractContentGraphPrompt() #"
You are an advanced algorithm that extracts structured data into a knowledge graph.
- **Nodes**: Entities/concepts (like Wikipedia articles).
- **Edges**: Relationships (like Wikipedia links). Use snake_case (e.g., `acted_in`).
**Rules:**
1. **Node Labeling & IDs**
- Use basic types only (e.g., "Person", "Date", "Organization").
- Avoid overly specific or generic terms (e.g., no "Mathematician" or "Entity").
- Node IDs must be human-readable names from the text (no numbers).
2. **Dates & Numbers**
- Label dates as **"Date"** in "YYYY-MM-DD" format (use available parts if incomplete).
- Properties are key-value pairs; do not use escaped quotes.
3. **Coreference Resolution**
- Use a single, complete identifier for each entity (e.g., always "John Doe" not "Joe" or "he").
4. **Relationship Labels**:
- Use descriptive, lowercase, snake_case names for edges.
- *Example*: born_in, married_to, invented_by.
- Avoid vague or generic labels like isA, relatesTo, has.
- Avoid duplicated relationships like produces, produced by.
5. **Strict Compliance**
- Follow these rules exactly. Non-compliance results in termination.
"#
// Summarization prompt template
template_string SummarizeContentPrompt() #"
You are a top-tier summarization engine. Your task is to summarize text and make it versatile.
Be brief and concise, but keep the important information and the subject.
Use synonym words where possible in order to change the wording but keep the meaning.
"#
// Code summarization prompt template
template_string SummarizeCodePrompt() #"
You are an expert code analyst. Analyze the provided source code and extract key information:
1. Provide a high-level summary of what the code does
2. List key features and functionality
3. Identify imports and dependencies
4. List constants and global variables
5. Summarize classes with their methods
6. Summarize standalone functions
7. Describe the overall workflow if applicable
Be precise and technical while remaining clear and concise.
"#
// Detailed template for complex extraction (slower, higher quality)
template_string DetailedExtractContentGraphPrompt() #"
You are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.
**Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
**Edges** represent relationships between concepts. They're akin to Wikipedia links.
The aim is to achieve simplicity and clarity in the knowledge graph.
# 1. Labeling Nodes
**Consistency**: Ensure you use basic or elementary types for node labels.
- For example, when you identify an entity representing a person, always label it as **"Person"**.
- Avoid using more specific terms like "Mathematician" or "Scientist", keep those as "profession" property.
- Don't use too generic terms like "Entity".
**Node IDs**: Never utilize integers as node IDs.
- Node IDs should be names or human-readable identifiers found in the text.
# 2. Handling Numerical Data and Dates
- For example, when you identify an entity representing a date, make sure it has type **"Date"**.
- Extract the date in the format "YYYY-MM-DD"
- If not possible to extract the whole date, extract month or year, or both if available.
- **Property Format**: Properties must be in a key-value format.
- **Quotation Marks**: Never use escaped single or double quotes within property values.
- **Naming Convention**: Use snake_case for relationship names, e.g., `acted_in`.
# 3. Coreference Resolution
- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.
If an entity, such as "John Doe", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., "Joe", "he"),
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Person's ID.
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
# 4. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination.
"#
// Guided template with step-by-step instructions
template_string GuidedExtractContentGraphPrompt() #"
You are an advanced algorithm designed to extract structured information to build a clean, consistent, and human-readable knowledge graph.
**Objective**:
- Nodes represent entities and concepts, similar to Wikipedia articles.
- Edges represent typed relationships between nodes, similar to Wikipedia hyperlinks.
- The graph must be clear, minimal, consistent, and semantically precise.
**Node Guidelines**:
1. **Label Consistency**:
- Use consistent, basic types for all node labels.
- Do not switch between granular or vague labels for the same kind of entity.
- Pick one label for each category and apply it uniformly.
- Each entity type should be in a singular form and in a case of multiple words separated by whitespaces
2. **Node Identifiers**:
- Node IDs must be human-readable and derived directly from the text.
- Prefer full names and canonical terms.
- Never use integers or autogenerated IDs.
- *Example*: Use "Marie Curie", "Theory of Evolution", "Google".
3. **Coreference Resolution**:
- Maintain one consistent node ID for each real-world entity.
- Resolve aliases, acronyms, and pronouns to the most complete form.
- *Example*: Always use "John Doe" even if later referred to as "Doe" or "he".
**Edge Guidelines**:
4. **Relationship Labels**:
- Use descriptive, lowercase, snake_case names for edges.
- *Example*: born_in, married_to, invented_by.
- Avoid vague or generic labels like isA, relatesTo, has.
5. **Relationship Direction**:
- Edges must be directional and logically consistent.
- *Example*:
- "Marie Curie" —[born_in]→ "Warsaw"
- "Radioactivity" —[discovered_by]→ "Marie Curie"
**Compliance**:
Strict adherence to these guidelines is required. Any deviation will result in immediate termination of the task.
"#
// Strict template with zero-tolerance rules
template_string StrictExtractContentGraphPrompt() #"
You are a top-tier algorithm for **extracting structured information** from unstructured text to build a **knowledge graph**.
Your primary goal is to extract:
- **Nodes**: Representing **entities** and **concepts** (like Wikipedia nodes).
- **Edges**: Representing **relationships** between those concepts (like Wikipedia links).
The resulting knowledge graph must be **simple, consistent, and human-readable**.
## 1. Node Labeling and Identification
### Node Types
Use **basic atomic types** for node labels. Always prefer general types over specific roles or professions:
- "Person" for any human.
- "Organization" for companies, institutions, etc.
- "Location" for geographic or place entities.
- "Date" for any temporal expression.
- "Event" for historical or scheduled occurrences.
- "Work" for books, films, artworks, or research papers.
- "Concept" for abstract notions or ideas.
### Node IDs
- Always assign **human-readable and unambiguous identifiers**.
- Never use numeric or autogenerated IDs.
- Prioritize **most complete form** of entity names for consistency.
## 2. Relationship Handling
- Use **snake_case** for all relationship (edge) types.
- Keep relationship types semantically clear and consistent.
- Avoid vague relation names like "related_to" unless no better alternative exists.
## 3. Strict Compliance
Follow all rules exactly. Any deviation may lead to rejection or incorrect graph construction.
"#
// OpenAI client with environment model selection
client<llm> OpenAI {
provider openai
options {
model client_registry.model
api_key client_registry.api_key
}
}
// Function that returns raw structured output (for custom objects - to be handled in Python)
function ExtractContentGraphGeneric(
content: string,
mode: "simple" | "base" | "guided" | "strict" | "custom"?,
custom_prompt_content: string?
) -> KnowledgeGraph {
client OpenAI
prompt #"
{% if mode == "base" %}
{{ DetailedExtractContentGraphPrompt() }}
{% elif mode == "guided" %}
{{ GuidedExtractContentGraphPrompt() }}
{% elif mode == "strict" %}
{{ StrictExtractContentGraphPrompt() }}
{% elif mode == "custom" and custom_prompt_content %}
{{ custom_prompt_content }}
{% else %}
{{ ExtractContentGraphPrompt() }}
{% endif %}
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
Before answering, briefly describe what you'll extract from the text, then provide the structured output.
Example format:
I'll extract the main entities and their relationships from this text...
{ ... }
{{ _.role('user') }}
{{ content }}
"#
}
// Backward-compatible function specifically for KnowledgeGraph
function ExtractDynamicContentGraph(
content: string,
mode: "simple" | "base" | "guided" | "strict" | "custom"?,
custom_prompt_content: string?
) -> DynamicKnowledgeGraph {
client OpenAI
prompt #"
{% if mode == "base" %}
{{ DetailedExtractContentGraphPrompt() }}
{% elif mode == "guided" %}
{{ GuidedExtractContentGraphPrompt() }}
{% elif mode == "strict" %}
{{ StrictExtractContentGraphPrompt() }}
{% elif mode == "custom" and custom_prompt_content %}
{{ custom_prompt_content }}
{% else %}
{{ ExtractContentGraphPrompt() }}
{% endif %}
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
Before answering, briefly describe what you'll extract from the text, then provide the structured output.
Example format:
I'll extract the main entities and their relationships from this text...
{ ... }
{{ _.role('user') }}
{{ content }}
"#
}
// Summarization functions
function SummarizeContent(content: string) -> SummarizedContent {
client OpenAI
prompt #"
{{ SummarizeContentPrompt() }}
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
{{ _.role('user') }}
{{ content }}
"#
}
function SummarizeCode(content: string) -> SummarizedCode {
client OpenAI
prompt #"
{{ SummarizeCodePrompt() }}
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
{{ _.role('user') }}
{{ content }}
"#
}
test ExtractStrictExample {
functions [ExtractContentGraphGeneric]
args {
content #"
The Python programming language was created by Guido van Rossum in 1991.
"#
mode "strict"
}
}

View file

@ -1,2 +0,0 @@
from .knowledge_graph.extract_content_graph import extract_content_graph
from .extract_summary import extract_summary, extract_code_summary

View file

@ -1,114 +0,0 @@
import os
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.baml_src.config import get_llm_config
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
from cognee.shared.data_models import SummarizedCode
from cognee.shared.logging_utils import get_logger
from baml_py import ClientRegistry
config = get_llm_config()
logger = get_logger("extract_summary_baml")
def get_mock_summarized_code():
"""Local mock function to avoid circular imports."""
return SummarizedCode(
high_level_summary="Mock code summary",
key_features=["Mock feature 1", "Mock feature 2"],
imports=["mock_import"],
constants=["MOCK_CONSTANT"],
classes=[],
functions=[],
workflow_description="Mock workflow description",
)
async def extract_summary(content: str, response_model: Type[BaseModel]):
"""
Extract summary using BAML framework.
Args:
content: The content to summarize
response_model: The Pydantic model type for the response
Returns:
BaseModel: The summarized content in the specified format
"""
config = get_llm_config()
baml_registry = ClientRegistry()
baml_registry.add_llm_client(
name="extract_category_client",
provider=config.llm_model,
options={
"model": config.llm_model,
"temperature": config.llm_temperature,
"api_key": config.llm_api_key,
},
)
baml_registry.set_primary("extract_category_client")
# Use BAML's SummarizeContent function
summary_result = await b.SummarizeContent(
content, baml_options={"client_registry": baml_registry}
)
# Convert BAML result to the expected response model
if response_model is SummarizedCode:
# If it's asking for SummarizedCode but we got SummarizedContent,
# we need to use SummarizeCode instead
code_result = await b.SummarizeCode(
content, baml_options={"client_registry": baml_registry}
)
return code_result
else:
# For other models, return the summary result
return summary_result
async def extract_code_summary(content: str):
"""
Extract code summary using BAML framework with mocking support.
Args:
content: The code content to summarize
Returns:
SummarizedCode: The summarized code information
"""
enable_mocking = os.getenv("MOCK_CODE_SUMMARY", "false")
if isinstance(enable_mocking, bool):
enable_mocking = str(enable_mocking).lower()
enable_mocking = enable_mocking in ("true", "1", "yes")
if enable_mocking:
result = get_mock_summarized_code()
return result
else:
try:
config = get_llm_config()
baml_registry = ClientRegistry()
baml_registry.add_llm_client(
name="extract_content_category",
provider=config.llm_provider,
options={
"model": config.llm_model,
"temperature": config.llm_temperature,
"api_key": config.llm_api_key,
},
)
baml_registry.set_primary("extract_content_category")
result = await b.SummarizeCode(content, baml_options={"client_registry": baml_registry})
except Exception as e:
logger.error(
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e
)
result = get_mock_summarized_code()
return result

View file

@ -1,114 +0,0 @@
import os
from typing import Type
from pydantic import BaseModel
from baml_py import ClientRegistry
from cognee.shared.logging_utils import get_logger
from cognee.shared.data_models import SummarizedCode
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
from cognee.infrastructure.llm.structured_output_framework.baml_src.config import get_llm_config
config = get_llm_config()
logger = get_logger("extract_summary_baml")
def get_mock_summarized_code():
"""Local mock function to avoid circular imports."""
return SummarizedCode(
high_level_summary="Mock code summary",
key_features=["Mock feature 1", "Mock feature 2"],
imports=["mock_import"],
constants=["MOCK_CONSTANT"],
classes=[],
functions=[],
workflow_description="Mock workflow description",
)
async def extract_summary(content: str, response_model: Type[BaseModel]):
"""
Extract summary using BAML framework.
Args:
content: The content to summarize
response_model: The Pydantic model type for the response
Returns:
BaseModel: The summarized content in the specified format
"""
config = get_llm_config()
baml_registry = ClientRegistry()
baml_registry.add_llm_client(
name="def",
provider="openai",
options={
"model": config.llm_model,
"temperature": config.llm_temperature,
"api_key": config.llm_api_key,
},
)
baml_registry.set_primary("def")
# Use BAML's SummarizeContent function
summary_result = await b.SummarizeContent(
content, baml_options={"client_registry": baml_registry}
)
# Convert BAML result to the expected response model
if response_model is SummarizedCode:
# If it's asking for SummarizedCode but we got SummarizedContent,
# we need to use SummarizeCode instead
code_result = await b.SummarizeCode(
content, baml_options={"client_registry": config.baml_registry}
)
return code_result
else:
# For other models, return the summary result
return summary_result
async def extract_code_summary(content: str):
"""
Extract code summary using BAML framework with mocking support.
Args:
content: The code content to summarize
Returns:
SummarizedCode: The summarized code information
"""
enable_mocking = os.getenv("MOCK_CODE_SUMMARY", "false")
if isinstance(enable_mocking, bool):
enable_mocking = str(enable_mocking).lower()
enable_mocking = enable_mocking in ("true", "1", "yes")
if enable_mocking:
result = get_mock_summarized_code()
return result
else:
try:
config = get_llm_config()
baml_registry = ClientRegistry()
baml_registry.add_llm_client(
name="def",
provider="openai",
options={
"model": config.llm_model,
"temperature": config.llm_temperature,
"api_key": config.llm_api_key,
},
)
baml_registry.set_primary("def")
result = await b.SummarizeCode(content, baml_options={"client_registry": baml_registry})
except Exception as e:
logger.error(
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e
)
result = get_mock_summarized_code()
return result

View file

@ -1,49 +0,0 @@
from baml_py import ClientRegistry
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.baml_src.config import get_llm_config
from cognee.shared.logging_utils import get_logger, setup_logging
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
config = get_llm_config()
async def extract_content_graph(
content: str, response_model: Type[BaseModel], mode: str = "simple"
):
config = get_llm_config()
setup_logging()
get_logger(level="INFO")
baml_registry = ClientRegistry()
baml_registry.add_llm_client(
name="extract_content_client",
provider=config.llm_provider,
options={
"model": config.llm_model,
"temperature": config.llm_temperature,
"api_key": config.llm_api_key,
},
)
baml_registry.set_primary("extract_content_client")
# if response_model:
# # tb = TypeBuilder()
# # country = tb.union \
# # ([tb.literal_string("USA"), tb.literal_string("UK"), tb.literal_string("Germany"), tb.literal_string("other")])
# # tb.Node.add_property("country", country)
#
# graph = await b.ExtractDynamicContentGraph(
# content, mode=mode, baml_options={"client_registry": baml_registry}
# )
#
# return graph
# else:
graph = await b.ExtractContentGraphGeneric(
content, mode=mode, baml_options={"client_registry": baml_registry}
)
return graph

View file

@ -1,18 +0,0 @@
// This helps use auto generate libraries you can use in the language of
// your choice. You can have multiple generators if you use multiple languages.
// Just ensure that the output_dir is different for each generator.
generator target {
// Valid values: "python/pydantic", "typescript", "ruby/sorbet", "rest/openapi"
output_type "python/pydantic"
// Where the generated code will be saved (relative to baml_src/)
output_dir "../baml/"
// The version of the BAML package you have installed (e.g. same version as your baml-py or @boundaryml/baml).
// The BAML VSCode extension version should also match this version.
version "0.201.0"
// Valid values: "sync", "async"
// This controls what `b.FunctionName()` will be (sync or async).
default_client_mode sync
}

View file

@ -1,2 +1,3 @@
from .knowledge_graph.extract_content_graph import extract_content_graph
from .extract_categories import extract_categories
from .extract_summary import extract_summary, extract_code_summary

View file

@ -7,7 +7,7 @@ from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.l
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
render_prompt,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import (
from cognee.infrastructure.llm.config import (
get_llm_config,
)

View file

@ -1,176 +0,0 @@
import os
from typing import Optional
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import model_validator
class LLMConfig(BaseSettings):
"""
Configuration settings for the LLM (Large Language Model) provider and related options.
Public instance variables include:
- llm_provider
- llm_model
- llm_endpoint
- llm_api_key
- llm_api_version
- llm_temperature
- llm_streaming
- llm_max_tokens
- transcription_model
- graph_prompt_path
- llm_rate_limit_enabled
- llm_rate_limit_requests
- llm_rate_limit_interval
- embedding_rate_limit_enabled
- embedding_rate_limit_requests
- embedding_rate_limit_interval
Public methods include:
- ensure_env_vars_for_ollama
- to_dict
"""
llm_provider: str = "openai"
llm_model: str = "gpt-4o-mini"
llm_endpoint: str = ""
llm_api_key: Optional[str] = None
llm_api_version: Optional[str] = None
llm_temperature: float = 0.0
llm_streaming: bool = False
llm_max_tokens: int = 16384
transcription_model: str = "whisper-1"
graph_prompt_path: str = "generate_graph_prompt.txt"
llm_rate_limit_enabled: bool = False
llm_rate_limit_requests: int = 60
llm_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
embedding_rate_limit_enabled: bool = False
embedding_rate_limit_requests: int = 60
embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
fallback_api_key: str = ""
fallback_endpoint: str = ""
fallback_model: str = ""
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@model_validator(mode="after")
def ensure_env_vars_for_ollama(self) -> "LLMConfig":
"""
Validate required environment variables for the 'ollama' LLM provider.
Raises ValueError if some required environment variables are set without the others.
Only checks are performed when 'llm_provider' is set to 'ollama'.
Returns:
--------
- 'LLMConfig': The instance of LLMConfig after validation.
"""
if self.llm_provider != "ollama":
# Skip checks unless provider is "ollama"
return self
def is_env_set(var_name: str) -> bool:
"""
Check if a given environment variable is set and non-empty.
Parameters:
-----------
- var_name (str): The name of the environment variable to check.
Returns:
--------
- bool: True if the environment variable exists and is not empty, otherwise False.
"""
val = os.environ.get(var_name)
return val is not None and val.strip() != ""
#
# 1. Check LLM environment variables
#
llm_env_vars = {
"LLM_MODEL": is_env_set("LLM_MODEL"),
"LLM_ENDPOINT": is_env_set("LLM_ENDPOINT"),
"LLM_API_KEY": is_env_set("LLM_API_KEY"),
}
if any(llm_env_vars.values()) and not all(llm_env_vars.values()):
missing_llm = [key for key, is_set in llm_env_vars.items() if not is_set]
raise ValueError(
"You have set some but not all of the required environment variables "
f"for LLM usage (LLM_MODEL, LLM_ENDPOINT, LLM_API_KEY). Missing: {missing_llm}"
)
#
# 2. Check embedding environment variables
#
embedding_env_vars = {
"EMBEDDING_PROVIDER": is_env_set("EMBEDDING_PROVIDER"),
"EMBEDDING_MODEL": is_env_set("EMBEDDING_MODEL"),
"EMBEDDING_DIMENSIONS": is_env_set("EMBEDDING_DIMENSIONS"),
"HUGGINGFACE_TOKENIZER": is_env_set("HUGGINGFACE_TOKENIZER"),
}
if any(embedding_env_vars.values()) and not all(embedding_env_vars.values()):
missing_embed = [key for key, is_set in embedding_env_vars.items() if not is_set]
raise ValueError(
"You have set some but not all of the required environment variables "
"for embeddings (EMBEDDING_PROVIDER, EMBEDDING_MODEL, "
"EMBEDDING_DIMENSIONS, HUGGINGFACE_TOKENIZER). Missing: "
f"{missing_embed}"
)
return self
def to_dict(self) -> dict:
"""
Convert the LLMConfig instance into a dictionary representation.
Returns:
--------
- dict: A dictionary containing the configuration settings of the LLMConfig
instance.
"""
return {
"provider": self.llm_provider,
"model": self.llm_model,
"endpoint": self.llm_endpoint,
"api_key": self.llm_api_key,
"api_version": self.llm_api_version,
"temperature": self.llm_temperature,
"streaming": self.llm_streaming,
"max_tokens": self.llm_max_tokens,
"transcription_model": self.transcription_model,
"graph_prompt_path": self.graph_prompt_path,
"rate_limit_enabled": self.llm_rate_limit_enabled,
"rate_limit_requests": self.llm_rate_limit_requests,
"rate_limit_interval": self.llm_rate_limit_interval,
"embedding_rate_limit_enabled": self.embedding_rate_limit_enabled,
"embedding_rate_limit_requests": self.embedding_rate_limit_requests,
"embedding_rate_limit_interval": self.embedding_rate_limit_interval,
"fallback_api_key": self.fallback_api_key,
"fallback_endpoint": self.fallback_endpoint,
"fallback_model": self.fallback_model,
}
@lru_cache
def get_llm_config():
"""
Retrieve and cache the LLM configuration.
This function returns an instance of the LLMConfig class. It leverages
caching to ensure that repeated calls do not create new instances,
but instead return the already created configuration object.
Returns:
--------
- LLMConfig: An instance of the LLMConfig class containing the configuration for the
LLM.
"""
return LLMConfig()

View file

@ -1,550 +0,0 @@
import threading
import logging
import functools
import os
import time
import asyncio
import random
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import (
get_llm_config,
)
logger = get_logger()
# Common error patterns that indicate rate limiting
RATE_LIMIT_ERROR_PATTERNS = [
"rate limit",
"rate_limit",
"ratelimit",
"too many requests",
"retry after",
"capacity",
"quota",
"limit exceeded",
"tps limit exceeded",
"request limit exceeded",
"maximum requests",
"exceeded your current quota",
"throttled",
"throttling",
]
# Default retry settings
DEFAULT_MAX_RETRIES = 5
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
class EmbeddingRateLimiter:
"""
Rate limiter for embedding API calls.
This class implements a singleton pattern to ensure that rate limiting
is consistent across all embedding requests. It uses the limits
library with a moving window strategy to control request rates.
The rate limiter uses the same configuration as the LLM API rate limiter
but uses a separate key to track embedding API calls independently.
Public Methods:
- get_instance
- reset_instance
- hit_limit
- wait_if_needed
- async_wait_if_needed
Instance Variables:
- enabled
- requests_limit
- interval_seconds
- request_times
- lock
"""
_instance = None
lock = threading.Lock()
@classmethod
def get_instance(cls):
"""
Retrieve the singleton instance of the EmbeddingRateLimiter.
This method ensures that only one instance of the class exists and
is thread-safe. It lazily initializes the instance if it doesn't
already exist.
Returns:
--------
The singleton instance of the EmbeddingRateLimiter class.
"""
if cls._instance is None:
with cls.lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset_instance(cls):
"""
Reset the singleton instance of the EmbeddingRateLimiter.
This method is thread-safe and sets the instance to None, allowing
for a new instance to be created when requested again.
"""
with cls.lock:
cls._instance = None
def __init__(self):
config = get_llm_config()
self.enabled = config.embedding_rate_limit_enabled
self.requests_limit = config.embedding_rate_limit_requests
self.interval_seconds = config.embedding_rate_limit_interval
self.request_times = []
self.lock = threading.Lock()
logging.info(
f"EmbeddingRateLimiter initialized: enabled={self.enabled}, "
f"requests_limit={self.requests_limit}, interval_seconds={self.interval_seconds}"
)
def hit_limit(self) -> bool:
"""
Check if the current request would exceed the rate limit.
This method checks if the rate limiter is enabled and evaluates
the number of requests made in the elapsed interval.
Returns:
- bool: True if the rate limit would be exceeded, False otherwise.
Returns:
--------
- bool: True if the rate limit would be exceeded, otherwise False.
"""
if not self.enabled:
return False
current_time = time.time()
with self.lock:
# Remove expired request times
cutoff_time = current_time - self.interval_seconds
self.request_times = [t for t in self.request_times if t > cutoff_time]
# Check if adding a new request would exceed the limit
if len(self.request_times) >= self.requests_limit:
logger.info(
f"Rate limit hit: {len(self.request_times)} requests in the last {self.interval_seconds} seconds"
)
return True
# Otherwise, we're under the limit
return False
def wait_if_needed(self) -> float:
"""
Block until a request can be made without exceeding the rate limit.
This method will wait if the current request would exceed the
rate limit and returns the time waited in seconds.
Returns:
- float: Time waited in seconds before a request is allowed.
Returns:
--------
- float: Time waited in seconds before proceeding.
"""
if not self.enabled:
return 0
wait_time = 0
start_time = time.time()
while self.hit_limit():
time.sleep(0.5) # Poll every 0.5 seconds
wait_time = time.time() - start_time
# Record this request
with self.lock:
self.request_times.append(time.time())
return wait_time
async def async_wait_if_needed(self) -> float:
"""
Asynchronously wait until a request can be made without exceeding the rate limit.
This method will wait if the current request would exceed the
rate limit and returns the time waited in seconds.
Returns:
- float: Time waited in seconds before a request is allowed.
Returns:
--------
- float: Time waited in seconds before proceeding.
"""
if not self.enabled:
return 0
wait_time = 0
start_time = time.time()
while self.hit_limit():
await asyncio.sleep(0.5) # Poll every 0.5 seconds
wait_time = time.time() - start_time
# Record this request
with self.lock:
self.request_times.append(time.time())
return wait_time
def embedding_rate_limit_sync(func):
"""
Apply rate limiting to a synchronous embedding function.
Parameters:
-----------
- func: Function to decorate with rate limiting logic.
Returns:
--------
Returns the decorated function that applies rate limiting.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
"""
Wrap the given function with rate limiting logic to control the embedding API usage.
Checks if the rate limit has been exceeded before allowing the function to execute. If
the limit is hit, it logs a warning and raises an EmbeddingException. Otherwise, it
updates the request count and proceeds to call the original function.
Parameters:
-----------
- *args: Variable length argument list for the wrapped function.
- **kwargs: Keyword arguments for the wrapped function.
Returns:
--------
Returns the result of the wrapped function if rate limiting conditions are met.
"""
limiter = EmbeddingRateLimiter.get_instance()
# Check if rate limiting is enabled and if we're at the limit
if limiter.hit_limit():
error_msg = "Embedding API rate limit exceeded"
logger.warning(error_msg)
# Create a custom embedding rate limit exception
from cognee.infrastructure.databases.exceptions.EmbeddingException import (
EmbeddingException,
)
raise EmbeddingException(error_msg)
# Add this request to the counter and proceed
limiter.wait_if_needed()
return func(*args, **kwargs)
return wrapper
def embedding_rate_limit_async(func):
"""
Decorator that applies rate limiting to an asynchronous embedding function.
Parameters:
-----------
- func: Async function to decorate.
Returns:
--------
Returns the decorated async function that applies rate limiting.
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
"""
Handle function calls with embedding rate limiting.
This asynchronous wrapper checks if the embedding API rate limit is exceeded before
allowing the function to execute. If the limit is exceeded, it logs a warning and raises
an EmbeddingException. If not, it waits as necessary and proceeds with the function
call.
Parameters:
-----------
- *args: Positional arguments passed to the wrapped function.
- **kwargs: Keyword arguments passed to the wrapped function.
Returns:
--------
Returns the result of the wrapped function after handling rate limiting.
"""
limiter = EmbeddingRateLimiter.get_instance()
# Check if rate limiting is enabled and if we're at the limit
if limiter.hit_limit():
error_msg = "Embedding API rate limit exceeded"
logger.warning(error_msg)
# Create a custom embedding rate limit exception
from cognee.infrastructure.databases.exceptions.EmbeddingException import (
EmbeddingException,
)
raise EmbeddingException(error_msg)
# Add this request to the counter and proceed
await limiter.async_wait_if_needed()
return await func(*args, **kwargs)
return wrapper
def embedding_sleep_and_retry_sync(max_retries=5, base_backoff=1.0, jitter=0.5):
"""
Add retry with exponential backoff for synchronous embedding functions.
Parameters:
-----------
- max_retries: Maximum number of retries before giving up. (default 5)
- base_backoff: Base backoff time in seconds for retry intervals. (default 1.0)
- jitter: Jitter factor to randomize the backoff time to avoid collision. (default
0.5)
Returns:
--------
A decorator that retries the wrapped function on rate limit errors, applying
exponential backoff with jitter.
"""
def decorator(func):
"""
Wraps a function to apply retry logic on rate limit errors.
Parameters:
-----------
- func: The function to be wrapped with retry logic.
Returns:
--------
Returns the wrapped function with retry logic applied.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
"""
Retry the execution of a function with backoff on failure due to rate limit errors.
This wrapper function will call the specified function and if it raises an exception, it
will handle retries according to defined conditions. It will check the environment for a
DISABLE_RETRIES flag to determine whether to retry or propagate errors immediately
during tests. If the error is identified as a rate limit error, it will apply an
exponential backoff strategy with jitter before retrying, up to a maximum number of
retries. If the retries are exhausted, it raises the last encountered error.
Parameters:
-----------
- *args: Positional arguments passed to the wrapped function.
- **kwargs: Keyword arguments passed to the wrapped function.
Returns:
--------
Returns the result of the wrapped function if successful; otherwise, raises the last
error encountered after maximum retries are exhausted.
"""
# If DISABLE_RETRIES is set, don't retry for testing purposes
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
"true",
"1",
"yes",
)
retries = 0
last_error = None
while retries <= max_retries:
try:
return func(*args, **kwargs)
except Exception as e:
# Check if this is a rate limit error
error_str = str(e).lower()
error_type = type(e).__name__
is_rate_limit = any(
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
)
if disable_retries:
# For testing, propagate the exception immediately
raise
if is_rate_limit and retries < max_retries:
# Calculate backoff with jitter
backoff = (
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
)
logger.warning(
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
f"(attempt {retries + 1}/{max_retries}): "
f"({error_str!r}, {error_type!r})"
)
time.sleep(backoff)
retries += 1
last_error = e
else:
# Not a rate limit error or max retries reached, raise
raise
# If we exit the loop due to max retries, raise the last error
if last_error:
raise last_error
return wrapper
return decorator
def embedding_sleep_and_retry_async(max_retries=5, base_backoff=1.0, jitter=0.5):
"""
Add retry logic with exponential backoff for asynchronous embedding functions.
This decorator retries the wrapped asynchronous function upon encountering rate limit
errors, utilizing exponential backoff with optional jitter to space out retry attempts.
It allows for a maximum number of retries before giving up and raising the last error
encountered.
Parameters:
-----------
- max_retries: Maximum number of retries allowed before giving up. (default 5)
- base_backoff: Base amount of time in seconds to wait before retrying after a rate
limit error. (default 1.0)
- jitter: Amount of randomness to add to the backoff duration to help mitigate burst
issues on retries. (default 0.5)
Returns:
--------
Returns a decorated asynchronous function that implements the retry logic on rate
limit errors.
"""
def decorator(func):
"""
Handle retries for an async function with exponential backoff and jitter.
Parameters:
-----------
- func: An asynchronous function to be wrapped with retry logic.
Returns:
--------
Returns the wrapper function that manages the retry behavior for the wrapped async
function.
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
"""
Handle retries for an async function with exponential backoff and jitter.
If the environment variable DISABLE_RETRIES is set to true, 1, or yes, the function will
not retry on errors.
It attempts to call the wrapped function until it succeeds or the maximum number of
retries is reached. If an exception occurs, it checks if it's a rate limit error to
determine if a retry is needed.
Parameters:
-----------
- *args: Positional arguments passed to the wrapped function.
- **kwargs: Keyword arguments passed to the wrapped function.
Returns:
--------
Returns the result of the wrapped async function if successful; raises the last
encountered error if all retries fail.
"""
# If DISABLE_RETRIES is set, don't retry for testing purposes
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
"true",
"1",
"yes",
)
retries = 0
last_error = None
while retries <= max_retries:
try:
return await func(*args, **kwargs)
except Exception as e:
# Check if this is a rate limit error
error_str = str(e).lower()
error_type = type(e).__name__
is_rate_limit = any(
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
)
if disable_retries:
# For testing, propagate the exception immediately
raise
if is_rate_limit and retries < max_retries:
# Calculate backoff with jitter
backoff = (
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
)
logger.warning(
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
f"(attempt {retries + 1}/{max_retries}): "
f"({error_str!r}, {error_type!r})"
)
await asyncio.sleep(backoff)
retries += 1
last_error = e
else:
# Not a rate limit error or max retries reached, raise
raise
# If we exit the loop due to max retries, raise the last error
if last_error:
raise last_error
return wrapper
return decorator

View file

@ -1,9 +1,15 @@
"""Adapter for Generic API LLM provider API"""
from typing import Type
from pydantic import BaseModel
import logging
import litellm
import instructor
from typing import Type
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
from instructor.exceptions import InstructorRetryException
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.llm_interface import (
LLMInterface,
)
@ -11,7 +17,6 @@ from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.l
rate_limit_async,
sleep_and_retry_async,
)
import litellm
class GenericAPIAdapter(LLMInterface):
@ -31,13 +36,27 @@ class GenericAPIAdapter(LLMInterface):
model: str
api_key: str
def __init__(self, endpoint, api_key: str, model: str, name: str, max_tokens: int):
def __init__(
self,
endpoint,
api_key: str,
model: str,
name: str,
max_tokens: int,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.name = name
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.max_tokens = max_tokens
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
)
@ -70,19 +89,68 @@ class GenericAPIAdapter(LLMInterface):
output from the language model.
"""
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_base=self.endpoint,
response_model=response_model,
)
try:
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_base=self.endpoint,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
InstructorRetryException,
) as error:
if (
isinstance(error, InstructorRetryException)
and "content management policy" not in str(error).lower()
):
raise error
if not (self.fallback_model and self.fallback_api_key and self.fallback_endpoint):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
try:
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
InstructorRetryException,
) as error:
if (
isinstance(error, InstructorRetryException)
and "content management policy" not in str(error).lower()
):
raise error
else:
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)

View file

@ -50,7 +50,7 @@ def get_llm_client():
# Check if max_token value is defined in liteLLM for given model
# if not use value from cognee configuration
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.utils import (
from cognee.infrastructure.llm.utils import (
get_model_max_tokens,
) # imported here to avoid circular imports

View file

@ -49,9 +49,7 @@ from functools import wraps
from limits import RateLimitItemPerMinute, storage
from limits.strategies import MovingWindowRateLimiter
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import (
get_llm_config,
)
from cognee.infrastructure.llm.config import get_llm_config
logger = get_logger()

View file

@ -1,107 +0,0 @@
import litellm
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.shared.logging_utils import get_logger
logger = get_logger()
def get_max_chunk_tokens():
"""
Calculate the maximum number of tokens allowed in a chunk.
The function determines the maximum chunk size based on the maximum token limit of the
embedding engine and half of the LLM maximum context token size. It ensures that the
chunk size does not exceed these constraints.
Returns:
--------
- int: The maximum number of tokens that can be included in a chunk, determined by
the smaller value of the embedding engine's max tokens and half of the LLM's
maximum tokens.
"""
# NOTE: Import must be done in function to avoid circular import issue
from cognee.infrastructure.databases.vector import get_vector_engine
# Calculate max chunk size based on the following formula
embedding_engine = get_vector_engine().embedding_engine
llm_client = get_llm_client()
# We need to make sure chunk size won't take more than half of LLM max context token size
# but it also can't be bigger than the embedding engine max token size
llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
return max_chunk_tokens
def get_model_max_tokens(model_name: str):
"""
Retrieve the maximum token limit for a specified model name if it exists.
Checks if the provided model name is present in the predefined model cost dictionary. If
found, it logs the maximum token count for that model and returns it. If the model name
is not recognized, it logs an informational message and returns None.
Parameters:
-----------
- model_name (str): Name of LLM or embedding model
Returns:
--------
Number of max tokens of model, or None if model is unknown
"""
max_tokens = None
if model_name in litellm.model_cost:
max_tokens = litellm.model_cost[model_name]["max_tokens"]
logger.debug(f"Max input tokens for {model_name}: {max_tokens}")
else:
logger.info("Model not found in LiteLLM's model_cost.")
return max_tokens
async def test_llm_connection():
"""
Establish a connection to the LLM and create a structured output.
Attempt to connect to the LLM client and uses the adapter to create a structured output
with a predefined text input and system prompt. Log any exceptions encountered during
the connection attempt and re-raise the exception for further handling.
"""
try:
llm_adapter = get_llm_client()
await llm_adapter.acreate_structured_output(
text_input="test",
system_prompt='Respond to me with the following string: "test"',
response_model=str,
)
except Exception as e:
logger.error(e)
logger.error("Connection to LLM could not be established.")
raise e
async def test_embedding_connection():
"""
Test the connection to the embedding engine by embedding a sample text.
Handles exceptions that may occur during the operation, logs the error, and re-raises
the exception if the connection to the embedding handler cannot be established.
"""
try:
# NOTE: Vector engine import must be done in function to avoid circular import issue
from cognee.infrastructure.databases.vector import get_vector_engine
await get_vector_engine().embedding_engine.embed_text("test")
except Exception as e:
logger.error(e)
logger.error("Connection to Embedding handler could not be established.")
raise e

View file

@ -1,7 +1,5 @@
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.modules.chunking.Chunker import Chunker
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from .Document import Document
@ -10,7 +8,7 @@ class AudioDocument(Document):
type: str = "audio"
async def create_transcript(self):
result = await get_llm_client().create_transcript(self.raw_data_location)
result = await LLMAdapter.create_transcript(self.raw_data_location)
return result.text
async def read(self, chunker_cls: Chunker, max_chunk_size: int):

View file

@ -1,6 +1,4 @@
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.modules.chunking.Chunker import Chunker
from .Document import Document
@ -10,7 +8,7 @@ class ImageDocument(Document):
type: str = "image"
async def transcribe_image(self):
result = await get_llm_client().transcribe_image(self.raw_data_location)
result = await LLMAdapter.transcribe_image(self.raw_data_location)
return result.choices[0].message.content
async def read(self, chunker_cls: Chunker, max_chunk_size: int):

View file

@ -69,7 +69,7 @@ async def cognee_pipeline(
cognee_pipeline.first_run = True
if cognee_pipeline.first_run:
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.utils import (
from cognee.infrastructure.llm.utils import (
test_llm_connection,
test_embedding_connection,
)

View file

@ -7,12 +7,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
logger = get_logger("CodeRetriever")
@ -46,11 +41,10 @@ class CodeRetriever(BaseRetriever):
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
llm_client = get_llm_client()
system_prompt = LLMAdapter.read_query_prompt("codegraph_retriever_system.txt")
try:
result = await llm_client.acreate_structured_output(
result = await LLMAdapter.acreate_structured_output(
text_input=query,
system_prompt=system_prompt,
response_model=self.CodeQueryInfo,

View file

@ -1,14 +1,9 @@
from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
render_prompt,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
logger = get_logger()
@ -78,7 +73,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer to the user's query.
"""
llm_client = get_llm_client()
followup_question = ""
triplets = []
answer = [""]
@ -100,27 +94,27 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
logger.info(f"Chain-of-thought: round {round_idx} - answer: {answer}")
if round_idx < max_iter:
valid_args = {"query": query, "answer": answer, "context": context}
valid_user_prompt = render_prompt(
valid_user_prompt = LLMAdapter.render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = read_query_prompt(
valid_system_prompt = LLMAdapter.read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning = await llm_client.acreate_structured_output(
reasoning = await LLMAdapter.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": answer, "reasoning": reasoning}
followup_prompt = render_prompt(
followup_prompt = LLMAdapter.render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = read_query_prompt(
followup_system = LLMAdapter.read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question = await llm_client.acreate_structured_output(
followup_question = await LLMAdapter.llm_client.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
logger.info(

View file

@ -2,12 +2,7 @@ from typing import Any, Optional
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
render_prompt,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
@ -55,8 +50,7 @@ class NaturalLanguageRetriever(BaseRetriever):
async def _generate_cypher_query(self, query: str, edge_schemas, previous_attempts=None) -> str:
"""Generate a Cypher query using LLM based on natural language query and schema information."""
llm_client = get_llm_client()
system_prompt = render_prompt(
system_prompt = LLMAdapter.render_prompt(
self.system_prompt_path,
context={
"edge_schemas": edge_schemas,
@ -64,7 +58,7 @@ class NaturalLanguageRetriever(BaseRetriever):
},
)
return await llm_client.acreate_structured_output(
return await LLMAdapter.acreate_structured_output(
text_input=query,
system_prompt=system_prompt,
response_model=str,

View file

@ -1,10 +1,4 @@
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
render_prompt,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
async def generate_completion(
@ -15,11 +9,10 @@ async def generate_completion(
) -> str:
"""Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args)
system_prompt = read_query_prompt(system_prompt_path)
user_prompt = LLMAdapter.render_prompt(user_prompt_path, args)
system_prompt = LLMAdapter.read_query_prompt(system_prompt_path)
llm_client = get_llm_client()
return await llm_client.acreate_structured_output(
return await LLMAdapter.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
@ -31,10 +24,9 @@ async def summarize_text(
prompt_path: str = "summarize_search_results.txt",
) -> str:
"""Summarizes text using LLM with the specified prompt."""
system_prompt = read_query_prompt(prompt_path)
llm_client = get_llm_client()
system_prompt = LLMAdapter.read_query_prompt(prompt_path)
return await llm_client.acreate_structured_output(
return await LLMAdapter.acreate_structured_output(
text_input=text,
system_prompt=system_prompt,
response_model=str,

View file

@ -9,9 +9,7 @@ from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from cognee.modules.search.methods import search
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
logger = get_logger(level=ERROR)
@ -73,8 +71,7 @@ async def code_description_to_code_part(
if isinstance(obj, dict) and "description" in obj
)
llm_client = get_llm_client()
context_from_documents = await llm_client.acreate_structured_output(
context_from_documents = await LLMAdapter.acreate_structured_output(
text_input=f"The retrieved context from documents is {concatenated_descriptions}.",
system_prompt="You are a Senior Software Engineer, summarize the context from documents"
f" in a way that it is gonna be provided next to codeparts as context"

View file

@ -4,7 +4,7 @@ from enum import Enum, auto
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import (
from cognee.infrastructure.llm.config import (
get_llm_config,
)

View file

@ -7,9 +7,7 @@ from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine.models import DataPoint
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction.extract_categories import (
extract_categories,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
@ -42,7 +40,7 @@ async def chunk_naive_llm_classifier(
return data_chunks
chunk_classifications = await asyncio.gather(
*[extract_categories(chunk.text, classification_model) for chunk in data_chunks],
*[LLMAdapter.extract_categories(chunk.text, classification_model) for chunk in data_chunks],
)
classification_data_points = []

View file

@ -6,13 +6,7 @@ from pydantic import BaseModel
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
from cognee.modules.engine.models import Entity
from cognee.modules.engine.models.EntityType import EntityType
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
render_prompt,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
logger = get_logger("llm_entity_extractor")
@ -56,11 +50,10 @@ class LLMEntityExtractor(BaseEntityExtractor):
try:
logger.info(f"Extracting entities from text: {text[:100]}...")
llm_client = get_llm_client()
user_prompt = render_prompt(self.user_prompt_template, {"text": text})
system_prompt = read_query_prompt(self.system_prompt_template)
user_prompt = LLMAdapter.render_prompt(self.user_prompt_template, {"text": text})
system_prompt = LLMAdapter.read_query_prompt(self.system_prompt_template)
response = await llm_client.acreate_structured_output(
response = await LLMAdapter.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=EntityList,

View file

@ -1,13 +1,7 @@
from typing import List, Tuple
from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
render_prompt,
read_query_prompt,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.root_dir import get_absolute_path
@ -22,7 +16,6 @@ async def extract_content_nodes_and_relationship_names(
content: str, existing_nodes: List[str], n_rounds: int = 2
) -> Tuple[List[str], List[str]]:
"""Extracts node names and relationship_names from content through multiple rounds of analysis."""
llm_client = get_llm_client()
all_nodes: List[str] = existing_nodes.copy()
all_relationship_names: List[str] = []
existing_node_set = {node.lower() for node in all_nodes}
@ -39,15 +32,15 @@ async def extract_content_nodes_and_relationship_names(
}
base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts")
text_input = render_prompt(
text_input = LLMAdapter.render_prompt(
"extract_graph_relationship_names_prompt_input.txt",
context,
base_directory=base_directory,
)
system_prompt = read_query_prompt(
system_prompt = LLMAdapter.read_query_prompt(
"extract_graph_relationship_names_prompt_system.txt", base_directory=base_directory
)
response = await llm_client.acreate_structured_output(
response = await LLMAdapter.acreate_structured_output(
text_input=text_input,
system_prompt=system_prompt,
response_model=PotentialNodesAndRelationshipNames,

View file

@ -1,11 +1,6 @@
from typing import List
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
render_prompt,
read_query_prompt,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.shared.data_models import KnowledgeGraph
from cognee.root_dir import get_absolute_path
@ -14,7 +9,6 @@ async def extract_edge_triplets(
content: str, nodes: List[str], relationship_names: List[str], n_rounds: int = 2
) -> KnowledgeGraph:
"""Creates a knowledge graph by identifying relationships between the provided nodes."""
llm_client = get_llm_client()
final_graph = KnowledgeGraph(nodes=[], edges=[])
existing_nodes = set()
existing_node_ids = set()
@ -32,13 +26,13 @@ async def extract_edge_triplets(
}
base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts")
text_input = render_prompt(
text_input = LLMAdapter.render_prompt(
"extract_graph_edge_triplets_prompt_input.txt", context, base_directory=base_directory
)
system_prompt = read_query_prompt(
system_prompt = LLMAdapter.read_query_prompt(
"extract_graph_edge_triplets_prompt_system.txt", base_directory=base_directory
)
extracted_graph = await llm_client.acreate_structured_output(
extracted_graph = await LLMAdapter.acreate_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=KnowledgeGraph
)

View file

@ -1,13 +1,7 @@
from typing import List
from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
render_prompt,
read_query_prompt,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.root_dir import get_absolute_path
@ -19,7 +13,6 @@ class PotentialNodes(BaseModel):
async def extract_nodes(text: str, n_rounds: int = 2) -> List[str]:
"""Extracts node names from content through multiple rounds of analysis."""
llm_client = get_llm_client()
all_nodes: List[str] = []
existing_nodes = set()
@ -31,13 +24,13 @@ async def extract_nodes(text: str, n_rounds: int = 2) -> List[str]:
"text": text,
}
base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts")
text_input = render_prompt(
text_input = LLMAdapter.render_prompt(
"extract_graph_nodes_prompt_input.txt", context, base_directory=base_directory
)
system_prompt = read_query_prompt(
system_prompt = LLMAdapter.read_query_prompt(
"extract_graph_nodes_prompt_system.txt", base_directory=base_directory
)
response = await llm_client.acreate_structured_output(
response = await LLMAdapter.acreate_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=PotentialNodes
)

View file

@ -2,22 +2,9 @@ import asyncio
from typing import Type, List
from pydantic import BaseModel
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.tasks.storage import add_data_points
from cognee.base_config import get_base_config
# Framework selection
base = get_base_config()
if base.structured_output_framework == "BAML":
print(f"Using BAML framework: {base.structured_output_framework}")
from cognee.infrastructure.llm.structured_output_framework.baml_src.extraction import (
extract_content_graph,
)
else:
print(f"Using llitellm_instructor framework: {base.structured_output_framework}")
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import (
extract_content_graph,
)
async def extract_graph_from_code(
@ -31,7 +18,7 @@ async def extract_graph_from_code(
- Graph nodes are stored using the `add_data_points` function for later retrieval or analysis.
"""
chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
*[LLMAdapter.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
for chunk_index, chunk in enumerate(data_chunks):

View file

@ -6,25 +6,12 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.tasks.storage.add_data_points import add_data_points
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.base_config import get_base_config
from cognee.modules.graph.utils import (
expand_with_nodes_and_edges,
retrieve_existing_edges,
)
from cognee.shared.data_models import KnowledgeGraph
# Framework selection
base = get_base_config()
if base.structured_output_framework == "BAML":
print(f"Using BAML framework: {base.structured_output_framework}")
from cognee.infrastructure.llm.structured_output_framework.baml_src.extraction import (
extract_content_graph,
)
else:
print(f"Using llitellm_instructor framework: {base.structured_output_framework}")
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import (
extract_content_graph,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
async def integrate_chunk_graphs(
@ -69,7 +56,7 @@ async def extract_graph_from_data(
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
"""
chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
*[LLMAdapter.extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
# Note: Filter edges with missing source or target nodes

View file

@ -15,12 +15,7 @@ from pydantic import BaseModel
from cognee.modules.graph.exceptions import EntityNotFoundError
from cognee.modules.ingestion.exceptions import IngestionError
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.data.chunking.config import get_chunk_config
from cognee.infrastructure.data.chunking.get_chunking_engine import get_chunk_engine
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
@ -32,6 +27,7 @@ from cognee.modules.data.methods.add_model_class_to_graph import (
from cognee.tasks.graph.models import NodeModel, GraphOntology
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.engine.utils import generate_node_id, generate_node_name
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
logger = get_logger("task:infer_data_ontology")
@ -56,11 +52,10 @@ async def extract_ontology(content: str, response_model: Type[BaseModel]):
The structured ontology extracted from the content.
"""
llm_client = get_llm_client()
system_prompt = read_query_prompt("extract_ontology.txt")
system_prompt = LLMAdapter.read_query_prompt("extract_ontology.txt")
ontology = await llm_client.acreate_structured_output(content, system_prompt, response_model)
ontology = await LLMAdapter.acreate_structured_output(content, system_prompt, response_model)
return ontology

View file

@ -3,24 +3,9 @@ from typing import AsyncGenerator, Union
from uuid import uuid5
from cognee.infrastructure.engine import DataPoint
from cognee.base_config import get_base_config
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from .models import CodeSummary
# Framework selection
base = get_base_config()
if base.structured_output_framework == "BAML":
print(f"Using BAML framework for code summarization: {base.structured_output_framework}")
from cognee.infrastructure.llm.structured_output_framework.baml_src.extraction import (
extract_code_summary,
)
else:
print(
f"Using llitellm_instructor framework for code summarization: {base.structured_output_framework}"
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import (
extract_code_summary,
)
async def summarize_code(
code_graph_nodes: list[DataPoint],
@ -31,7 +16,7 @@ async def summarize_code(
code_data_points = [file for file in code_graph_nodes if hasattr(file, "source_code")]
file_summaries = await asyncio.gather(
*[extract_code_summary(file.source_code) for file in code_data_points]
*[LLMAdapter.extract_code_summary(file.source_code) for file in code_data_points]
)
file_summaries_map = {

View file

@ -3,26 +3,11 @@ from typing import Type
from uuid import uuid5
from pydantic import BaseModel
from cognee.base_config import get_base_config
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.modules.cognify.config import get_cognify_config
from .models import TextSummary
# Framework selection
base = get_base_config()
if base.structured_output_framework == "BAML":
print(f"Using BAML framework for text summarization: {base.structured_output_framework}")
from cognee.infrastructure.llm.structured_output_framework.baml_src.extraction import (
extract_summary,
)
else:
print(
f"Using llitellm_instructor framework for text summarization: {base.structured_output_framework}"
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import (
extract_summary,
)
async def summarize_text(
data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel] = None
@ -58,7 +43,7 @@ async def summarize_text(
summarization_model = cognee_config.summarization_model
chunk_summaries = await asyncio.gather(
*[extract_summary(chunk.text, summarization_model) for chunk in data_chunks]
*[LLMAdapter.extract_summary(chunk.text, summarization_model) for chunk in data_chunks]
)
summaries = [

View file

@ -1 +0,0 @@
# Vector database tests module

View file

@ -4,7 +4,7 @@ from typing import List
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import (
LiteLLMEmbeddingEngine,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.embedding_rate_limiter import (
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
embedding_rate_limit_async,
embedding_sleep_and_retry_async,
)

View file

@ -3,10 +3,10 @@ import time
import asyncio
import logging
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import (
from cognee.infrastructure.llm.config import (
get_llm_config,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.embedding_rate_limiter import (
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
EmbeddingRateLimiter,
)
from cognee.tests.unit.infrastructure.mock_embedding_engine import MockEmbeddingEngine

View file

@ -5,7 +5,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.rate_limiter import (
llm_rate_limiter,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.config import (
from cognee.infrastructure.llm.config import (
get_llm_config,
)

View file

@ -3,9 +3,7 @@ import logging
import cognee
import asyncio
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from dotenv import load_dotenv
from cognee.api.v1.search import SearchType
from cognee.modules.engine.models import NodeSet
@ -187,7 +185,7 @@ async def run_procurement_example():
print(research_information)
print("\nPassing research to LLM for final procurement recommendation...\n")
final_decision = await get_llm_client().acreate_structured_output(
final_decision = await LLMAdapter.acreate_structured_output(
text_input=research_information,
system_prompt="""You are a procurement decision assistant. Use the provided QA pairs that were collected through a research phase. Recommend the best vendor,
based on pricing, delivery, warranty, policy fit, and past performance. Be concise and justify your choice with evidence.

View file

@ -12,13 +12,7 @@ from cognee.tasks.temporal_awareness.index_graphiti_objects import (
)
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
render_prompt,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.modules.users.methods import get_default_user
text_list = [
@ -65,11 +59,10 @@ async def main():
"context": context,
}
user_prompt = render_prompt("graph_context_for_question.txt", args)
system_prompt = read_query_prompt("answer_simple_question_restricted.txt")
user_prompt = LLMAdapter.render_prompt("graph_context_for_question.txt", args)
system_prompt = LLMAdapter.read_query_prompt("answer_simple_question_restricted.txt")
llm_client = get_llm_client()
computed_answer = await llm_client.acreate_structured_output(
computed_answer = await LLMAdapter.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,