Added docs

This commit is contained in:
Vasilije 2024-03-25 22:11:25 +01:00
parent d903eacdc7
commit a86978fb15
6 changed files with 858 additions and 297 deletions

View file

@ -43,6 +43,10 @@ class Config:
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
# Model parameters
llm_provider: str = "ollama"
ollama_endpoint: str = "http://localhost:11434/v1"
ollama_key: Optional[str] = os.getenv("OLLAMA_API_KEY")
ollama_model: str = "gpt-3.5-turbo"
model: str = "gpt-4-0125-preview"
# model: str = "gpt-3.5-turbo"
model_endpoint: str = "openai"

View file

@ -1,6 +1,8 @@
from cognee.config import Config
from .databases.relational import SqliteEngine, DatabaseEngine
from .databases.vector import WeaviateAdapter, VectorDBInterface
from .llm.llm_interface import LLMInterface
from .llm.openai.adapter import OpenAIAdapter
config = Config()
config.load()
@ -8,11 +10,15 @@ config.load()
class InfrastructureConfig():
database_engine: DatabaseEngine = None
vector_engine: VectorDBInterface = None
llm_engine: LLMInterface = None
def get_config(self) -> dict:
if self.database_engine is None:
self.database_engine = SqliteEngine(config.db_path, config.db_name)
if self.llm_engine is None:
self.llm_engine = OpenAIAdapter(config.openai_key, config.model)
if self.vector_engine is None:
self.vector_engine = WeaviateAdapter(
config.weaviate_url,
@ -28,5 +34,6 @@ class InfrastructureConfig():
def set_config(self, new_config: dict):
self.database_engine = new_config["database_engine"]
self.vector_engine = new_config["vector_engine"]
self.llm_engine = new_config["llm_engine"]
infrastructure_config = InfrastructureConfig()

View file

@ -1,10 +1,27 @@
"""Get the LLM client."""
from enum import Enum
from cognee.config import Config
from .openai.adapter import OpenAIAdapter
from .ollama.adapter import OllamaAPIAdapter
# Define an Enum for LLM Providers
class LLMProvider(Enum):
OPENAI = "openai"
OLLAMA = "ollama"
config = Config()
config.load()
def get_llm_client():
"""Get the LLM client."""
return OpenAIAdapter(config.openai_key, config.model)
"""Get the LLM client based on the configuration using Enums."""
provider = LLMProvider(config.llm_provider)
if provider == LLMProvider.OPENAI:
return OpenAIAdapter(config.openai_key, config.model)
elif provider == LLMProvider.OLLAMA:
return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
# Usage example
llm_client = get_llm_client()

View file

@ -0,0 +1,108 @@
import asyncio
import aiohttp
from typing import List, Type
from pydantic import BaseModel
import instructor
from tenacity import retry, stop_after_attempt
from openai import AsyncOpenAI
import openai
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
class OllamaAPIAdapter(LLMInterface):
"""Adapter for Ollama's API"""
def __init__(self, ollama_endpoint, api_key: str, model: str):
self.aclient = instructor.patch(
AsyncOpenAI(
base_url=ollama_endpoint,
api_key=api_key, # required, but unused
),
mode=instructor.Mode.JSON,
)
self.model = model
@retry(stop=stop_after_attempt(5))
def completions_with_backoff(self, **kwargs):
"""Wrapper around ChatCompletion.create w/ backoff"""
# Local model
return openai.chat.completions.create(**kwargs)
@retry(stop=stop_after_attempt(5))
async def acompletions_with_backoff(self, **kwargs):
"""Wrapper around ChatCompletion.acreate w/ backoff"""
return await openai.chat.completions.acreate(**kwargs)
@retry(stop=stop_after_attempt(5))
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
"""Wrapper around Embedding.acreate w/ backoff"""
return await self.aclient.embeddings.create(input=input, model=model)
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting + is async"""
text = text.replace("\n", " ")
response = await self.aclient.embeddings.create(input=text, model=model)
embedding = response.data[0].embedding
return embedding
@retry(stop=stop_after_attempt(5))
def create_embedding_with_backoff(self, **kwargs):
"""Wrapper around Embedding.create w/ backoff"""
return openai.embeddings.create(**kwargs)
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting
:param text: str
:param model: str
"""
text = text.replace("\n", " ")
response = self.create_embedding_with_backoff(input=[text], model=model)
embedding = response.data[0].embedding
return embedding
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
"""To get multiple text embeddings in parallel, import/call this function
It specifies defaults + handles rate-limiting + is async"""
# Collect all coroutines
coroutines = (self.async_get_embedding_with_backoff(text, model)
for text, model in zip(texts, models))
# Run the coroutines in parallel and gather the results
embeddings = await asyncio.gather(*coroutines)
return embeddings
@retry(stop=stop_after_attempt(5))
async def acreate_structured_output(self, text_input: str, system_prompt: str,
response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{"role": "system", "content": system_prompt},
],
response_model=response_model,
)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query."""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise ValueError("No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt)
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
return formatted_prompt

1008
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -18,7 +18,7 @@ classifiers = [
[tool.poetry.dependencies]
python = "~3.10"
openai = "1.12.0"
openai = "1.14.3"
python-dotenv = "1.0.1"
fastapi = "^0.109.2"
uvicorn = "0.22.0"
@ -26,7 +26,7 @@ boto3 = "^1.26.125"
gunicorn = "^20.1.0"
sqlalchemy = "^2.0.21"
asyncpg = "^0.28.0"
instructor = "^0.3.4"
instructor = "^0.6.6"
networkx = "^3.2.1"
graphviz = "^0.20.1"
langdetect = "^1.0.9"
@ -46,11 +46,9 @@ dlt = "^0.4.6"
duckdb = {version = "^0.10.0", extras = ["dlt"]}
overrides = "^7.7.0"
aiofiles = "^23.2.1"
qdrant-client = "^1.8.0"
duckdb-engine = "^0.11.2"
graphistry = "^0.33.5"
tenacity = "^8.2.3"
weaviate-client = "^4.5.4"
[tool.poetry.extras]
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]
@ -61,6 +59,7 @@ postgres = ["psycopg2-binary", "psycopg2cffi"]
redshift = ["psycopg2-binary", "psycopg2cffi"]
parquet = ["pyarrow"]
duckdb = ["duckdb"]
qdrant = ["qdrant-client"]
filesystem = ["s3fs", "botocore"]
s3 = ["s3fs", "botocore"]
gs = ["gcsfs"]