refactor: move prompts to LLM folder

This commit is contained in:
Igor Ilic 2025-08-05 19:06:40 +02:00
parent 7761b70229
commit cb21e08ad9
65 changed files with 1691 additions and 64 deletions

View file

@ -0,0 +1,548 @@
import threading
import logging
import functools
import os
import time
import asyncio
import random
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.config import get_llm_config
logger = get_logger()
# Common error patterns that indicate rate limiting
RATE_LIMIT_ERROR_PATTERNS = [
"rate limit",
"rate_limit",
"ratelimit",
"too many requests",
"retry after",
"capacity",
"quota",
"limit exceeded",
"tps limit exceeded",
"request limit exceeded",
"maximum requests",
"exceeded your current quota",
"throttled",
"throttling",
]
# Default retry settings
DEFAULT_MAX_RETRIES = 5
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
class EmbeddingRateLimiter:
"""
Rate limiter for embedding API calls.
This class implements a singleton pattern to ensure that rate limiting
is consistent across all embedding requests. It uses the limits
library with a moving window strategy to control request rates.
The rate limiter uses the same configuration as the LLM API rate limiter
but uses a separate key to track embedding API calls independently.
Public Methods:
- get_instance
- reset_instance
- hit_limit
- wait_if_needed
- async_wait_if_needed
Instance Variables:
- enabled
- requests_limit
- interval_seconds
- request_times
- lock
"""
_instance = None
lock = threading.Lock()
@classmethod
def get_instance(cls):
"""
Retrieve the singleton instance of the EmbeddingRateLimiter.
This method ensures that only one instance of the class exists and
is thread-safe. It lazily initializes the instance if it doesn't
already exist.
Returns:
--------
The singleton instance of the EmbeddingRateLimiter class.
"""
if cls._instance is None:
with cls.lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset_instance(cls):
"""
Reset the singleton instance of the EmbeddingRateLimiter.
This method is thread-safe and sets the instance to None, allowing
for a new instance to be created when requested again.
"""
with cls.lock:
cls._instance = None
def __init__(self):
config = get_llm_config()
self.enabled = config.embedding_rate_limit_enabled
self.requests_limit = config.embedding_rate_limit_requests
self.interval_seconds = config.embedding_rate_limit_interval
self.request_times = []
self.lock = threading.Lock()
logging.info(
f"EmbeddingRateLimiter initialized: enabled={self.enabled}, "
f"requests_limit={self.requests_limit}, interval_seconds={self.interval_seconds}"
)
def hit_limit(self) -> bool:
"""
Check if the current request would exceed the rate limit.
This method checks if the rate limiter is enabled and evaluates
the number of requests made in the elapsed interval.
Returns:
- bool: True if the rate limit would be exceeded, False otherwise.
Returns:
--------
- bool: True if the rate limit would be exceeded, otherwise False.
"""
if not self.enabled:
return False
current_time = time.time()
with self.lock:
# Remove expired request times
cutoff_time = current_time - self.interval_seconds
self.request_times = [t for t in self.request_times if t > cutoff_time]
# Check if adding a new request would exceed the limit
if len(self.request_times) >= self.requests_limit:
logger.info(
f"Rate limit hit: {len(self.request_times)} requests in the last {self.interval_seconds} seconds"
)
return True
# Otherwise, we're under the limit
return False
def wait_if_needed(self) -> float:
"""
Block until a request can be made without exceeding the rate limit.
This method will wait if the current request would exceed the
rate limit and returns the time waited in seconds.
Returns:
- float: Time waited in seconds before a request is allowed.
Returns:
--------
- float: Time waited in seconds before proceeding.
"""
if not self.enabled:
return 0
wait_time = 0
start_time = time.time()
while self.hit_limit():
time.sleep(0.5) # Poll every 0.5 seconds
wait_time = time.time() - start_time
# Record this request
with self.lock:
self.request_times.append(time.time())
return wait_time
async def async_wait_if_needed(self) -> float:
"""
Asynchronously wait until a request can be made without exceeding the rate limit.
This method will wait if the current request would exceed the
rate limit and returns the time waited in seconds.
Returns:
- float: Time waited in seconds before a request is allowed.
Returns:
--------
- float: Time waited in seconds before proceeding.
"""
if not self.enabled:
return 0
wait_time = 0
start_time = time.time()
while self.hit_limit():
await asyncio.sleep(0.5) # Poll every 0.5 seconds
wait_time = time.time() - start_time
# Record this request
with self.lock:
self.request_times.append(time.time())
return wait_time
def embedding_rate_limit_sync(func):
"""
Apply rate limiting to a synchronous embedding function.
Parameters:
-----------
- func: Function to decorate with rate limiting logic.
Returns:
--------
Returns the decorated function that applies rate limiting.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
"""
Wrap the given function with rate limiting logic to control the embedding API usage.
Checks if the rate limit has been exceeded before allowing the function to execute. If
the limit is hit, it logs a warning and raises an EmbeddingException. Otherwise, it
updates the request count and proceeds to call the original function.
Parameters:
-----------
- *args: Variable length argument list for the wrapped function.
- **kwargs: Keyword arguments for the wrapped function.
Returns:
--------
Returns the result of the wrapped function if rate limiting conditions are met.
"""
limiter = EmbeddingRateLimiter.get_instance()
# Check if rate limiting is enabled and if we're at the limit
if limiter.hit_limit():
error_msg = "Embedding API rate limit exceeded"
logger.warning(error_msg)
# Create a custom embedding rate limit exception
from cognee.infrastructure.databases.exceptions.EmbeddingException import (
EmbeddingException,
)
raise EmbeddingException(error_msg)
# Add this request to the counter and proceed
limiter.wait_if_needed()
return func(*args, **kwargs)
return wrapper
def embedding_rate_limit_async(func):
"""
Decorator that applies rate limiting to an asynchronous embedding function.
Parameters:
-----------
- func: Async function to decorate.
Returns:
--------
Returns the decorated async function that applies rate limiting.
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
"""
Handle function calls with embedding rate limiting.
This asynchronous wrapper checks if the embedding API rate limit is exceeded before
allowing the function to execute. If the limit is exceeded, it logs a warning and raises
an EmbeddingException. If not, it waits as necessary and proceeds with the function
call.
Parameters:
-----------
- *args: Positional arguments passed to the wrapped function.
- **kwargs: Keyword arguments passed to the wrapped function.
Returns:
--------
Returns the result of the wrapped function after handling rate limiting.
"""
limiter = EmbeddingRateLimiter.get_instance()
# Check if rate limiting is enabled and if we're at the limit
if limiter.hit_limit():
error_msg = "Embedding API rate limit exceeded"
logger.warning(error_msg)
# Create a custom embedding rate limit exception
from cognee.infrastructure.databases.exceptions.EmbeddingException import (
EmbeddingException,
)
raise EmbeddingException(error_msg)
# Add this request to the counter and proceed
await limiter.async_wait_if_needed()
return await func(*args, **kwargs)
return wrapper
def embedding_sleep_and_retry_sync(max_retries=5, base_backoff=1.0, jitter=0.5):
"""
Add retry with exponential backoff for synchronous embedding functions.
Parameters:
-----------
- max_retries: Maximum number of retries before giving up. (default 5)
- base_backoff: Base backoff time in seconds for retry intervals. (default 1.0)
- jitter: Jitter factor to randomize the backoff time to avoid collision. (default
0.5)
Returns:
--------
A decorator that retries the wrapped function on rate limit errors, applying
exponential backoff with jitter.
"""
def decorator(func):
"""
Wraps a function to apply retry logic on rate limit errors.
Parameters:
-----------
- func: The function to be wrapped with retry logic.
Returns:
--------
Returns the wrapped function with retry logic applied.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
"""
Retry the execution of a function with backoff on failure due to rate limit errors.
This wrapper function will call the specified function and if it raises an exception, it
will handle retries according to defined conditions. It will check the environment for a
DISABLE_RETRIES flag to determine whether to retry or propagate errors immediately
during tests. If the error is identified as a rate limit error, it will apply an
exponential backoff strategy with jitter before retrying, up to a maximum number of
retries. If the retries are exhausted, it raises the last encountered error.
Parameters:
-----------
- *args: Positional arguments passed to the wrapped function.
- **kwargs: Keyword arguments passed to the wrapped function.
Returns:
--------
Returns the result of the wrapped function if successful; otherwise, raises the last
error encountered after maximum retries are exhausted.
"""
# If DISABLE_RETRIES is set, don't retry for testing purposes
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
"true",
"1",
"yes",
)
retries = 0
last_error = None
while retries <= max_retries:
try:
return func(*args, **kwargs)
except Exception as e:
# Check if this is a rate limit error
error_str = str(e).lower()
error_type = type(e).__name__
is_rate_limit = any(
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
)
if disable_retries:
# For testing, propagate the exception immediately
raise
if is_rate_limit and retries < max_retries:
# Calculate backoff with jitter
backoff = (
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
)
logger.warning(
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
f"(attempt {retries + 1}/{max_retries}): "
f"({error_str!r}, {error_type!r})"
)
time.sleep(backoff)
retries += 1
last_error = e
else:
# Not a rate limit error or max retries reached, raise
raise
# If we exit the loop due to max retries, raise the last error
if last_error:
raise last_error
return wrapper
return decorator
def embedding_sleep_and_retry_async(max_retries=5, base_backoff=1.0, jitter=0.5):
"""
Add retry logic with exponential backoff for asynchronous embedding functions.
This decorator retries the wrapped asynchronous function upon encountering rate limit
errors, utilizing exponential backoff with optional jitter to space out retry attempts.
It allows for a maximum number of retries before giving up and raising the last error
encountered.
Parameters:
-----------
- max_retries: Maximum number of retries allowed before giving up. (default 5)
- base_backoff: Base amount of time in seconds to wait before retrying after a rate
limit error. (default 1.0)
- jitter: Amount of randomness to add to the backoff duration to help mitigate burst
issues on retries. (default 0.5)
Returns:
--------
Returns a decorated asynchronous function that implements the retry logic on rate
limit errors.
"""
def decorator(func):
"""
Handle retries for an async function with exponential backoff and jitter.
Parameters:
-----------
- func: An asynchronous function to be wrapped with retry logic.
Returns:
--------
Returns the wrapper function that manages the retry behavior for the wrapped async
function.
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
"""
Handle retries for an async function with exponential backoff and jitter.
If the environment variable DISABLE_RETRIES is set to true, 1, or yes, the function will
not retry on errors.
It attempts to call the wrapped function until it succeeds or the maximum number of
retries is reached. If an exception occurs, it checks if it's a rate limit error to
determine if a retry is needed.
Parameters:
-----------
- *args: Positional arguments passed to the wrapped function.
- **kwargs: Keyword arguments passed to the wrapped function.
Returns:
--------
Returns the result of the wrapped async function if successful; raises the last
encountered error if all retries fail.
"""
# If DISABLE_RETRIES is set, don't retry for testing purposes
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
"true",
"1",
"yes",
)
retries = 0
last_error = None
while retries <= max_retries:
try:
return await func(*args, **kwargs)
except Exception as e:
# Check if this is a rate limit error
error_str = str(e).lower()
error_type = type(e).__name__
is_rate_limit = any(
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
)
if disable_retries:
# For testing, propagate the exception immediately
raise
if is_rate_limit and retries < max_retries:
# Calculate backoff with jitter
backoff = (
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
)
logger.warning(
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
f"(attempt {retries + 1}/{max_retries}): "
f"({error_str!r}, {error_type!r})"
)
await asyncio.sleep(backoff)
retries += 1
last_error = e
else:
# Not a rate limit error or max retries reached, raise
raise
# If we exit the loop due to max retries, raise the last error
if last_error:
raise last_error
return wrapper
return decorator

View file

@ -0,0 +1,136 @@
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm import get_llm_config
# TODO: Check if Coroutines should be returned or awaited result values
class LLMAdapter:
"""
Class handles selection of structured output frameworks and LLM functions.
Class used as a namespace for LLM related functions, should not be instantiated, all methods are static.
"""
@staticmethod
def render_prompt(filename: str, context: dict, base_directory: str = None):
from cognee.infrastructure.llm.prompts import render_prompt
return render_prompt(filename=filename, context=context, base_directory=base_directory)
@staticmethod
def acreate_structured_output(
text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
llm_client = get_llm_client()
return llm_client.acreate_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=response_model
)
@staticmethod
def create_structured_output(
text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
llm_client = get_llm_client()
return llm_client.create_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=response_model
)
@staticmethod
def create_transcript(input):
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
llm_client = get_llm_client()
return llm_client.create_transcript(input=input)
@staticmethod
def transcribe_image(input):
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
llm_client = get_llm_client()
return llm_client.transcribe_image(input=input)
@staticmethod
def show_prompt(text_input: str, system_prompt: str) -> str:
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
llm_client = get_llm_client()
return llm_client.show_prompt(text_input=text_input, system_prompt=system_prompt)
@staticmethod
def read_query_prompt(prompt_file_name: str, base_directory: str = None):
from cognee.infrastructure.llm.prompts import (
read_query_prompt,
)
return read_query_prompt(prompt_file_name=prompt_file_name, base_directory=base_directory)
@staticmethod
def extract_content_graph(content: str, response_model: Type[BaseModel], mode: str = "simple"):
llm_config = get_llm_config()
if llm_config.structured_output_framework.upper() == "BAML":
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
extract_content_graph,
)
return extract_content_graph(content=content, response_model=response_model, mode=mode)
else:
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import (
extract_content_graph,
)
return extract_content_graph(content=content, response_model=response_model)
@staticmethod
def extract_categories(content: str, response_model: Type[BaseModel]):
# TODO: Add BAML version of category and extraction and update function
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import (
extract_categories,
)
return extract_categories(content=content, response_model=response_model)
@staticmethod
def extract_code_summary(content: str):
llm_config = get_llm_config()
if llm_config.structured_output_framework == "BAML":
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
extract_code_summary,
)
return extract_code_summary(content=content)
else:
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import (
extract_code_summary,
)
return extract_code_summary(content=content)
@staticmethod
def extract_summary(content: str, response_model: Type[BaseModel]):
llm_config = get_llm_config()
if llm_config.structured_output_framework == "BAML":
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
extract_summary,
)
return extract_summary(content=content, response_model=response_model)
else:
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import (
extract_summary,
)
return extract_summary(content=content, response_model=response_model)

View file

@ -11,4 +11,4 @@ from cognee.infrastructure.llm.utils import (
test_embedding_connection,
)
from LLMAdapter import LLMAdapter
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter

View file

@ -0,0 +1,194 @@
import os
from typing import Optional, ClassVar
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import model_validator
from baml_py import ClientRegistry
class LLMConfig(BaseSettings):
"""
Configuration settings for the LLM (Large Language Model) provider and related options.
Public instance variables include:
- llm_provider
- llm_model
- llm_endpoint
- llm_api_key
- llm_api_version
- llm_temperature
- llm_streaming
- llm_max_tokens
- transcription_model
- graph_prompt_path
- llm_rate_limit_enabled
- llm_rate_limit_requests
- llm_rate_limit_interval
- embedding_rate_limit_enabled
- embedding_rate_limit_requests
- embedding_rate_limit_interval
Public methods include:
- ensure_env_vars_for_ollama
- to_dict
"""
structured_output_framework: str = "instructor"
llm_provider: str = "openai"
llm_model: str = "gpt-4o-mini"
llm_endpoint: str = ""
llm_api_key: Optional[str] = None
llm_api_version: Optional[str] = None
llm_temperature: float = 0.0
llm_streaming: bool = False
llm_max_tokens: int = 16384
transcription_model: str = "whisper-1"
graph_prompt_path: str = "generate_graph_prompt.txt"
llm_rate_limit_enabled: bool = False
llm_rate_limit_requests: int = 60
llm_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
embedding_rate_limit_enabled: bool = False
embedding_rate_limit_requests: int = 60
embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
fallback_api_key: str = ""
fallback_endpoint: str = ""
fallback_model: str = ""
baml_registry: ClassVar[ClientRegistry] = ClientRegistry()
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def model_post_init(self, __context) -> None:
"""Initialize the BAML registry after the model is created."""
self.baml_registry.add_llm_client(
name=self.llm_provider,
provider=self.llm_provider,
options={
"model": self.llm_model,
"temperature": self.llm_temperature,
"api_key": self.llm_api_key,
},
)
# Sets the primary client
self.baml_registry.set_primary(self.llm_provider)
@model_validator(mode="after")
def ensure_env_vars_for_ollama(self) -> "LLMConfig":
"""
Validate required environment variables for the 'ollama' LLM provider.
Raises ValueError if some required environment variables are set without the others.
Only checks are performed when 'llm_provider' is set to 'ollama'.
Returns:
--------
- 'LLMConfig': The instance of LLMConfig after validation.
"""
if self.llm_provider != "ollama":
# Skip checks unless provider is "ollama"
return self
def is_env_set(var_name: str) -> bool:
"""
Check if a given environment variable is set and non-empty.
Parameters:
-----------
- var_name (str): The name of the environment variable to check.
Returns:
--------
- bool: True if the environment variable exists and is not empty, otherwise False.
"""
val = os.environ.get(var_name)
return val is not None and val.strip() != ""
#
# 1. Check LLM environment variables
#
llm_env_vars = {
"LLM_MODEL": is_env_set("LLM_MODEL"),
"LLM_ENDPOINT": is_env_set("LLM_ENDPOINT"),
"LLM_API_KEY": is_env_set("LLM_API_KEY"),
}
if any(llm_env_vars.values()) and not all(llm_env_vars.values()):
missing_llm = [key for key, is_set in llm_env_vars.items() if not is_set]
raise ValueError(
"You have set some but not all of the required environment variables "
f"for LLM usage (LLM_MODEL, LLM_ENDPOINT, LLM_API_KEY). Missing: {missing_llm}"
)
#
# 2. Check embedding environment variables
#
embedding_env_vars = {
"EMBEDDING_PROVIDER": is_env_set("EMBEDDING_PROVIDER"),
"EMBEDDING_MODEL": is_env_set("EMBEDDING_MODEL"),
"EMBEDDING_DIMENSIONS": is_env_set("EMBEDDING_DIMENSIONS"),
"HUGGINGFACE_TOKENIZER": is_env_set("HUGGINGFACE_TOKENIZER"),
}
if any(embedding_env_vars.values()) and not all(embedding_env_vars.values()):
missing_embed = [key for key, is_set in embedding_env_vars.items() if not is_set]
raise ValueError(
"You have set some but not all of the required environment variables "
"for embeddings (EMBEDDING_PROVIDER, EMBEDDING_MODEL, "
"EMBEDDING_DIMENSIONS, HUGGINGFACE_TOKENIZER). Missing: "
f"{missing_embed}"
)
return self
def to_dict(self) -> dict:
"""
Convert the LLMConfig instance into a dictionary representation.
Returns:
--------
- dict: A dictionary containing the configuration settings of the LLMConfig
instance.
"""
return {
"provider": self.llm_provider,
"model": self.llm_model,
"endpoint": self.llm_endpoint,
"api_key": self.llm_api_key,
"api_version": self.llm_api_version,
"temperature": self.llm_temperature,
"streaming": self.llm_streaming,
"max_tokens": self.llm_max_tokens,
"transcription_model": self.transcription_model,
"graph_prompt_path": self.graph_prompt_path,
"rate_limit_enabled": self.llm_rate_limit_enabled,
"rate_limit_requests": self.llm_rate_limit_requests,
"rate_limit_interval": self.llm_rate_limit_interval,
"embedding_rate_limit_enabled": self.embedding_rate_limit_enabled,
"embedding_rate_limit_requests": self.embedding_rate_limit_requests,
"embedding_rate_limit_interval": self.embedding_rate_limit_interval,
"fallback_api_key": self.fallback_api_key,
"fallback_endpoint": self.fallback_endpoint,
"fallback_model": self.fallback_model,
}
@lru_cache
def get_llm_config():
"""
Retrieve and cache the LLM configuration.
This function returns an instance of the LLMConfig class. It leverages
caching to ensure that repeated calls do not create new instances,
but instead return the already created configuration object.
Returns:
--------
- LLMConfig: An instance of the LLMConfig class containing the configuration for the
LLM.
"""
return LLMConfig()

View file

@ -1,2 +1,2 @@
Answer the question using the provided context. Be as brief as possible.
Each entry in the context is a paragraph, which is represented as a list with two elements [title, sentences] and sentences is a list of strings.
Each entry in the context is a paragraph, which is represented as a list with two elements [title, sentences] and sentences is a list of strings.

View file

@ -1,2 +1,2 @@
Answer the question using the provided context. Be as brief as possible.
Each entry in the context is tuple of length 3, representing an edge of a knowledge graph with its two nodes.
Each entry in the context is tuple of length 3, representing an edge of a knowledge graph with its two nodes.

View file

@ -1 +1 @@
Answer the question using the provided context. Be as brief as possible.
Answer the question using the provided context. Be as brief as possible.

View file

@ -1 +1 @@
Answer the question using the provided context. If the provided context is not connected to the question, just answer "The provided knowledge base does not contain the answer to the question". Be as brief as possible.
Answer the question using the provided context. If the provided context is not connected to the question, just answer "The provided knowledge base does not contain the answer to the question". Be as brief as possible.

View file

@ -1,2 +1,2 @@
Chose the summary that is the most relevant to the query`{{ query }}`
Here are the categories:`{{ categories }}`
Here are the categories:`{{ categories }}`

View file

@ -174,4 +174,4 @@ The possible classifications are:
"Recipes and crafting instructions"
]
}
}
}

View file

@ -0,0 +1,2 @@
The question is: `{{ question }}`
And here is the context: `{{ context }}`

View file

@ -1,2 +1,2 @@
The question is: `{{ question }}`
and here is the context provided with a set of relationships from a knowledge graph separated by \n---\n each represented as node1 -- relation -- node2 triplet: `{{ context }}`
and here is the context provided with a set of relationships from a knowledge graph separated by \n---\n each represented as node1 -- relation -- node2 triplet: `{{ context }}`

View file

@ -63,4 +63,4 @@ This queries doesn't work. Do NOT use them:
Example 1:
Get all nodes connected to John
MATCH (n:Entity {'name': 'John'})--(neighbor)
RETURN n, neighbor
RETURN n, neighbor

View file

@ -1,2 +1,2 @@
I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.
Please respond with a single patch file in the following format.
Please respond with a single patch file in the following format.

View file

@ -25,9 +25,7 @@ def render_prompt(filename: str, context: dict, base_directory: str = None) -> s
# Set the base directory relative to the cognee root directory
if base_directory is None:
base_directory = get_absolute_path(
"./infrastructure/llm/structured_output_framework/llitellm_instructor/llm/prompts"
)
base_directory = get_absolute_path("./infrastructure/llm/prompts")
# Initialize the Jinja2 environment to load templates from the filesystem
env = Environment(

View file

@ -1,5 +1,5 @@
You are an expert Python programmer and technical writer. Your task is to summarize the given Python code snippet or file.
The code may contain multiple imports, classes, functions, constants and logic. Provide a clear, structured explanation of its components
The code may contain multiple imports, classes, functions, constants and logic. Provide a clear, structured explanation of its components
and their relationships.
Instructions:
@ -7,4 +7,4 @@ Provide an overview: Start with a high-level summary of what the code does as a
Break it down: Summarize each class and function individually, explaining their purpose and how they interact.
Describe the workflow: Outline how the classes and functions work together. Mention any control flow (e.g., main functions, entry points, loops).
Key features: Highlight important elements like arguments, return values, or unique logic.
Maintain clarity: Write in plain English for someone familiar with Python but unfamiliar with this code.
Maintain clarity: Write in plain English for someone familiar with Python but unfamiliar with this code.

View file

@ -0,0 +1,109 @@
// Content classification data models - matching shared/data_models.py
class TextContent {
type string
subclass string[]
}
class AudioContent {
type string
subclass string[]
}
class ImageContent {
type string
subclass string[]
}
class VideoContent {
type string
subclass string[]
}
class MultimediaContent {
type string
subclass string[]
}
class Model3DContent {
type string
subclass string[]
}
class ProceduralContent {
type string
subclass string[]
}
class ContentLabel {
content_type "text" | "audio" | "image" | "video" | "multimedia" | "3d_model" | "procedural"
type string
subclass string[]
}
class DefaultContentPrediction {
label ContentLabel
}
// Content classification prompt template
template_string ClassifyContentPrompt() #"
You are a classification engine and should classify content. Make sure to use one of the existing classification options and not invent your own.
Classify the content into one of these main categories and their relevant subclasses:
**TEXT CONTENT** (content_type: "text"):
- type: "TEXTUAL_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Articles, essays, and reports", "Books and manuscripts", "News stories and blog posts", "Research papers and academic publications", "Social media posts and comments", "Website content and product descriptions", "Personal narratives and stories", "Spreadsheets and tables", "Forms and surveys", "Databases and CSV files", "Source code in various programming languages", "Shell commands and scripts", "Markup languages (HTML, XML)", "Stylesheets (CSS) and configuration files (YAML, JSON, INI)", "Chat transcripts and messaging history", "Customer service logs and interactions", "Conversational AI training data", "Textbook content and lecture notes", "Exam questions and academic exercises", "E-learning course materials", "Poetry and prose", "Scripts for plays, movies, and television", "Song lyrics", "Manuals and user guides", "Technical specifications and API documentation", "Helpdesk articles and FAQs", "Contracts and agreements", "Laws, regulations, and legal case documents", "Policy documents and compliance materials", "Clinical trial reports", "Patient records and case notes", "Scientific journal articles", "Financial reports and statements", "Business plans and proposals", "Market research and analysis reports", "Ad copies and marketing slogans", "Product catalogs and brochures", "Press releases and promotional content", "Professional and formal correspondence", "Personal emails and letters", "Image and video captions", "Annotations and metadata for various media", "Vocabulary lists and grammar rules", "Language exercises and quizzes", "Other types of text data"]
**AUDIO CONTENT** (content_type: "audio"):
- type: "AUDIO_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Music tracks and albums", "Podcasts and radio broadcasts", "Audiobooks and audio guides", "Recorded interviews and speeches", "Sound effects and ambient sounds", "Other types of audio recordings"]
**IMAGE CONTENT** (content_type: "image"):
- type: "IMAGE_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Photographs and digital images", "Illustrations, diagrams, and charts", "Infographics and visual data representations", "Artwork and paintings", "Screenshots and graphical user interfaces", "Other types of images"]
**VIDEO CONTENT** (content_type: "video"):
- type: "VIDEO_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Movies and short films", "Documentaries and educational videos", "Video tutorials and how-to guides", "Animated features and cartoons", "Live event recordings and sports broadcasts", "Other types of video content"]
**MULTIMEDIA CONTENT** (content_type: "multimedia"):
- type: "MULTIMEDIA_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Interactive web content and games", "Virtual reality (VR) and augmented reality (AR) experiences", "Mixed media presentations and slide decks", "E-learning modules with integrated multimedia", "Digital exhibitions and virtual tours", "Other types of multimedia content"]
**3D MODEL CONTENT** (content_type: "3d_model"):
- type: "3D_MODEL_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Architectural renderings and building plans", "Product design models and prototypes", "3D animations and character models", "Scientific simulations and visualizations", "Virtual objects for AR/VR applications", "Other types of 3D models"]
**PROCEDURAL CONTENT** (content_type: "procedural"):
- type: "PROCEDURAL_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
- subclass options: ["Tutorials and step-by-step guides", "Workflow and process descriptions", "Simulation and training exercises", "Recipes and crafting instructions", "Other types of procedural content"]
Select the most appropriate content_type, type, and relevant subclasses.
"#
// OpenAI client defined once for all BAML files
// Classification function
function ExtractCategories(content: string) -> DefaultContentPrediction {
client OpenAI
prompt #"
{{ ClassifyContentPrompt() }}
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
{{ _.role('user') }}
{{ content }}
"#
}
// Test case for classification
test ExtractCategoriesExample {
functions [ExtractCategories]
args {
content #"
Natural language processing (NLP) is an interdisciplinary subfield of computer science and information retrieval.
It deals with the interaction between computers and human language, in particular how to program computers to process and analyze large amounts of natural language data.
"#
}
}

View file

@ -0,0 +1,343 @@
class Node {
id string
name string
type string
description string
@@dynamic
}
/// doc string for edge
class Edge {
/// doc string for source_node_id
source_node_id string
target_node_id string
relationship_name string
}
class KnowledgeGraph {
nodes (Node @stream.done)[]
edges Edge[]
}
// Summarization classes
class SummarizedContent {
summary string
description string
}
class SummarizedFunction {
name string
description string
inputs string[]?
outputs string[]?
decorators string[]?
}
class SummarizedClass {
name string
description string
methods SummarizedFunction[]?
decorators string[]?
}
class SummarizedCode {
high_level_summary string
key_features string[]
imports string[]
constants string[]
classes SummarizedClass[]
functions SummarizedFunction[]
workflow_description string?
}
class DynamicKnowledgeGraph {
@@dynamic
}
// Simple template for basic extraction (fast, good quality)
template_string ExtractContentGraphPrompt() #"
You are an advanced algorithm that extracts structured data into a knowledge graph.
- **Nodes**: Entities/concepts (like Wikipedia articles).
- **Edges**: Relationships (like Wikipedia links). Use snake_case (e.g., `acted_in`).
**Rules:**
1. **Node Labeling & IDs**
- Use basic types only (e.g., "Person", "Date", "Organization").
- Avoid overly specific or generic terms (e.g., no "Mathematician" or "Entity").
- Node IDs must be human-readable names from the text (no numbers).
2. **Dates & Numbers**
- Label dates as **"Date"** in "YYYY-MM-DD" format (use available parts if incomplete).
- Properties are key-value pairs; do not use escaped quotes.
3. **Coreference Resolution**
- Use a single, complete identifier for each entity (e.g., always "John Doe" not "Joe" or "he").
4. **Relationship Labels**:
- Use descriptive, lowercase, snake_case names for edges.
- *Example*: born_in, married_to, invented_by.
- Avoid vague or generic labels like isA, relatesTo, has.
- Avoid duplicated relationships like produces, produced by.
5. **Strict Compliance**
- Follow these rules exactly. Non-compliance results in termination.
"#
// Summarization prompt template
template_string SummarizeContentPrompt() #"
You are a top-tier summarization engine. Your task is to summarize text and make it versatile.
Be brief and concise, but keep the important information and the subject.
Use synonym words where possible in order to change the wording but keep the meaning.
"#
// Code summarization prompt template
template_string SummarizeCodePrompt() #"
You are an expert code analyst. Analyze the provided source code and extract key information:
1. Provide a high-level summary of what the code does
2. List key features and functionality
3. Identify imports and dependencies
4. List constants and global variables
5. Summarize classes with their methods
6. Summarize standalone functions
7. Describe the overall workflow if applicable
Be precise and technical while remaining clear and concise.
"#
// Detailed template for complex extraction (slower, higher quality)
template_string DetailedExtractContentGraphPrompt() #"
You are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.
**Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
**Edges** represent relationships between concepts. They're akin to Wikipedia links.
The aim is to achieve simplicity and clarity in the knowledge graph.
# 1. Labeling Nodes
**Consistency**: Ensure you use basic or elementary types for node labels.
- For example, when you identify an entity representing a person, always label it as **"Person"**.
- Avoid using more specific terms like "Mathematician" or "Scientist", keep those as "profession" property.
- Don't use too generic terms like "Entity".
**Node IDs**: Never utilize integers as node IDs.
- Node IDs should be names or human-readable identifiers found in the text.
# 2. Handling Numerical Data and Dates
- For example, when you identify an entity representing a date, make sure it has type **"Date"**.
- Extract the date in the format "YYYY-MM-DD"
- If not possible to extract the whole date, extract month or year, or both if available.
- **Property Format**: Properties must be in a key-value format.
- **Quotation Marks**: Never use escaped single or double quotes within property values.
- **Naming Convention**: Use snake_case for relationship names, e.g., `acted_in`.
# 3. Coreference Resolution
- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.
If an entity, such as "John Doe", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., "Joe", "he"),
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Person's ID.
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
# 4. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination.
"#
// Guided template with step-by-step instructions
template_string GuidedExtractContentGraphPrompt() #"
You are an advanced algorithm designed to extract structured information to build a clean, consistent, and human-readable knowledge graph.
**Objective**:
- Nodes represent entities and concepts, similar to Wikipedia articles.
- Edges represent typed relationships between nodes, similar to Wikipedia hyperlinks.
- The graph must be clear, minimal, consistent, and semantically precise.
**Node Guidelines**:
1. **Label Consistency**:
- Use consistent, basic types for all node labels.
- Do not switch between granular or vague labels for the same kind of entity.
- Pick one label for each category and apply it uniformly.
- Each entity type should be in a singular form and in a case of multiple words separated by whitespaces
2. **Node Identifiers**:
- Node IDs must be human-readable and derived directly from the text.
- Prefer full names and canonical terms.
- Never use integers or autogenerated IDs.
- *Example*: Use "Marie Curie", "Theory of Evolution", "Google".
3. **Coreference Resolution**:
- Maintain one consistent node ID for each real-world entity.
- Resolve aliases, acronyms, and pronouns to the most complete form.
- *Example*: Always use "John Doe" even if later referred to as "Doe" or "he".
**Edge Guidelines**:
4. **Relationship Labels**:
- Use descriptive, lowercase, snake_case names for edges.
- *Example*: born_in, married_to, invented_by.
- Avoid vague or generic labels like isA, relatesTo, has.
5. **Relationship Direction**:
- Edges must be directional and logically consistent.
- *Example*:
- "Marie Curie" —[born_in]→ "Warsaw"
- "Radioactivity" —[discovered_by]→ "Marie Curie"
**Compliance**:
Strict adherence to these guidelines is required. Any deviation will result in immediate termination of the task.
"#
// Strict template with zero-tolerance rules
template_string StrictExtractContentGraphPrompt() #"
You are a top-tier algorithm for **extracting structured information** from unstructured text to build a **knowledge graph**.
Your primary goal is to extract:
- **Nodes**: Representing **entities** and **concepts** (like Wikipedia nodes).
- **Edges**: Representing **relationships** between those concepts (like Wikipedia links).
The resulting knowledge graph must be **simple, consistent, and human-readable**.
## 1. Node Labeling and Identification
### Node Types
Use **basic atomic types** for node labels. Always prefer general types over specific roles or professions:
- "Person" for any human.
- "Organization" for companies, institutions, etc.
- "Location" for geographic or place entities.
- "Date" for any temporal expression.
- "Event" for historical or scheduled occurrences.
- "Work" for books, films, artworks, or research papers.
- "Concept" for abstract notions or ideas.
### Node IDs
- Always assign **human-readable and unambiguous identifiers**.
- Never use numeric or autogenerated IDs.
- Prioritize **most complete form** of entity names for consistency.
## 2. Relationship Handling
- Use **snake_case** for all relationship (edge) types.
- Keep relationship types semantically clear and consistent.
- Avoid vague relation names like "related_to" unless no better alternative exists.
## 3. Strict Compliance
Follow all rules exactly. Any deviation may lead to rejection or incorrect graph construction.
"#
// OpenAI client with environment model selection
client<llm> OpenAI {
provider openai
options {
model client_registry.model
api_key client_registry.api_key
}
}
// Function that returns raw structured output (for custom objects - to be handled in Python)
function ExtractContentGraphGeneric(
content: string,
mode: "simple" | "base" | "guided" | "strict" | "custom"?,
custom_prompt_content: string?
) -> KnowledgeGraph {
client OpenAI
prompt #"
{% if mode == "base" %}
{{ DetailedExtractContentGraphPrompt() }}
{% elif mode == "guided" %}
{{ GuidedExtractContentGraphPrompt() }}
{% elif mode == "strict" %}
{{ StrictExtractContentGraphPrompt() }}
{% elif mode == "custom" and custom_prompt_content %}
{{ custom_prompt_content }}
{% else %}
{{ ExtractContentGraphPrompt() }}
{% endif %}
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
Before answering, briefly describe what you'll extract from the text, then provide the structured output.
Example format:
I'll extract the main entities and their relationships from this text...
{ ... }
{{ _.role('user') }}
{{ content }}
"#
}
// Backward-compatible function specifically for KnowledgeGraph
function ExtractDynamicContentGraph(
content: string,
mode: "simple" | "base" | "guided" | "strict" | "custom"?,
custom_prompt_content: string?
) -> DynamicKnowledgeGraph {
client OpenAI
prompt #"
{% if mode == "base" %}
{{ DetailedExtractContentGraphPrompt() }}
{% elif mode == "guided" %}
{{ GuidedExtractContentGraphPrompt() }}
{% elif mode == "strict" %}
{{ StrictExtractContentGraphPrompt() }}
{% elif mode == "custom" and custom_prompt_content %}
{{ custom_prompt_content }}
{% else %}
{{ ExtractContentGraphPrompt() }}
{% endif %}
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
Before answering, briefly describe what you'll extract from the text, then provide the structured output.
Example format:
I'll extract the main entities and their relationships from this text...
{ ... }
{{ _.role('user') }}
{{ content }}
"#
}
// Summarization functions
function SummarizeContent(content: string) -> SummarizedContent {
client OpenAI
prompt #"
{{ SummarizeContentPrompt() }}
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
{{ _.role('user') }}
{{ content }}
"#
}
function SummarizeCode(content: string) -> SummarizedCode {
client OpenAI
prompt #"
{{ SummarizeCodePrompt() }}
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
{{ _.role('user') }}
{{ content }}
"#
}
test ExtractStrictExample {
functions [ExtractContentGraphGeneric]
args {
content #"
The Python programming language was created by Guido van Rossum in 1991.
"#
mode "strict"
}
}

View file

@ -0,0 +1,2 @@
from .knowledge_graph.extract_content_graph import extract_content_graph
from .extract_summary import extract_summary, extract_code_summary

View file

@ -0,0 +1,114 @@
import os
from typing import Type
from pydantic import BaseModel
from baml_py import ClientRegistry
from cognee.shared.logging_utils import get_logger
from cognee.shared.data_models import SummarizedCode
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
from cognee.infrastructure.llm.structured_output_framework.baml_src.config import get_llm_config
config = get_llm_config()
logger = get_logger("extract_summary_baml")
def get_mock_summarized_code():
"""Local mock function to avoid circular imports."""
return SummarizedCode(
high_level_summary="Mock code summary",
key_features=["Mock feature 1", "Mock feature 2"],
imports=["mock_import"],
constants=["MOCK_CONSTANT"],
classes=[],
functions=[],
workflow_description="Mock workflow description",
)
async def extract_summary(content: str, response_model: Type[BaseModel]):
"""
Extract summary using BAML framework.
Args:
content: The content to summarize
response_model: The Pydantic model type for the response
Returns:
BaseModel: The summarized content in the specified format
"""
config = get_llm_config()
baml_registry = ClientRegistry()
baml_registry.add_llm_client(
name="def",
provider="openai",
options={
"model": config.llm_model,
"temperature": config.llm_temperature,
"api_key": config.llm_api_key,
},
)
baml_registry.set_primary("def")
# Use BAML's SummarizeContent function
summary_result = await b.SummarizeContent(
content, baml_options={"client_registry": baml_registry}
)
# Convert BAML result to the expected response model
if response_model is SummarizedCode:
# If it's asking for SummarizedCode but we got SummarizedContent,
# we need to use SummarizeCode instead
code_result = await b.SummarizeCode(
content, baml_options={"client_registry": config.baml_registry}
)
return code_result
else:
# For other models, return the summary result
return summary_result
async def extract_code_summary(content: str):
"""
Extract code summary using BAML framework with mocking support.
Args:
content: The code content to summarize
Returns:
SummarizedCode: The summarized code information
"""
enable_mocking = os.getenv("MOCK_CODE_SUMMARY", "false")
if isinstance(enable_mocking, bool):
enable_mocking = str(enable_mocking).lower()
enable_mocking = enable_mocking in ("true", "1", "yes")
if enable_mocking:
result = get_mock_summarized_code()
return result
else:
try:
config = get_llm_config()
baml_registry = ClientRegistry()
baml_registry.add_llm_client(
name="def",
provider="openai",
options={
"model": config.llm_model,
"temperature": config.llm_temperature,
"api_key": config.llm_api_key,
},
)
baml_registry.set_primary("def")
result = await b.SummarizeCode(content, baml_options={"client_registry": baml_registry})
except Exception as e:
logger.error(
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e
)
result = get_mock_summarized_code()
return result

View file

@ -0,0 +1,49 @@
from baml_py import ClientRegistry
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.config import get_llm_config
from cognee.shared.logging_utils import get_logger, setup_logging
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
config = get_llm_config()
async def extract_content_graph(
content: str, response_model: Type[BaseModel], mode: str = "simple"
):
config = get_llm_config()
setup_logging()
get_logger(level="INFO")
baml_registry = ClientRegistry()
baml_registry.add_llm_client(
name="extract_content_client",
provider=config.llm_provider,
options={
"model": config.llm_model,
"temperature": config.llm_temperature,
"api_key": config.llm_api_key,
},
)
baml_registry.set_primary("extract_content_client")
# if response_model:
# # tb = TypeBuilder()
# # country = tb.union \
# # ([tb.literal_string("USA"), tb.literal_string("UK"), tb.literal_string("Germany"), tb.literal_string("other")])
# # tb.Node.add_property("country", country)
#
# graph = await b.ExtractDynamicContentGraph(
# content, mode=mode, baml_options={"client_registry": baml_registry}
# )
#
# return graph
# else:
graph = await b.ExtractContentGraphGeneric(
content, mode=mode, baml_options={"client_registry": baml_registry}
)
return graph

View file

@ -0,0 +1,18 @@
// This helps use auto generate libraries you can use in the language of
// your choice. You can have multiple generators if you use multiple languages.
// Just ensure that the output_dir is different for each generator.
generator target {
// Valid values: "python/pydantic", "typescript", "ruby/sorbet", "rest/openapi"
output_type "python/pydantic"
// Where the generated code will be saved (relative to baml_src/)
output_dir "../baml/"
// The version of the BAML package you have installed (e.g. same version as your baml-py or @boundaryml/baml).
// The BAML VSCode extension version should also match this version.
version "0.201.0"
// Valid values: "sync", "async"
// This controls what `b.FunctionName()` will be (sync or async).
default_client_mode sync
}

View file

@ -1,18 +1,12 @@
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
async def extract_categories(content: str, response_model: Type[BaseModel]):
llm_client = get_llm_client()
system_prompt = LLMAdapter.read_query_prompt("classify_content.txt")
system_prompt = read_query_prompt("classify_content.txt")
llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model)
llm_output = await LLMAdapter.acreate_structured_output(content, system_prompt, response_model)
return llm_output

View file

@ -5,12 +5,7 @@ from typing import Type
from instructor.exceptions import InstructorRetryException
from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.shared.data_models import SummarizedCode
logger = get_logger("extract_summary")
@ -30,11 +25,9 @@ def get_mock_summarized_code():
async def extract_summary(content: str, response_model: Type[BaseModel]):
llm_client = get_llm_client()
system_prompt = LLMAdapter.read_query_prompt("summarize_content.txt")
system_prompt = read_query_prompt("summarize_content.txt")
llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model)
llm_output = await LLMAdapter.acreate_structured_output(content, system_prompt, response_model)
return llm_output

View file

@ -1,19 +1,14 @@
import os
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
render_prompt,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.infrastructure.llm.config import (
get_llm_config,
)
async def extract_content_graph(content: str, response_model: Type[BaseModel]):
llm_client = get_llm_client()
llm_config = get_llm_config()
prompt_path = llm_config.graph_prompt_path
@ -27,9 +22,9 @@ async def extract_content_graph(content: str, response_model: Type[BaseModel]):
else:
base_directory = None
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
system_prompt = LLMAdapter.render_prompt(prompt_path, {}, base_directory=base_directory)
content_graph = await llm_client.acreate_structured_output(
content_graph = await LLMAdapter.acreate_structured_output(
content, system_prompt, response_model
)

View file

@ -6,14 +6,13 @@ from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
class AnthropicAdapter(LLMInterface):
"""
@ -92,7 +91,7 @@ class AnthropicAdapter(LLMInterface):
if not system_prompt:
raise InvalidValueError(message="No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt)
system_prompt = LLMAdapter.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""

View file

@ -9,9 +9,7 @@ from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
@ -138,7 +136,7 @@ class GeminiAdapter(LLMInterface):
text_input = "No user input provided."
if not system_prompt:
raise InvalidValueError(message="No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt)
system_prompt = LLMAdapter.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""

View file

@ -1,6 +1,5 @@
"""Adapter for Generic API LLM provider API"""
import logging
import litellm
import instructor
from typing import Type

View file

@ -3,9 +3,7 @@
from typing import Type, Protocol
from abc import abstractmethod
from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
class LLMInterface(Protocol):
@ -59,7 +57,7 @@ class LLMInterface(Protocol):
text_input = "No user input provided."
if not system_prompt:
raise ValueError("No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt)
system_prompt = LLMAdapter.read_query_prompt(system_prompt)
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""

View file

@ -8,9 +8,7 @@ from litellm.exceptions import ContentPolicyViolationError
from instructor.exceptions import InstructorRetryException
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
read_query_prompt,
)
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.llm_interface import (
LLMInterface,
)
@ -328,7 +326,7 @@ class OpenAIAdapter(LLMInterface):
text_input = "No user input provided."
if not system_prompt:
raise InvalidValueError(message="No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt)
system_prompt = LLMAdapter.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""

View file

@ -1,2 +0,0 @@
The question is: `{{ question }}`
And here is the context: `{{ context }}`

View file

@ -0,0 +1,107 @@
import litellm
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
get_llm_client,
)
from cognee.shared.logging_utils import get_logger
logger = get_logger()
def get_max_chunk_tokens():
"""
Calculate the maximum number of tokens allowed in a chunk.
The function determines the maximum chunk size based on the maximum token limit of the
embedding engine and half of the LLM maximum context token size. It ensures that the
chunk size does not exceed these constraints.
Returns:
--------
- int: The maximum number of tokens that can be included in a chunk, determined by
the smaller value of the embedding engine's max tokens and half of the LLM's
maximum tokens.
"""
# NOTE: Import must be done in function to avoid circular import issue
from cognee.infrastructure.databases.vector import get_vector_engine
# Calculate max chunk size based on the following formula
embedding_engine = get_vector_engine().embedding_engine
llm_client = get_llm_client()
# We need to make sure chunk size won't take more than half of LLM max context token size
# but it also can't be bigger than the embedding engine max token size
llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
return max_chunk_tokens
def get_model_max_tokens(model_name: str):
"""
Retrieve the maximum token limit for a specified model name if it exists.
Checks if the provided model name is present in the predefined model cost dictionary. If
found, it logs the maximum token count for that model and returns it. If the model name
is not recognized, it logs an informational message and returns None.
Parameters:
-----------
- model_name (str): Name of LLM or embedding model
Returns:
--------
Number of max tokens of model, or None if model is unknown
"""
max_tokens = None
if model_name in litellm.model_cost:
max_tokens = litellm.model_cost[model_name]["max_tokens"]
logger.debug(f"Max input tokens for {model_name}: {max_tokens}")
else:
logger.info("Model not found in LiteLLM's model_cost.")
return max_tokens
async def test_llm_connection():
"""
Establish a connection to the LLM and create a structured output.
Attempt to connect to the LLM client and uses the adapter to create a structured output
with a predefined text input and system prompt. Log any exceptions encountered during
the connection attempt and re-raise the exception for further handling.
"""
try:
llm_adapter = get_llm_client()
await llm_adapter.acreate_structured_output(
text_input="test",
system_prompt='Respond to me with the following string: "test"',
response_model=str,
)
except Exception as e:
logger.error(e)
logger.error("Connection to LLM could not be established.")
raise e
async def test_embedding_connection():
"""
Test the connection to the embedding engine by embedding a sample text.
Handles exceptions that may occur during the operation, logs the error, and re-raises
the exception if the connection to the embedding handler cannot be established.
"""
try:
# NOTE: Vector engine import must be done in function to avoid circular import issue
from cognee.infrastructure.databases.vector import get_vector_engine
await get_vector_engine().embedding_engine.embed_text("test")
except Exception as e:
logger.error(e)
logger.error("Connection to Embedding handler could not be established.")
raise e

20
poetry.lock generated
View file

@ -616,6 +616,24 @@ files = [
[package.extras]
extras = ["regex"]
[[package]]
name = "baml-py"
version = "0.201.0"
description = "BAML python bindings (pyproject.toml)"
optional = false
python-versions = "*"
groups = ["main"]
files = [
{file = "baml_py-0.201.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:83228d2af2b0e845bbbb4e14f7cbd3376cec385aee01210ac522ab6076e07bec"},
{file = "baml_py-0.201.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:2a9d016139e3ae5b5ce98c7b05b5fbd53d5d38f04dc810ec4d70fb17dd6c10e4"},
{file = "baml_py-0.201.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5058505b1a3c5f04fc1679aec4d730fa9bef2cbd96209b3ed50152f60b96baf"},
{file = "baml_py-0.201.0-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:36289d548581ba4accd5eaaab3246872542dd32dc6717e537654fa0cad884071"},
{file = "baml_py-0.201.0-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:5ab70e7bd6481d71edca8a33313347b29faccec78b9960138aa437522813ac9a"},
{file = "baml_py-0.201.0-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7efc5c693a7142c230a4f3d6700415127fee0b9f5fdbb36db63e04e27ac4c0f1"},
{file = "baml_py-0.201.0-cp38-abi3-win_amd64.whl", hash = "sha256:56499857b7a27ae61a661c8ce0dddd0fb567a45c0b826157e44048a14cf586f9"},
{file = "baml_py-0.201.0-cp38-abi3-win_arm64.whl", hash = "sha256:1e52dc1151db84a302b746590fe2bc484bdd794f83fa5da7216d9394c559f33a"},
]
[[package]]
name = "banks"
version = "2.2.0"
@ -12235,4 +12253,4 @@ qdrant = ["qdrant-client"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<=3.13"
content-hash = "d15e6b5d065016613be0b8c015cccf85e7f63891353c97636d136d14e5c8f62e"
content-hash = "111afa2da80fc4baf4cfd72c8f8c782aec2cd0e4435b3e4a3b77e2536eaa37aa"

17
uv.lock generated
View file

@ -443,6 +443,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/41/ff/392bff89415399a979be4a65357a41d92729ae8580a66073d8ec8d810f98/backrefs-5.9-py39-none-any.whl", hash = "sha256:f48ee18f6252b8f5777a22a00a09a85de0ca931658f1dd96d4406a34f3748c60", size = 380265, upload-time = "2025-06-22T19:34:12.405Z" },
]
[[package]]
name = "baml-py"
version = "0.201.0"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/54/2b0edb3d22e95ce56f36610391c11108a4ef26ba2837736a32001687ae34/baml_py-0.201.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:83228d2af2b0e845bbbb4e14f7cbd3376cec385aee01210ac522ab6076e07bec", size = 17387971, upload-time = "2025-07-03T19:29:05.844Z" },
{ url = "https://files.pythonhosted.org/packages/c9/08/1d48c28c63eadea2c04360cbb7f64968599e99cd6b8fc0ec0bd4424d3cf1/baml_py-0.201.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:2a9d016139e3ae5b5ce98c7b05b5fbd53d5d38f04dc810ec4d70fb17dd6c10e4", size = 16191010, upload-time = "2025-07-03T19:29:09.323Z" },
{ url = "https://files.pythonhosted.org/packages/73/1a/20b2d46501e3dd0648af339825106a6ac5eeb5d22d7e6a10cf16b9aa1cb8/baml_py-0.201.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5058505b1a3c5f04fc1679aec4d730fa9bef2cbd96209b3ed50152f60b96baf", size = 19950249, upload-time = "2025-07-03T19:29:11.974Z" },
{ url = "https://files.pythonhosted.org/packages/38/24/bc871059e905159ae1913c2e3032dd6ef2f5c3d0983999d2c2f1eebb65a4/baml_py-0.201.0-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:36289d548581ba4accd5eaaab3246872542dd32dc6717e537654fa0cad884071", size = 19231310, upload-time = "2025-07-03T19:29:14.857Z" },
{ url = "https://files.pythonhosted.org/packages/0e/11/4268a0b82b02c7202fe5aa0d7175712158d998c491cac723b2bac3d5d495/baml_py-0.201.0-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:5ab70e7bd6481d71edca8a33313347b29faccec78b9960138aa437522813ac9a", size = 19490012, upload-time = "2025-07-03T19:29:18.512Z" },
{ url = "https://files.pythonhosted.org/packages/31/21/c9f9aea1adba2a5978ffab11ba0948a9f3f81ec6ed3056067713260e93a1/baml_py-0.201.0-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7efc5c693a7142c230a4f3d6700415127fee0b9f5fdbb36db63e04e27ac4c0f1", size = 20090620, upload-time = "2025-07-03T19:29:21.072Z" },
{ url = "https://files.pythonhosted.org/packages/99/cf/92123d8d753f1d1473e080c4c182139bfe3b9a6418e891cf1d96b6c33848/baml_py-0.201.0-cp38-abi3-win_amd64.whl", hash = "sha256:56499857b7a27ae61a661c8ce0dddd0fb567a45c0b826157e44048a14cf586f9", size = 17253005, upload-time = "2025-07-03T19:29:23.722Z" },
{ url = "https://files.pythonhosted.org/packages/59/88/5056aa1bc9480f758cd6e210d63bd1f9ad90b44c87f4121285906526495e/baml_py-0.201.0-cp38-abi3-win_arm64.whl", hash = "sha256:1e52dc1151db84a302b746590fe2bc484bdd794f83fa5da7216d9394c559f33a", size = 15612701, upload-time = "2025-07-03T19:29:26.712Z" },
]
[[package]]
name = "banks"
version = "2.2.0"
@ -864,6 +879,7 @@ dependencies = [
{ name = "aiohttp" },
{ name = "aiosqlite" },
{ name = "alembic" },
{ name = "baml-py" },
{ name = "dlt", extra = ["sqlalchemy"] },
{ name = "fastapi" },
{ name = "fastapi-users", extra = ["sqlalchemy"] },
@ -1020,6 +1036,7 @@ requires-dist = [
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.26.1,<0.27" },
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.30.0,<1.0.0" },
{ name = "asyncpg", marker = "extra == 'postgres-binary'", specifier = ">=0.30.0,<1.0.0" },
{ name = "baml-py", specifier = ">=0.201.0,<0.202.0" },
{ name = "chromadb", marker = "extra == 'chromadb'", specifier = ">=0.3.0,<0.7" },
{ name = "coverage", marker = "extra == 'dev'", specifier = ">=7.3.2,<8" },
{ name = "debugpy", marker = "extra == 'debug'", specifier = ">=1.8.9,<2.0.0" },