refactor: move prompts to LLM folder
This commit is contained in:
parent
7761b70229
commit
cb21e08ad9
65 changed files with 1691 additions and 64 deletions
|
|
@ -0,0 +1,548 @@
|
|||
import threading
|
||||
import logging
|
||||
import functools
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import random
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Common error patterns that indicate rate limiting
|
||||
RATE_LIMIT_ERROR_PATTERNS = [
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"ratelimit",
|
||||
"too many requests",
|
||||
"retry after",
|
||||
"capacity",
|
||||
"quota",
|
||||
"limit exceeded",
|
||||
"tps limit exceeded",
|
||||
"request limit exceeded",
|
||||
"maximum requests",
|
||||
"exceeded your current quota",
|
||||
"throttled",
|
||||
"throttling",
|
||||
]
|
||||
|
||||
# Default retry settings
|
||||
DEFAULT_MAX_RETRIES = 5
|
||||
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
|
||||
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
|
||||
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
|
||||
|
||||
|
||||
class EmbeddingRateLimiter:
|
||||
"""
|
||||
Rate limiter for embedding API calls.
|
||||
|
||||
This class implements a singleton pattern to ensure that rate limiting
|
||||
is consistent across all embedding requests. It uses the limits
|
||||
library with a moving window strategy to control request rates.
|
||||
|
||||
The rate limiter uses the same configuration as the LLM API rate limiter
|
||||
but uses a separate key to track embedding API calls independently.
|
||||
|
||||
Public Methods:
|
||||
- get_instance
|
||||
- reset_instance
|
||||
- hit_limit
|
||||
- wait_if_needed
|
||||
- async_wait_if_needed
|
||||
|
||||
Instance Variables:
|
||||
- enabled
|
||||
- requests_limit
|
||||
- interval_seconds
|
||||
- request_times
|
||||
- lock
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""
|
||||
Retrieve the singleton instance of the EmbeddingRateLimiter.
|
||||
|
||||
This method ensures that only one instance of the class exists and
|
||||
is thread-safe. It lazily initializes the instance if it doesn't
|
||||
already exist.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The singleton instance of the EmbeddingRateLimiter class.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls.lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls):
|
||||
"""
|
||||
Reset the singleton instance of the EmbeddingRateLimiter.
|
||||
|
||||
This method is thread-safe and sets the instance to None, allowing
|
||||
for a new instance to be created when requested again.
|
||||
"""
|
||||
with cls.lock:
|
||||
cls._instance = None
|
||||
|
||||
def __init__(self):
|
||||
config = get_llm_config()
|
||||
self.enabled = config.embedding_rate_limit_enabled
|
||||
self.requests_limit = config.embedding_rate_limit_requests
|
||||
self.interval_seconds = config.embedding_rate_limit_interval
|
||||
self.request_times = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
logging.info(
|
||||
f"EmbeddingRateLimiter initialized: enabled={self.enabled}, "
|
||||
f"requests_limit={self.requests_limit}, interval_seconds={self.interval_seconds}"
|
||||
)
|
||||
|
||||
def hit_limit(self) -> bool:
|
||||
"""
|
||||
Check if the current request would exceed the rate limit.
|
||||
|
||||
This method checks if the rate limiter is enabled and evaluates
|
||||
the number of requests made in the elapsed interval.
|
||||
|
||||
Returns:
|
||||
- bool: True if the rate limit would be exceeded, False otherwise.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the rate limit would be exceeded, otherwise False.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
with self.lock:
|
||||
# Remove expired request times
|
||||
cutoff_time = current_time - self.interval_seconds
|
||||
self.request_times = [t for t in self.request_times if t > cutoff_time]
|
||||
|
||||
# Check if adding a new request would exceed the limit
|
||||
if len(self.request_times) >= self.requests_limit:
|
||||
logger.info(
|
||||
f"Rate limit hit: {len(self.request_times)} requests in the last {self.interval_seconds} seconds"
|
||||
)
|
||||
return True
|
||||
|
||||
# Otherwise, we're under the limit
|
||||
return False
|
||||
|
||||
def wait_if_needed(self) -> float:
|
||||
"""
|
||||
Block until a request can be made without exceeding the rate limit.
|
||||
|
||||
This method will wait if the current request would exceed the
|
||||
rate limit and returns the time waited in seconds.
|
||||
|
||||
Returns:
|
||||
- float: Time waited in seconds before a request is allowed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- float: Time waited in seconds before proceeding.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
wait_time = 0
|
||||
start_time = time.time()
|
||||
|
||||
while self.hit_limit():
|
||||
time.sleep(0.5) # Poll every 0.5 seconds
|
||||
wait_time = time.time() - start_time
|
||||
|
||||
# Record this request
|
||||
with self.lock:
|
||||
self.request_times.append(time.time())
|
||||
|
||||
return wait_time
|
||||
|
||||
async def async_wait_if_needed(self) -> float:
|
||||
"""
|
||||
Asynchronously wait until a request can be made without exceeding the rate limit.
|
||||
|
||||
This method will wait if the current request would exceed the
|
||||
rate limit and returns the time waited in seconds.
|
||||
|
||||
Returns:
|
||||
- float: Time waited in seconds before a request is allowed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- float: Time waited in seconds before proceeding.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
wait_time = 0
|
||||
start_time = time.time()
|
||||
|
||||
while self.hit_limit():
|
||||
await asyncio.sleep(0.5) # Poll every 0.5 seconds
|
||||
wait_time = time.time() - start_time
|
||||
|
||||
# Record this request
|
||||
with self.lock:
|
||||
self.request_times.append(time.time())
|
||||
|
||||
return wait_time
|
||||
|
||||
|
||||
def embedding_rate_limit_sync(func):
|
||||
"""
|
||||
Apply rate limiting to a synchronous embedding function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: Function to decorate with rate limiting logic.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the decorated function that applies rate limiting.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Wrap the given function with rate limiting logic to control the embedding API usage.
|
||||
|
||||
Checks if the rate limit has been exceeded before allowing the function to execute. If
|
||||
the limit is hit, it logs a warning and raises an EmbeddingException. Otherwise, it
|
||||
updates the request count and proceeds to call the original function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Variable length argument list for the wrapped function.
|
||||
- **kwargs: Keyword arguments for the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function if rate limiting conditions are met.
|
||||
"""
|
||||
limiter = EmbeddingRateLimiter.get_instance()
|
||||
|
||||
# Check if rate limiting is enabled and if we're at the limit
|
||||
if limiter.hit_limit():
|
||||
error_msg = "Embedding API rate limit exceeded"
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Create a custom embedding rate limit exception
|
||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import (
|
||||
EmbeddingException,
|
||||
)
|
||||
|
||||
raise EmbeddingException(error_msg)
|
||||
|
||||
# Add this request to the counter and proceed
|
||||
limiter.wait_if_needed()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def embedding_rate_limit_async(func):
|
||||
"""
|
||||
Decorator that applies rate limiting to an asynchronous embedding function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: Async function to decorate.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the decorated async function that applies rate limiting.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Handle function calls with embedding rate limiting.
|
||||
|
||||
This asynchronous wrapper checks if the embedding API rate limit is exceeded before
|
||||
allowing the function to execute. If the limit is exceeded, it logs a warning and raises
|
||||
an EmbeddingException. If not, it waits as necessary and proceeds with the function
|
||||
call.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function after handling rate limiting.
|
||||
"""
|
||||
limiter = EmbeddingRateLimiter.get_instance()
|
||||
|
||||
# Check if rate limiting is enabled and if we're at the limit
|
||||
if limiter.hit_limit():
|
||||
error_msg = "Embedding API rate limit exceeded"
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Create a custom embedding rate limit exception
|
||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import (
|
||||
EmbeddingException,
|
||||
)
|
||||
|
||||
raise EmbeddingException(error_msg)
|
||||
|
||||
# Add this request to the counter and proceed
|
||||
await limiter.async_wait_if_needed()
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def embedding_sleep_and_retry_sync(max_retries=5, base_backoff=1.0, jitter=0.5):
|
||||
"""
|
||||
Add retry with exponential backoff for synchronous embedding functions.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- max_retries: Maximum number of retries before giving up. (default 5)
|
||||
- base_backoff: Base backoff time in seconds for retry intervals. (default 1.0)
|
||||
- jitter: Jitter factor to randomize the backoff time to avoid collision. (default
|
||||
0.5)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
A decorator that retries the wrapped function on rate limit errors, applying
|
||||
exponential backoff with jitter.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
"""
|
||||
Wraps a function to apply retry logic on rate limit errors.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: The function to be wrapped with retry logic.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the wrapped function with retry logic applied.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Retry the execution of a function with backoff on failure due to rate limit errors.
|
||||
|
||||
This wrapper function will call the specified function and if it raises an exception, it
|
||||
will handle retries according to defined conditions. It will check the environment for a
|
||||
DISABLE_RETRIES flag to determine whether to retry or propagate errors immediately
|
||||
during tests. If the error is identified as a rate limit error, it will apply an
|
||||
exponential backoff strategy with jitter before retrying, up to a maximum number of
|
||||
retries. If the retries are exhausted, it raises the last encountered error.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function if successful; otherwise, raises the last
|
||||
error encountered after maximum retries are exhausted.
|
||||
"""
|
||||
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
||||
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
retries = 0
|
||||
last_error = None
|
||||
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Check if this is a rate limit error
|
||||
error_str = str(e).lower()
|
||||
error_type = type(e).__name__
|
||||
is_rate_limit = any(
|
||||
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
||||
)
|
||||
|
||||
if disable_retries:
|
||||
# For testing, propagate the exception immediately
|
||||
raise
|
||||
|
||||
if is_rate_limit and retries < max_retries:
|
||||
# Calculate backoff with jitter
|
||||
backoff = (
|
||||
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
||||
f"(attempt {retries + 1}/{max_retries}): "
|
||||
f"({error_str!r}, {error_type!r})"
|
||||
)
|
||||
|
||||
time.sleep(backoff)
|
||||
retries += 1
|
||||
last_error = e
|
||||
else:
|
||||
# Not a rate limit error or max retries reached, raise
|
||||
raise
|
||||
|
||||
# If we exit the loop due to max retries, raise the last error
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def embedding_sleep_and_retry_async(max_retries=5, base_backoff=1.0, jitter=0.5):
|
||||
"""
|
||||
Add retry logic with exponential backoff for asynchronous embedding functions.
|
||||
|
||||
This decorator retries the wrapped asynchronous function upon encountering rate limit
|
||||
errors, utilizing exponential backoff with optional jitter to space out retry attempts.
|
||||
It allows for a maximum number of retries before giving up and raising the last error
|
||||
encountered.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- max_retries: Maximum number of retries allowed before giving up. (default 5)
|
||||
- base_backoff: Base amount of time in seconds to wait before retrying after a rate
|
||||
limit error. (default 1.0)
|
||||
- jitter: Amount of randomness to add to the backoff duration to help mitigate burst
|
||||
issues on retries. (default 0.5)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns a decorated asynchronous function that implements the retry logic on rate
|
||||
limit errors.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
"""
|
||||
Handle retries for an async function with exponential backoff and jitter.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: An asynchronous function to be wrapped with retry logic.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the wrapper function that manages the retry behavior for the wrapped async
|
||||
function.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Handle retries for an async function with exponential backoff and jitter.
|
||||
|
||||
If the environment variable DISABLE_RETRIES is set to true, 1, or yes, the function will
|
||||
not retry on errors.
|
||||
It attempts to call the wrapped function until it succeeds or the maximum number of
|
||||
retries is reached. If an exception occurs, it checks if it's a rate limit error to
|
||||
determine if a retry is needed.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped async function if successful; raises the last
|
||||
encountered error if all retries fail.
|
||||
"""
|
||||
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
||||
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
retries = 0
|
||||
last_error = None
|
||||
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Check if this is a rate limit error
|
||||
error_str = str(e).lower()
|
||||
error_type = type(e).__name__
|
||||
is_rate_limit = any(
|
||||
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
||||
)
|
||||
|
||||
if disable_retries:
|
||||
# For testing, propagate the exception immediately
|
||||
raise
|
||||
|
||||
if is_rate_limit and retries < max_retries:
|
||||
# Calculate backoff with jitter
|
||||
backoff = (
|
||||
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
||||
f"(attempt {retries + 1}/{max_retries}): "
|
||||
f"({error_str!r}, {error_type!r})"
|
||||
)
|
||||
|
||||
await asyncio.sleep(backoff)
|
||||
retries += 1
|
||||
last_error = e
|
||||
else:
|
||||
# Not a rate limit error or max retries reached, raise
|
||||
raise
|
||||
|
||||
# If we exit the loop due to max retries, raise the last error
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
136
cognee/infrastructure/llm/LLMAdapter.py
Normal file
136
cognee/infrastructure/llm/LLMAdapter.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm import get_llm_config
|
||||
|
||||
|
||||
# TODO: Check if Coroutines should be returned or awaited result values
|
||||
class LLMAdapter:
|
||||
"""
|
||||
Class handles selection of structured output frameworks and LLM functions.
|
||||
Class used as a namespace for LLM related functions, should not be instantiated, all methods are static.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def render_prompt(filename: str, context: dict, base_directory: str = None):
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
|
||||
return render_prompt(filename=filename, context=context, base_directory=base_directory)
|
||||
|
||||
@staticmethod
|
||||
def acreate_structured_output(
|
||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
return llm_client.acreate_structured_output(
|
||||
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_structured_output(
|
||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
return llm_client.create_structured_output(
|
||||
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_transcript(input):
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
return llm_client.create_transcript(input=input)
|
||||
|
||||
@staticmethod
|
||||
def transcribe_image(input):
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
return llm_client.transcribe_image(input=input)
|
||||
|
||||
@staticmethod
|
||||
def show_prompt(text_input: str, system_prompt: str) -> str:
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
return llm_client.show_prompt(text_input=text_input, system_prompt=system_prompt)
|
||||
|
||||
@staticmethod
|
||||
def read_query_prompt(prompt_file_name: str, base_directory: str = None):
|
||||
from cognee.infrastructure.llm.prompts import (
|
||||
read_query_prompt,
|
||||
)
|
||||
|
||||
return read_query_prompt(prompt_file_name=prompt_file_name, base_directory=base_directory)
|
||||
|
||||
@staticmethod
|
||||
def extract_content_graph(content: str, response_model: Type[BaseModel], mode: str = "simple"):
|
||||
llm_config = get_llm_config()
|
||||
if llm_config.structured_output_framework.upper() == "BAML":
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
|
||||
extract_content_graph,
|
||||
)
|
||||
|
||||
return extract_content_graph(content=content, response_model=response_model, mode=mode)
|
||||
else:
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import (
|
||||
extract_content_graph,
|
||||
)
|
||||
|
||||
return extract_content_graph(content=content, response_model=response_model)
|
||||
|
||||
@staticmethod
|
||||
def extract_categories(content: str, response_model: Type[BaseModel]):
|
||||
# TODO: Add BAML version of category and extraction and update function
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import (
|
||||
extract_categories,
|
||||
)
|
||||
|
||||
return extract_categories(content=content, response_model=response_model)
|
||||
|
||||
@staticmethod
|
||||
def extract_code_summary(content: str):
|
||||
llm_config = get_llm_config()
|
||||
if llm_config.structured_output_framework == "BAML":
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
|
||||
extract_code_summary,
|
||||
)
|
||||
|
||||
return extract_code_summary(content=content)
|
||||
else:
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import (
|
||||
extract_code_summary,
|
||||
)
|
||||
|
||||
return extract_code_summary(content=content)
|
||||
|
||||
@staticmethod
|
||||
def extract_summary(content: str, response_model: Type[BaseModel]):
|
||||
llm_config = get_llm_config()
|
||||
if llm_config.structured_output_framework == "BAML":
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
|
||||
extract_summary,
|
||||
)
|
||||
|
||||
return extract_summary(content=content, response_model=response_model)
|
||||
else:
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.extraction import (
|
||||
extract_summary,
|
||||
)
|
||||
|
||||
return extract_summary(content=content, response_model=response_model)
|
||||
|
|
@ -11,4 +11,4 @@ from cognee.infrastructure.llm.utils import (
|
|||
test_embedding_connection,
|
||||
)
|
||||
|
||||
from LLMAdapter import LLMAdapter
|
||||
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
|
||||
|
|
|
|||
194
cognee/infrastructure/llm/config.py
Normal file
194
cognee/infrastructure/llm/config.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
import os
|
||||
from typing import Optional, ClassVar
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import model_validator
|
||||
from baml_py import ClientRegistry
|
||||
|
||||
|
||||
class LLMConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for the LLM (Large Language Model) provider and related options.
|
||||
|
||||
Public instance variables include:
|
||||
- llm_provider
|
||||
- llm_model
|
||||
- llm_endpoint
|
||||
- llm_api_key
|
||||
- llm_api_version
|
||||
- llm_temperature
|
||||
- llm_streaming
|
||||
- llm_max_tokens
|
||||
- transcription_model
|
||||
- graph_prompt_path
|
||||
- llm_rate_limit_enabled
|
||||
- llm_rate_limit_requests
|
||||
- llm_rate_limit_interval
|
||||
- embedding_rate_limit_enabled
|
||||
- embedding_rate_limit_requests
|
||||
- embedding_rate_limit_interval
|
||||
|
||||
Public methods include:
|
||||
- ensure_env_vars_for_ollama
|
||||
- to_dict
|
||||
"""
|
||||
|
||||
structured_output_framework: str = "instructor"
|
||||
llm_provider: str = "openai"
|
||||
llm_model: str = "gpt-4o-mini"
|
||||
llm_endpoint: str = ""
|
||||
llm_api_key: Optional[str] = None
|
||||
llm_api_version: Optional[str] = None
|
||||
llm_temperature: float = 0.0
|
||||
llm_streaming: bool = False
|
||||
llm_max_tokens: int = 16384
|
||||
transcription_model: str = "whisper-1"
|
||||
graph_prompt_path: str = "generate_graph_prompt.txt"
|
||||
llm_rate_limit_enabled: bool = False
|
||||
llm_rate_limit_requests: int = 60
|
||||
llm_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
|
||||
embedding_rate_limit_enabled: bool = False
|
||||
embedding_rate_limit_requests: int = 60
|
||||
embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
|
||||
|
||||
fallback_api_key: str = ""
|
||||
fallback_endpoint: str = ""
|
||||
fallback_model: str = ""
|
||||
|
||||
baml_registry: ClassVar[ClientRegistry] = ClientRegistry()
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
"""Initialize the BAML registry after the model is created."""
|
||||
self.baml_registry.add_llm_client(
|
||||
name=self.llm_provider,
|
||||
provider=self.llm_provider,
|
||||
options={
|
||||
"model": self.llm_model,
|
||||
"temperature": self.llm_temperature,
|
||||
"api_key": self.llm_api_key,
|
||||
},
|
||||
)
|
||||
# Sets the primary client
|
||||
self.baml_registry.set_primary(self.llm_provider)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_env_vars_for_ollama(self) -> "LLMConfig":
|
||||
"""
|
||||
Validate required environment variables for the 'ollama' LLM provider.
|
||||
|
||||
Raises ValueError if some required environment variables are set without the others.
|
||||
Only checks are performed when 'llm_provider' is set to 'ollama'.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- 'LLMConfig': The instance of LLMConfig after validation.
|
||||
"""
|
||||
|
||||
if self.llm_provider != "ollama":
|
||||
# Skip checks unless provider is "ollama"
|
||||
return self
|
||||
|
||||
def is_env_set(var_name: str) -> bool:
|
||||
"""
|
||||
Check if a given environment variable is set and non-empty.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- var_name (str): The name of the environment variable to check.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the environment variable exists and is not empty, otherwise False.
|
||||
"""
|
||||
val = os.environ.get(var_name)
|
||||
return val is not None and val.strip() != ""
|
||||
|
||||
#
|
||||
# 1. Check LLM environment variables
|
||||
#
|
||||
llm_env_vars = {
|
||||
"LLM_MODEL": is_env_set("LLM_MODEL"),
|
||||
"LLM_ENDPOINT": is_env_set("LLM_ENDPOINT"),
|
||||
"LLM_API_KEY": is_env_set("LLM_API_KEY"),
|
||||
}
|
||||
if any(llm_env_vars.values()) and not all(llm_env_vars.values()):
|
||||
missing_llm = [key for key, is_set in llm_env_vars.items() if not is_set]
|
||||
raise ValueError(
|
||||
"You have set some but not all of the required environment variables "
|
||||
f"for LLM usage (LLM_MODEL, LLM_ENDPOINT, LLM_API_KEY). Missing: {missing_llm}"
|
||||
)
|
||||
|
||||
#
|
||||
# 2. Check embedding environment variables
|
||||
#
|
||||
embedding_env_vars = {
|
||||
"EMBEDDING_PROVIDER": is_env_set("EMBEDDING_PROVIDER"),
|
||||
"EMBEDDING_MODEL": is_env_set("EMBEDDING_MODEL"),
|
||||
"EMBEDDING_DIMENSIONS": is_env_set("EMBEDDING_DIMENSIONS"),
|
||||
"HUGGINGFACE_TOKENIZER": is_env_set("HUGGINGFACE_TOKENIZER"),
|
||||
}
|
||||
if any(embedding_env_vars.values()) and not all(embedding_env_vars.values()):
|
||||
missing_embed = [key for key, is_set in embedding_env_vars.items() if not is_set]
|
||||
raise ValueError(
|
||||
"You have set some but not all of the required environment variables "
|
||||
"for embeddings (EMBEDDING_PROVIDER, EMBEDDING_MODEL, "
|
||||
"EMBEDDING_DIMENSIONS, HUGGINGFACE_TOKENIZER). Missing: "
|
||||
f"{missing_embed}"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Convert the LLMConfig instance into a dictionary representation.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- dict: A dictionary containing the configuration settings of the LLMConfig
|
||||
instance.
|
||||
"""
|
||||
return {
|
||||
"provider": self.llm_provider,
|
||||
"model": self.llm_model,
|
||||
"endpoint": self.llm_endpoint,
|
||||
"api_key": self.llm_api_key,
|
||||
"api_version": self.llm_api_version,
|
||||
"temperature": self.llm_temperature,
|
||||
"streaming": self.llm_streaming,
|
||||
"max_tokens": self.llm_max_tokens,
|
||||
"transcription_model": self.transcription_model,
|
||||
"graph_prompt_path": self.graph_prompt_path,
|
||||
"rate_limit_enabled": self.llm_rate_limit_enabled,
|
||||
"rate_limit_requests": self.llm_rate_limit_requests,
|
||||
"rate_limit_interval": self.llm_rate_limit_interval,
|
||||
"embedding_rate_limit_enabled": self.embedding_rate_limit_enabled,
|
||||
"embedding_rate_limit_requests": self.embedding_rate_limit_requests,
|
||||
"embedding_rate_limit_interval": self.embedding_rate_limit_interval,
|
||||
"fallback_api_key": self.fallback_api_key,
|
||||
"fallback_endpoint": self.fallback_endpoint,
|
||||
"fallback_model": self.fallback_model,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_llm_config():
|
||||
"""
|
||||
Retrieve and cache the LLM configuration.
|
||||
|
||||
This function returns an instance of the LLMConfig class. It leverages
|
||||
caching to ensure that repeated calls do not create new instances,
|
||||
but instead return the already created configuration object.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- LLMConfig: An instance of the LLMConfig class containing the configuration for the
|
||||
LLM.
|
||||
"""
|
||||
return LLMConfig()
|
||||
|
|
@ -1,2 +1,2 @@
|
|||
Answer the question using the provided context. Be as brief as possible.
|
||||
Each entry in the context is a paragraph, which is represented as a list with two elements [title, sentences] and sentences is a list of strings.
|
||||
Each entry in the context is a paragraph, which is represented as a list with two elements [title, sentences] and sentences is a list of strings.
|
||||
|
|
@ -1,2 +1,2 @@
|
|||
Answer the question using the provided context. Be as brief as possible.
|
||||
Each entry in the context is tuple of length 3, representing an edge of a knowledge graph with its two nodes.
|
||||
Each entry in the context is tuple of length 3, representing an edge of a knowledge graph with its two nodes.
|
||||
|
|
@ -1 +1 @@
|
|||
Answer the question using the provided context. Be as brief as possible.
|
||||
Answer the question using the provided context. Be as brief as possible.
|
||||
|
|
@ -1 +1 @@
|
|||
Answer the question using the provided context. If the provided context is not connected to the question, just answer "The provided knowledge base does not contain the answer to the question". Be as brief as possible.
|
||||
Answer the question using the provided context. If the provided context is not connected to the question, just answer "The provided knowledge base does not contain the answer to the question". Be as brief as possible.
|
||||
|
|
@ -1,2 +1,2 @@
|
|||
Chose the summary that is the most relevant to the query`{{ query }}`
|
||||
Here are the categories:`{{ categories }}`
|
||||
Here are the categories:`{{ categories }}`
|
||||
|
|
@ -174,4 +174,4 @@ The possible classifications are:
|
|||
"Recipes and crafting instructions"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
The question is: `{{ question }}`
|
||||
And here is the context: `{{ context }}`
|
||||
|
|
@ -1,2 +1,2 @@
|
|||
The question is: `{{ question }}`
|
||||
and here is the context provided with a set of relationships from a knowledge graph separated by \n---\n each represented as node1 -- relation -- node2 triplet: `{{ context }}`
|
||||
and here is the context provided with a set of relationships from a knowledge graph separated by \n---\n each represented as node1 -- relation -- node2 triplet: `{{ context }}`
|
||||
|
|
@ -63,4 +63,4 @@ This queries doesn't work. Do NOT use them:
|
|||
Example 1:
|
||||
Get all nodes connected to John
|
||||
MATCH (n:Entity {'name': 'John'})--(neighbor)
|
||||
RETURN n, neighbor
|
||||
RETURN n, neighbor
|
||||
|
|
@ -1,2 +1,2 @@
|
|||
I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.
|
||||
Please respond with a single patch file in the following format.
|
||||
Please respond with a single patch file in the following format.
|
||||
|
|
@ -25,9 +25,7 @@ def render_prompt(filename: str, context: dict, base_directory: str = None) -> s
|
|||
|
||||
# Set the base directory relative to the cognee root directory
|
||||
if base_directory is None:
|
||||
base_directory = get_absolute_path(
|
||||
"./infrastructure/llm/structured_output_framework/llitellm_instructor/llm/prompts"
|
||||
)
|
||||
base_directory = get_absolute_path("./infrastructure/llm/prompts")
|
||||
|
||||
# Initialize the Jinja2 environment to load templates from the filesystem
|
||||
env = Environment(
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
You are an expert Python programmer and technical writer. Your task is to summarize the given Python code snippet or file.
|
||||
The code may contain multiple imports, classes, functions, constants and logic. Provide a clear, structured explanation of its components
|
||||
The code may contain multiple imports, classes, functions, constants and logic. Provide a clear, structured explanation of its components
|
||||
and their relationships.
|
||||
|
||||
Instructions:
|
||||
|
|
@ -7,4 +7,4 @@ Provide an overview: Start with a high-level summary of what the code does as a
|
|||
Break it down: Summarize each class and function individually, explaining their purpose and how they interact.
|
||||
Describe the workflow: Outline how the classes and functions work together. Mention any control flow (e.g., main functions, entry points, loops).
|
||||
Key features: Highlight important elements like arguments, return values, or unique logic.
|
||||
Maintain clarity: Write in plain English for someone familiar with Python but unfamiliar with this code.
|
||||
Maintain clarity: Write in plain English for someone familiar with Python but unfamiliar with this code.
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
// Content classification data models - matching shared/data_models.py
|
||||
class TextContent {
|
||||
type string
|
||||
subclass string[]
|
||||
}
|
||||
|
||||
class AudioContent {
|
||||
type string
|
||||
subclass string[]
|
||||
}
|
||||
|
||||
class ImageContent {
|
||||
type string
|
||||
subclass string[]
|
||||
}
|
||||
|
||||
class VideoContent {
|
||||
type string
|
||||
subclass string[]
|
||||
}
|
||||
|
||||
class MultimediaContent {
|
||||
type string
|
||||
subclass string[]
|
||||
}
|
||||
|
||||
class Model3DContent {
|
||||
type string
|
||||
subclass string[]
|
||||
}
|
||||
|
||||
class ProceduralContent {
|
||||
type string
|
||||
subclass string[]
|
||||
}
|
||||
|
||||
class ContentLabel {
|
||||
content_type "text" | "audio" | "image" | "video" | "multimedia" | "3d_model" | "procedural"
|
||||
type string
|
||||
subclass string[]
|
||||
}
|
||||
|
||||
class DefaultContentPrediction {
|
||||
label ContentLabel
|
||||
}
|
||||
|
||||
// Content classification prompt template
|
||||
template_string ClassifyContentPrompt() #"
|
||||
You are a classification engine and should classify content. Make sure to use one of the existing classification options and not invent your own.
|
||||
|
||||
Classify the content into one of these main categories and their relevant subclasses:
|
||||
|
||||
**TEXT CONTENT** (content_type: "text"):
|
||||
- type: "TEXTUAL_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
|
||||
- subclass options: ["Articles, essays, and reports", "Books and manuscripts", "News stories and blog posts", "Research papers and academic publications", "Social media posts and comments", "Website content and product descriptions", "Personal narratives and stories", "Spreadsheets and tables", "Forms and surveys", "Databases and CSV files", "Source code in various programming languages", "Shell commands and scripts", "Markup languages (HTML, XML)", "Stylesheets (CSS) and configuration files (YAML, JSON, INI)", "Chat transcripts and messaging history", "Customer service logs and interactions", "Conversational AI training data", "Textbook content and lecture notes", "Exam questions and academic exercises", "E-learning course materials", "Poetry and prose", "Scripts for plays, movies, and television", "Song lyrics", "Manuals and user guides", "Technical specifications and API documentation", "Helpdesk articles and FAQs", "Contracts and agreements", "Laws, regulations, and legal case documents", "Policy documents and compliance materials", "Clinical trial reports", "Patient records and case notes", "Scientific journal articles", "Financial reports and statements", "Business plans and proposals", "Market research and analysis reports", "Ad copies and marketing slogans", "Product catalogs and brochures", "Press releases and promotional content", "Professional and formal correspondence", "Personal emails and letters", "Image and video captions", "Annotations and metadata for various media", "Vocabulary lists and grammar rules", "Language exercises and quizzes", "Other types of text data"]
|
||||
|
||||
**AUDIO CONTENT** (content_type: "audio"):
|
||||
- type: "AUDIO_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
|
||||
- subclass options: ["Music tracks and albums", "Podcasts and radio broadcasts", "Audiobooks and audio guides", "Recorded interviews and speeches", "Sound effects and ambient sounds", "Other types of audio recordings"]
|
||||
|
||||
**IMAGE CONTENT** (content_type: "image"):
|
||||
- type: "IMAGE_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
|
||||
- subclass options: ["Photographs and digital images", "Illustrations, diagrams, and charts", "Infographics and visual data representations", "Artwork and paintings", "Screenshots and graphical user interfaces", "Other types of images"]
|
||||
|
||||
**VIDEO CONTENT** (content_type: "video"):
|
||||
- type: "VIDEO_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
|
||||
- subclass options: ["Movies and short films", "Documentaries and educational videos", "Video tutorials and how-to guides", "Animated features and cartoons", "Live event recordings and sports broadcasts", "Other types of video content"]
|
||||
|
||||
**MULTIMEDIA CONTENT** (content_type: "multimedia"):
|
||||
- type: "MULTIMEDIA_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
|
||||
- subclass options: ["Interactive web content and games", "Virtual reality (VR) and augmented reality (AR) experiences", "Mixed media presentations and slide decks", "E-learning modules with integrated multimedia", "Digital exhibitions and virtual tours", "Other types of multimedia content"]
|
||||
|
||||
**3D MODEL CONTENT** (content_type: "3d_model"):
|
||||
- type: "3D_MODEL_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
|
||||
- subclass options: ["Architectural renderings and building plans", "Product design models and prototypes", "3D animations and character models", "Scientific simulations and visualizations", "Virtual objects for AR/VR applications", "Other types of 3D models"]
|
||||
|
||||
**PROCEDURAL CONTENT** (content_type: "procedural"):
|
||||
- type: "PROCEDURAL_DOCUMENTS_USED_FOR_GENERAL_PURPOSES"
|
||||
- subclass options: ["Tutorials and step-by-step guides", "Workflow and process descriptions", "Simulation and training exercises", "Recipes and crafting instructions", "Other types of procedural content"]
|
||||
|
||||
Select the most appropriate content_type, type, and relevant subclasses.
|
||||
"#
|
||||
|
||||
// OpenAI client defined once for all BAML files
|
||||
|
||||
// Classification function
|
||||
function ExtractCategories(content: string) -> DefaultContentPrediction {
|
||||
client OpenAI
|
||||
|
||||
prompt #"
|
||||
{{ ClassifyContentPrompt() }}
|
||||
|
||||
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
|
||||
|
||||
{{ _.role('user') }}
|
||||
{{ content }}
|
||||
"#
|
||||
}
|
||||
|
||||
// Test case for classification
|
||||
test ExtractCategoriesExample {
|
||||
functions [ExtractCategories]
|
||||
args {
|
||||
content #"
|
||||
Natural language processing (NLP) is an interdisciplinary subfield of computer science and information retrieval.
|
||||
It deals with the interaction between computers and human language, in particular how to program computers to process and analyze large amounts of natural language data.
|
||||
"#
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,343 @@
|
|||
class Node {
|
||||
id string
|
||||
name string
|
||||
type string
|
||||
description string
|
||||
@@dynamic
|
||||
}
|
||||
|
||||
/// doc string for edge
|
||||
class Edge {
|
||||
/// doc string for source_node_id
|
||||
source_node_id string
|
||||
target_node_id string
|
||||
relationship_name string
|
||||
}
|
||||
|
||||
class KnowledgeGraph {
|
||||
nodes (Node @stream.done)[]
|
||||
edges Edge[]
|
||||
}
|
||||
|
||||
// Summarization classes
|
||||
class SummarizedContent {
|
||||
summary string
|
||||
description string
|
||||
}
|
||||
|
||||
class SummarizedFunction {
|
||||
name string
|
||||
description string
|
||||
inputs string[]?
|
||||
outputs string[]?
|
||||
decorators string[]?
|
||||
}
|
||||
|
||||
class SummarizedClass {
|
||||
name string
|
||||
description string
|
||||
methods SummarizedFunction[]?
|
||||
decorators string[]?
|
||||
}
|
||||
|
||||
class SummarizedCode {
|
||||
high_level_summary string
|
||||
key_features string[]
|
||||
imports string[]
|
||||
constants string[]
|
||||
classes SummarizedClass[]
|
||||
functions SummarizedFunction[]
|
||||
workflow_description string?
|
||||
}
|
||||
|
||||
class DynamicKnowledgeGraph {
|
||||
@@dynamic
|
||||
}
|
||||
|
||||
|
||||
// Simple template for basic extraction (fast, good quality)
|
||||
template_string ExtractContentGraphPrompt() #"
|
||||
You are an advanced algorithm that extracts structured data into a knowledge graph.
|
||||
|
||||
- **Nodes**: Entities/concepts (like Wikipedia articles).
|
||||
- **Edges**: Relationships (like Wikipedia links). Use snake_case (e.g., `acted_in`).
|
||||
|
||||
**Rules:**
|
||||
|
||||
1. **Node Labeling & IDs**
|
||||
- Use basic types only (e.g., "Person", "Date", "Organization").
|
||||
- Avoid overly specific or generic terms (e.g., no "Mathematician" or "Entity").
|
||||
- Node IDs must be human-readable names from the text (no numbers).
|
||||
|
||||
2. **Dates & Numbers**
|
||||
- Label dates as **"Date"** in "YYYY-MM-DD" format (use available parts if incomplete).
|
||||
- Properties are key-value pairs; do not use escaped quotes.
|
||||
|
||||
3. **Coreference Resolution**
|
||||
- Use a single, complete identifier for each entity (e.g., always "John Doe" not "Joe" or "he").
|
||||
|
||||
4. **Relationship Labels**:
|
||||
- Use descriptive, lowercase, snake_case names for edges.
|
||||
- *Example*: born_in, married_to, invented_by.
|
||||
- Avoid vague or generic labels like isA, relatesTo, has.
|
||||
- Avoid duplicated relationships like produces, produced by.
|
||||
|
||||
5. **Strict Compliance**
|
||||
- Follow these rules exactly. Non-compliance results in termination.
|
||||
"#
|
||||
|
||||
// Summarization prompt template
|
||||
template_string SummarizeContentPrompt() #"
|
||||
You are a top-tier summarization engine. Your task is to summarize text and make it versatile.
|
||||
Be brief and concise, but keep the important information and the subject.
|
||||
Use synonym words where possible in order to change the wording but keep the meaning.
|
||||
"#
|
||||
|
||||
// Code summarization prompt template
|
||||
template_string SummarizeCodePrompt() #"
|
||||
You are an expert code analyst. Analyze the provided source code and extract key information:
|
||||
|
||||
1. Provide a high-level summary of what the code does
|
||||
2. List key features and functionality
|
||||
3. Identify imports and dependencies
|
||||
4. List constants and global variables
|
||||
5. Summarize classes with their methods
|
||||
6. Summarize standalone functions
|
||||
7. Describe the overall workflow if applicable
|
||||
|
||||
Be precise and technical while remaining clear and concise.
|
||||
"#
|
||||
|
||||
// Detailed template for complex extraction (slower, higher quality)
|
||||
template_string DetailedExtractContentGraphPrompt() #"
|
||||
You are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.
|
||||
**Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
|
||||
**Edges** represent relationships between concepts. They're akin to Wikipedia links.
|
||||
|
||||
The aim is to achieve simplicity and clarity in the knowledge graph.
|
||||
|
||||
# 1. Labeling Nodes
|
||||
**Consistency**: Ensure you use basic or elementary types for node labels.
|
||||
- For example, when you identify an entity representing a person, always label it as **"Person"**.
|
||||
- Avoid using more specific terms like "Mathematician" or "Scientist", keep those as "profession" property.
|
||||
- Don't use too generic terms like "Entity".
|
||||
**Node IDs**: Never utilize integers as node IDs.
|
||||
- Node IDs should be names or human-readable identifiers found in the text.
|
||||
|
||||
# 2. Handling Numerical Data and Dates
|
||||
- For example, when you identify an entity representing a date, make sure it has type **"Date"**.
|
||||
- Extract the date in the format "YYYY-MM-DD"
|
||||
- If not possible to extract the whole date, extract month or year, or both if available.
|
||||
- **Property Format**: Properties must be in a key-value format.
|
||||
- **Quotation Marks**: Never use escaped single or double quotes within property values.
|
||||
- **Naming Convention**: Use snake_case for relationship names, e.g., `acted_in`.
|
||||
|
||||
# 3. Coreference Resolution
|
||||
- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.
|
||||
If an entity, such as "John Doe", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., "Joe", "he"),
|
||||
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Person's ID.
|
||||
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
|
||||
|
||||
# 4. Strict Compliance
|
||||
Adhere to the rules strictly. Non-compliance will result in termination.
|
||||
"#
|
||||
|
||||
// Guided template with step-by-step instructions
|
||||
template_string GuidedExtractContentGraphPrompt() #"
|
||||
You are an advanced algorithm designed to extract structured information to build a clean, consistent, and human-readable knowledge graph.
|
||||
|
||||
**Objective**:
|
||||
- Nodes represent entities and concepts, similar to Wikipedia articles.
|
||||
- Edges represent typed relationships between nodes, similar to Wikipedia hyperlinks.
|
||||
- The graph must be clear, minimal, consistent, and semantically precise.
|
||||
|
||||
**Node Guidelines**:
|
||||
|
||||
1. **Label Consistency**:
|
||||
- Use consistent, basic types for all node labels.
|
||||
- Do not switch between granular or vague labels for the same kind of entity.
|
||||
- Pick one label for each category and apply it uniformly.
|
||||
- Each entity type should be in a singular form and in a case of multiple words separated by whitespaces
|
||||
|
||||
2. **Node Identifiers**:
|
||||
- Node IDs must be human-readable and derived directly from the text.
|
||||
- Prefer full names and canonical terms.
|
||||
- Never use integers or autogenerated IDs.
|
||||
- *Example*: Use "Marie Curie", "Theory of Evolution", "Google".
|
||||
|
||||
3. **Coreference Resolution**:
|
||||
- Maintain one consistent node ID for each real-world entity.
|
||||
- Resolve aliases, acronyms, and pronouns to the most complete form.
|
||||
- *Example*: Always use "John Doe" even if later referred to as "Doe" or "he".
|
||||
|
||||
**Edge Guidelines**:
|
||||
|
||||
4. **Relationship Labels**:
|
||||
- Use descriptive, lowercase, snake_case names for edges.
|
||||
- *Example*: born_in, married_to, invented_by.
|
||||
- Avoid vague or generic labels like isA, relatesTo, has.
|
||||
|
||||
5. **Relationship Direction**:
|
||||
- Edges must be directional and logically consistent.
|
||||
- *Example*:
|
||||
- "Marie Curie" —[born_in]→ "Warsaw"
|
||||
- "Radioactivity" —[discovered_by]→ "Marie Curie"
|
||||
|
||||
**Compliance**:
|
||||
Strict adherence to these guidelines is required. Any deviation will result in immediate termination of the task.
|
||||
"#
|
||||
|
||||
// Strict template with zero-tolerance rules
|
||||
template_string StrictExtractContentGraphPrompt() #"
|
||||
You are a top-tier algorithm for **extracting structured information** from unstructured text to build a **knowledge graph**.
|
||||
|
||||
Your primary goal is to extract:
|
||||
- **Nodes**: Representing **entities** and **concepts** (like Wikipedia nodes).
|
||||
- **Edges**: Representing **relationships** between those concepts (like Wikipedia links).
|
||||
|
||||
The resulting knowledge graph must be **simple, consistent, and human-readable**.
|
||||
|
||||
## 1. Node Labeling and Identification
|
||||
|
||||
### Node Types
|
||||
Use **basic atomic types** for node labels. Always prefer general types over specific roles or professions:
|
||||
- "Person" for any human.
|
||||
- "Organization" for companies, institutions, etc.
|
||||
- "Location" for geographic or place entities.
|
||||
- "Date" for any temporal expression.
|
||||
- "Event" for historical or scheduled occurrences.
|
||||
- "Work" for books, films, artworks, or research papers.
|
||||
- "Concept" for abstract notions or ideas.
|
||||
|
||||
### Node IDs
|
||||
- Always assign **human-readable and unambiguous identifiers**.
|
||||
- Never use numeric or autogenerated IDs.
|
||||
- Prioritize **most complete form** of entity names for consistency.
|
||||
|
||||
## 2. Relationship Handling
|
||||
- Use **snake_case** for all relationship (edge) types.
|
||||
- Keep relationship types semantically clear and consistent.
|
||||
- Avoid vague relation names like "related_to" unless no better alternative exists.
|
||||
|
||||
## 3. Strict Compliance
|
||||
Follow all rules exactly. Any deviation may lead to rejection or incorrect graph construction.
|
||||
"#
|
||||
|
||||
// OpenAI client with environment model selection
|
||||
client<llm> OpenAI {
|
||||
provider openai
|
||||
options {
|
||||
model client_registry.model
|
||||
api_key client_registry.api_key
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Function that returns raw structured output (for custom objects - to be handled in Python)
|
||||
function ExtractContentGraphGeneric(
|
||||
content: string,
|
||||
mode: "simple" | "base" | "guided" | "strict" | "custom"?,
|
||||
custom_prompt_content: string?
|
||||
) -> KnowledgeGraph {
|
||||
client OpenAI
|
||||
|
||||
prompt #"
|
||||
{% if mode == "base" %}
|
||||
{{ DetailedExtractContentGraphPrompt() }}
|
||||
{% elif mode == "guided" %}
|
||||
{{ GuidedExtractContentGraphPrompt() }}
|
||||
{% elif mode == "strict" %}
|
||||
{{ StrictExtractContentGraphPrompt() }}
|
||||
{% elif mode == "custom" and custom_prompt_content %}
|
||||
{{ custom_prompt_content }}
|
||||
{% else %}
|
||||
{{ ExtractContentGraphPrompt() }}
|
||||
{% endif %}
|
||||
|
||||
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
|
||||
|
||||
Before answering, briefly describe what you'll extract from the text, then provide the structured output.
|
||||
|
||||
Example format:
|
||||
I'll extract the main entities and their relationships from this text...
|
||||
|
||||
{ ... }
|
||||
|
||||
{{ _.role('user') }}
|
||||
{{ content }}
|
||||
"#
|
||||
}
|
||||
|
||||
// Backward-compatible function specifically for KnowledgeGraph
|
||||
function ExtractDynamicContentGraph(
|
||||
content: string,
|
||||
mode: "simple" | "base" | "guided" | "strict" | "custom"?,
|
||||
custom_prompt_content: string?
|
||||
) -> DynamicKnowledgeGraph {
|
||||
client OpenAI
|
||||
|
||||
prompt #"
|
||||
{% if mode == "base" %}
|
||||
{{ DetailedExtractContentGraphPrompt() }}
|
||||
{% elif mode == "guided" %}
|
||||
{{ GuidedExtractContentGraphPrompt() }}
|
||||
{% elif mode == "strict" %}
|
||||
{{ StrictExtractContentGraphPrompt() }}
|
||||
{% elif mode == "custom" and custom_prompt_content %}
|
||||
{{ custom_prompt_content }}
|
||||
{% else %}
|
||||
{{ ExtractContentGraphPrompt() }}
|
||||
{% endif %}
|
||||
|
||||
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
|
||||
|
||||
Before answering, briefly describe what you'll extract from the text, then provide the structured output.
|
||||
|
||||
Example format:
|
||||
I'll extract the main entities and their relationships from this text...
|
||||
|
||||
{ ... }
|
||||
|
||||
{{ _.role('user') }}
|
||||
{{ content }}
|
||||
"#
|
||||
}
|
||||
|
||||
|
||||
// Summarization functions
|
||||
function SummarizeContent(content: string) -> SummarizedContent {
|
||||
client OpenAI
|
||||
|
||||
prompt #"
|
||||
{{ SummarizeContentPrompt() }}
|
||||
|
||||
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
|
||||
|
||||
{{ _.role('user') }}
|
||||
{{ content }}
|
||||
"#
|
||||
}
|
||||
|
||||
function SummarizeCode(content: string) -> SummarizedCode {
|
||||
client OpenAI
|
||||
|
||||
prompt #"
|
||||
{{ SummarizeCodePrompt() }}
|
||||
|
||||
{{ ctx.output_format(prefix="Answer in this schema:\n") }}
|
||||
|
||||
{{ _.role('user') }}
|
||||
{{ content }}
|
||||
"#
|
||||
}
|
||||
|
||||
test ExtractStrictExample {
|
||||
functions [ExtractContentGraphGeneric]
|
||||
args {
|
||||
content #"
|
||||
The Python programming language was created by Guido van Rossum in 1991.
|
||||
"#
|
||||
mode "strict"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from .knowledge_graph.extract_content_graph import extract_content_graph
|
||||
from .extract_summary import extract_summary, extract_code_summary
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
import os
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from baml_py import ClientRegistry
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.data_models import SummarizedCode
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml_src.config import get_llm_config
|
||||
|
||||
config = get_llm_config()
|
||||
|
||||
|
||||
logger = get_logger("extract_summary_baml")
|
||||
|
||||
|
||||
def get_mock_summarized_code():
|
||||
"""Local mock function to avoid circular imports."""
|
||||
return SummarizedCode(
|
||||
high_level_summary="Mock code summary",
|
||||
key_features=["Mock feature 1", "Mock feature 2"],
|
||||
imports=["mock_import"],
|
||||
constants=["MOCK_CONSTANT"],
|
||||
classes=[],
|
||||
functions=[],
|
||||
workflow_description="Mock workflow description",
|
||||
)
|
||||
|
||||
|
||||
async def extract_summary(content: str, response_model: Type[BaseModel]):
|
||||
"""
|
||||
Extract summary using BAML framework.
|
||||
|
||||
Args:
|
||||
content: The content to summarize
|
||||
response_model: The Pydantic model type for the response
|
||||
|
||||
Returns:
|
||||
BaseModel: The summarized content in the specified format
|
||||
"""
|
||||
config = get_llm_config()
|
||||
|
||||
baml_registry = ClientRegistry()
|
||||
|
||||
baml_registry.add_llm_client(
|
||||
name="def",
|
||||
provider="openai",
|
||||
options={
|
||||
"model": config.llm_model,
|
||||
"temperature": config.llm_temperature,
|
||||
"api_key": config.llm_api_key,
|
||||
},
|
||||
)
|
||||
baml_registry.set_primary("def")
|
||||
|
||||
# Use BAML's SummarizeContent function
|
||||
summary_result = await b.SummarizeContent(
|
||||
content, baml_options={"client_registry": baml_registry}
|
||||
)
|
||||
|
||||
# Convert BAML result to the expected response model
|
||||
if response_model is SummarizedCode:
|
||||
# If it's asking for SummarizedCode but we got SummarizedContent,
|
||||
# we need to use SummarizeCode instead
|
||||
code_result = await b.SummarizeCode(
|
||||
content, baml_options={"client_registry": config.baml_registry}
|
||||
)
|
||||
return code_result
|
||||
else:
|
||||
# For other models, return the summary result
|
||||
return summary_result
|
||||
|
||||
|
||||
async def extract_code_summary(content: str):
|
||||
"""
|
||||
Extract code summary using BAML framework with mocking support.
|
||||
|
||||
Args:
|
||||
content: The code content to summarize
|
||||
|
||||
Returns:
|
||||
SummarizedCode: The summarized code information
|
||||
"""
|
||||
enable_mocking = os.getenv("MOCK_CODE_SUMMARY", "false")
|
||||
if isinstance(enable_mocking, bool):
|
||||
enable_mocking = str(enable_mocking).lower()
|
||||
enable_mocking = enable_mocking in ("true", "1", "yes")
|
||||
|
||||
if enable_mocking:
|
||||
result = get_mock_summarized_code()
|
||||
return result
|
||||
else:
|
||||
try:
|
||||
config = get_llm_config()
|
||||
|
||||
baml_registry = ClientRegistry()
|
||||
|
||||
baml_registry.add_llm_client(
|
||||
name="def",
|
||||
provider="openai",
|
||||
options={
|
||||
"model": config.llm_model,
|
||||
"temperature": config.llm_temperature,
|
||||
"api_key": config.llm_api_key,
|
||||
},
|
||||
)
|
||||
baml_registry.set_primary("def")
|
||||
result = await b.SummarizeCode(content, baml_options={"client_registry": baml_registry})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to extract code summary with BAML, falling back to mock summary", exc_info=e
|
||||
)
|
||||
result = get_mock_summarized_code()
|
||||
|
||||
return result
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
from baml_py import ClientRegistry
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.async_client import b
|
||||
|
||||
config = get_llm_config()
|
||||
|
||||
|
||||
async def extract_content_graph(
|
||||
content: str, response_model: Type[BaseModel], mode: str = "simple"
|
||||
):
|
||||
config = get_llm_config()
|
||||
setup_logging()
|
||||
|
||||
get_logger(level="INFO")
|
||||
|
||||
baml_registry = ClientRegistry()
|
||||
|
||||
baml_registry.add_llm_client(
|
||||
name="extract_content_client",
|
||||
provider=config.llm_provider,
|
||||
options={
|
||||
"model": config.llm_model,
|
||||
"temperature": config.llm_temperature,
|
||||
"api_key": config.llm_api_key,
|
||||
},
|
||||
)
|
||||
baml_registry.set_primary("extract_content_client")
|
||||
|
||||
# if response_model:
|
||||
# # tb = TypeBuilder()
|
||||
# # country = tb.union \
|
||||
# # ([tb.literal_string("USA"), tb.literal_string("UK"), tb.literal_string("Germany"), tb.literal_string("other")])
|
||||
# # tb.Node.add_property("country", country)
|
||||
#
|
||||
# graph = await b.ExtractDynamicContentGraph(
|
||||
# content, mode=mode, baml_options={"client_registry": baml_registry}
|
||||
# )
|
||||
#
|
||||
# return graph
|
||||
|
||||
# else:
|
||||
graph = await b.ExtractContentGraphGeneric(
|
||||
content, mode=mode, baml_options={"client_registry": baml_registry}
|
||||
)
|
||||
|
||||
return graph
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
// This helps use auto generate libraries you can use in the language of
|
||||
// your choice. You can have multiple generators if you use multiple languages.
|
||||
// Just ensure that the output_dir is different for each generator.
|
||||
generator target {
|
||||
// Valid values: "python/pydantic", "typescript", "ruby/sorbet", "rest/openapi"
|
||||
output_type "python/pydantic"
|
||||
|
||||
// Where the generated code will be saved (relative to baml_src/)
|
||||
output_dir "../baml/"
|
||||
|
||||
// The version of the BAML package you have installed (e.g. same version as your baml-py or @boundaryml/baml).
|
||||
// The BAML VSCode extension version should also match this version.
|
||||
version "0.201.0"
|
||||
|
||||
// Valid values: "sync", "async"
|
||||
// This controls what `b.FunctionName()` will be (sync or async).
|
||||
default_client_mode sync
|
||||
}
|
||||
|
|
@ -1,18 +1,12 @@
|
|||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
|
||||
read_query_prompt,
|
||||
)
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
|
||||
|
||||
|
||||
async def extract_categories(content: str, response_model: Type[BaseModel]):
|
||||
llm_client = get_llm_client()
|
||||
system_prompt = LLMAdapter.read_query_prompt("classify_content.txt")
|
||||
|
||||
system_prompt = read_query_prompt("classify_content.txt")
|
||||
|
||||
llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model)
|
||||
llm_output = await LLMAdapter.acreate_structured_output(content, system_prompt, response_model)
|
||||
|
||||
return llm_output
|
||||
|
|
|
|||
|
|
@ -5,12 +5,7 @@ from typing import Type
|
|||
from instructor.exceptions import InstructorRetryException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
|
||||
read_query_prompt,
|
||||
)
|
||||
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
|
||||
from cognee.shared.data_models import SummarizedCode
|
||||
|
||||
logger = get_logger("extract_summary")
|
||||
|
|
@ -30,11 +25,9 @@ def get_mock_summarized_code():
|
|||
|
||||
|
||||
async def extract_summary(content: str, response_model: Type[BaseModel]):
|
||||
llm_client = get_llm_client()
|
||||
system_prompt = LLMAdapter.read_query_prompt("summarize_content.txt")
|
||||
|
||||
system_prompt = read_query_prompt("summarize_content.txt")
|
||||
|
||||
llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model)
|
||||
llm_output = await LLMAdapter.acreate_structured_output(content, system_prompt, response_model)
|
||||
|
||||
return llm_output
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,14 @@
|
|||
import os
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
|
||||
render_prompt,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
|
||||
from cognee.infrastructure.llm.config import (
|
||||
get_llm_config,
|
||||
)
|
||||
|
||||
|
||||
async def extract_content_graph(content: str, response_model: Type[BaseModel]):
|
||||
llm_client = get_llm_client()
|
||||
llm_config = get_llm_config()
|
||||
|
||||
prompt_path = llm_config.graph_prompt_path
|
||||
|
|
@ -27,9 +22,9 @@ async def extract_content_graph(content: str, response_model: Type[BaseModel]):
|
|||
else:
|
||||
base_directory = None
|
||||
|
||||
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
system_prompt = LLMAdapter.render_prompt(prompt_path, {}, base_directory=base_directory)
|
||||
|
||||
content_graph = await llm_client.acreate_structured_output(
|
||||
content_graph = await LLMAdapter.acreate_structured_output(
|
||||
content, system_prompt, response_model
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,14 +6,13 @@ from cognee.exceptions import InvalidValueError
|
|||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
|
||||
read_query_prompt,
|
||||
)
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
sleep_and_retry_async,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
|
||||
|
||||
|
||||
class AnthropicAdapter(LLMInterface):
|
||||
"""
|
||||
|
|
@ -92,7 +91,7 @@ class AnthropicAdapter(LLMInterface):
|
|||
if not system_prompt:
|
||||
raise InvalidValueError(message="No system prompt path provided.")
|
||||
|
||||
system_prompt = read_query_prompt(system_prompt)
|
||||
system_prompt = LLMAdapter.read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
|
|
|
|||
|
|
@ -9,9 +9,7 @@ from cognee.exceptions import InvalidValueError
|
|||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
|
||||
read_query_prompt,
|
||||
)
|
||||
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
sleep_and_retry_async,
|
||||
|
|
@ -138,7 +136,7 @@ class GeminiAdapter(LLMInterface):
|
|||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise InvalidValueError(message="No system prompt path provided.")
|
||||
system_prompt = read_query_prompt(system_prompt)
|
||||
system_prompt = LLMAdapter.read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
"""Adapter for Generic API LLM provider API"""
|
||||
|
||||
import logging
|
||||
import litellm
|
||||
import instructor
|
||||
from typing import Type
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@
|
|||
from typing import Type, Protocol
|
||||
from abc import abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
|
||||
read_query_prompt,
|
||||
)
|
||||
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
|
||||
|
||||
|
||||
class LLMInterface(Protocol):
|
||||
|
|
@ -59,7 +57,7 @@ class LLMInterface(Protocol):
|
|||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise ValueError("No system prompt path provided.")
|
||||
system_prompt = read_query_prompt(system_prompt)
|
||||
system_prompt = LLMAdapter.read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
|
||||
|
|
|
|||
|
|
@ -8,9 +8,7 @@ from litellm.exceptions import ContentPolicyViolationError
|
|||
from instructor.exceptions import InstructorRetryException
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.prompts import (
|
||||
read_query_prompt,
|
||||
)
|
||||
from cognee.infrastructure.llm.LLMAdapter import LLMAdapter
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
|
|
@ -328,7 +326,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise InvalidValueError(message="No system prompt path provided.")
|
||||
system_prompt = read_query_prompt(system_prompt)
|
||||
system_prompt = LLMAdapter.read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
The question is: `{{ question }}`
|
||||
And here is the context: `{{ context }}`
|
||||
107
cognee/infrastructure/llm/utils.py
Normal file
107
cognee/infrastructure/llm/utils.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import litellm
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def get_max_chunk_tokens():
|
||||
"""
|
||||
Calculate the maximum number of tokens allowed in a chunk.
|
||||
|
||||
The function determines the maximum chunk size based on the maximum token limit of the
|
||||
embedding engine and half of the LLM maximum context token size. It ensures that the
|
||||
chunk size does not exceed these constraints.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- int: The maximum number of tokens that can be included in a chunk, determined by
|
||||
the smaller value of the embedding engine's max tokens and half of the LLM's
|
||||
maximum tokens.
|
||||
"""
|
||||
# NOTE: Import must be done in function to avoid circular import issue
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
# Calculate max chunk size based on the following formula
|
||||
embedding_engine = get_vector_engine().embedding_engine
|
||||
llm_client = get_llm_client()
|
||||
|
||||
# We need to make sure chunk size won't take more than half of LLM max context token size
|
||||
# but it also can't be bigger than the embedding engine max token size
|
||||
llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division
|
||||
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
|
||||
|
||||
return max_chunk_tokens
|
||||
|
||||
|
||||
def get_model_max_tokens(model_name: str):
|
||||
"""
|
||||
Retrieve the maximum token limit for a specified model name if it exists.
|
||||
|
||||
Checks if the provided model name is present in the predefined model cost dictionary. If
|
||||
found, it logs the maximum token count for that model and returns it. If the model name
|
||||
is not recognized, it logs an informational message and returns None.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- model_name (str): Name of LLM or embedding model
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Number of max tokens of model, or None if model is unknown
|
||||
"""
|
||||
max_tokens = None
|
||||
|
||||
if model_name in litellm.model_cost:
|
||||
max_tokens = litellm.model_cost[model_name]["max_tokens"]
|
||||
logger.debug(f"Max input tokens for {model_name}: {max_tokens}")
|
||||
else:
|
||||
logger.info("Model not found in LiteLLM's model_cost.")
|
||||
|
||||
return max_tokens
|
||||
|
||||
|
||||
async def test_llm_connection():
|
||||
"""
|
||||
Establish a connection to the LLM and create a structured output.
|
||||
|
||||
Attempt to connect to the LLM client and uses the adapter to create a structured output
|
||||
with a predefined text input and system prompt. Log any exceptions encountered during
|
||||
the connection attempt and re-raise the exception for further handling.
|
||||
"""
|
||||
try:
|
||||
llm_adapter = get_llm_client()
|
||||
await llm_adapter.acreate_structured_output(
|
||||
text_input="test",
|
||||
system_prompt='Respond to me with the following string: "test"',
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error("Connection to LLM could not be established.")
|
||||
raise e
|
||||
|
||||
|
||||
async def test_embedding_connection():
|
||||
"""
|
||||
Test the connection to the embedding engine by embedding a sample text.
|
||||
|
||||
Handles exceptions that may occur during the operation, logs the error, and re-raises
|
||||
the exception if the connection to the embedding handler cannot be established.
|
||||
"""
|
||||
try:
|
||||
# NOTE: Vector engine import must be done in function to avoid circular import issue
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
await get_vector_engine().embedding_engine.embed_text("test")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error("Connection to Embedding handler could not be established.")
|
||||
raise e
|
||||
20
poetry.lock
generated
20
poetry.lock
generated
|
|
@ -616,6 +616,24 @@ files = [
|
|||
[package.extras]
|
||||
extras = ["regex"]
|
||||
|
||||
[[package]]
|
||||
name = "baml-py"
|
||||
version = "0.201.0"
|
||||
description = "BAML python bindings (pyproject.toml)"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "baml_py-0.201.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:83228d2af2b0e845bbbb4e14f7cbd3376cec385aee01210ac522ab6076e07bec"},
|
||||
{file = "baml_py-0.201.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:2a9d016139e3ae5b5ce98c7b05b5fbd53d5d38f04dc810ec4d70fb17dd6c10e4"},
|
||||
{file = "baml_py-0.201.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5058505b1a3c5f04fc1679aec4d730fa9bef2cbd96209b3ed50152f60b96baf"},
|
||||
{file = "baml_py-0.201.0-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:36289d548581ba4accd5eaaab3246872542dd32dc6717e537654fa0cad884071"},
|
||||
{file = "baml_py-0.201.0-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:5ab70e7bd6481d71edca8a33313347b29faccec78b9960138aa437522813ac9a"},
|
||||
{file = "baml_py-0.201.0-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7efc5c693a7142c230a4f3d6700415127fee0b9f5fdbb36db63e04e27ac4c0f1"},
|
||||
{file = "baml_py-0.201.0-cp38-abi3-win_amd64.whl", hash = "sha256:56499857b7a27ae61a661c8ce0dddd0fb567a45c0b826157e44048a14cf586f9"},
|
||||
{file = "baml_py-0.201.0-cp38-abi3-win_arm64.whl", hash = "sha256:1e52dc1151db84a302b746590fe2bc484bdd794f83fa5da7216d9394c559f33a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "banks"
|
||||
version = "2.2.0"
|
||||
|
|
@ -12235,4 +12253,4 @@ qdrant = ["qdrant-client"]
|
|||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<=3.13"
|
||||
content-hash = "d15e6b5d065016613be0b8c015cccf85e7f63891353c97636d136d14e5c8f62e"
|
||||
content-hash = "111afa2da80fc4baf4cfd72c8f8c782aec2cd0e4435b3e4a3b77e2536eaa37aa"
|
||||
|
|
|
|||
17
uv.lock
generated
17
uv.lock
generated
|
|
@ -443,6 +443,21 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/41/ff/392bff89415399a979be4a65357a41d92729ae8580a66073d8ec8d810f98/backrefs-5.9-py39-none-any.whl", hash = "sha256:f48ee18f6252b8f5777a22a00a09a85de0ca931658f1dd96d4406a34f3748c60", size = 380265, upload-time = "2025-06-22T19:34:12.405Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "baml-py"
|
||||
version = "0.201.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/54/2b0edb3d22e95ce56f36610391c11108a4ef26ba2837736a32001687ae34/baml_py-0.201.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:83228d2af2b0e845bbbb4e14f7cbd3376cec385aee01210ac522ab6076e07bec", size = 17387971, upload-time = "2025-07-03T19:29:05.844Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/08/1d48c28c63eadea2c04360cbb7f64968599e99cd6b8fc0ec0bd4424d3cf1/baml_py-0.201.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:2a9d016139e3ae5b5ce98c7b05b5fbd53d5d38f04dc810ec4d70fb17dd6c10e4", size = 16191010, upload-time = "2025-07-03T19:29:09.323Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/73/1a/20b2d46501e3dd0648af339825106a6ac5eeb5d22d7e6a10cf16b9aa1cb8/baml_py-0.201.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5058505b1a3c5f04fc1679aec4d730fa9bef2cbd96209b3ed50152f60b96baf", size = 19950249, upload-time = "2025-07-03T19:29:11.974Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/38/24/bc871059e905159ae1913c2e3032dd6ef2f5c3d0983999d2c2f1eebb65a4/baml_py-0.201.0-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:36289d548581ba4accd5eaaab3246872542dd32dc6717e537654fa0cad884071", size = 19231310, upload-time = "2025-07-03T19:29:14.857Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/11/4268a0b82b02c7202fe5aa0d7175712158d998c491cac723b2bac3d5d495/baml_py-0.201.0-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:5ab70e7bd6481d71edca8a33313347b29faccec78b9960138aa437522813ac9a", size = 19490012, upload-time = "2025-07-03T19:29:18.512Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/21/c9f9aea1adba2a5978ffab11ba0948a9f3f81ec6ed3056067713260e93a1/baml_py-0.201.0-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7efc5c693a7142c230a4f3d6700415127fee0b9f5fdbb36db63e04e27ac4c0f1", size = 20090620, upload-time = "2025-07-03T19:29:21.072Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/99/cf/92123d8d753f1d1473e080c4c182139bfe3b9a6418e891cf1d96b6c33848/baml_py-0.201.0-cp38-abi3-win_amd64.whl", hash = "sha256:56499857b7a27ae61a661c8ce0dddd0fb567a45c0b826157e44048a14cf586f9", size = 17253005, upload-time = "2025-07-03T19:29:23.722Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/88/5056aa1bc9480f758cd6e210d63bd1f9ad90b44c87f4121285906526495e/baml_py-0.201.0-cp38-abi3-win_arm64.whl", hash = "sha256:1e52dc1151db84a302b746590fe2bc484bdd794f83fa5da7216d9394c559f33a", size = 15612701, upload-time = "2025-07-03T19:29:26.712Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "banks"
|
||||
version = "2.2.0"
|
||||
|
|
@ -864,6 +879,7 @@ dependencies = [
|
|||
{ name = "aiohttp" },
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "alembic" },
|
||||
{ name = "baml-py" },
|
||||
{ name = "dlt", extra = ["sqlalchemy"] },
|
||||
{ name = "fastapi" },
|
||||
{ name = "fastapi-users", extra = ["sqlalchemy"] },
|
||||
|
|
@ -1020,6 +1036,7 @@ requires-dist = [
|
|||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.26.1,<0.27" },
|
||||
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.30.0,<1.0.0" },
|
||||
{ name = "asyncpg", marker = "extra == 'postgres-binary'", specifier = ">=0.30.0,<1.0.0" },
|
||||
{ name = "baml-py", specifier = ">=0.201.0,<0.202.0" },
|
||||
{ name = "chromadb", marker = "extra == 'chromadb'", specifier = ">=0.3.0,<0.7" },
|
||||
{ name = "coverage", marker = "extra == 'dev'", specifier = ">=7.3.2,<8" },
|
||||
{ name = "debugpy", marker = "extra == 'debug'", specifier = ">=1.8.9,<2.0.0" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue