Added docs

This commit is contained in:
Vasilije 2024-03-25 22:11:25 +01:00
parent d903eacdc7
commit a86978fb15
6 changed files with 858 additions and 297 deletions

View file

@ -43,6 +43,10 @@ class Config:
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl") graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
# Model parameters # Model parameters
llm_provider: str = "ollama"
ollama_endpoint: str = "http://localhost:11434/v1"
ollama_key: Optional[str] = os.getenv("OLLAMA_API_KEY")
ollama_model: str = "gpt-3.5-turbo"
model: str = "gpt-4-0125-preview" model: str = "gpt-4-0125-preview"
# model: str = "gpt-3.5-turbo" # model: str = "gpt-3.5-turbo"
model_endpoint: str = "openai" model_endpoint: str = "openai"

View file

@ -1,6 +1,8 @@
from cognee.config import Config from cognee.config import Config
from .databases.relational import SqliteEngine, DatabaseEngine from .databases.relational import SqliteEngine, DatabaseEngine
from .databases.vector import WeaviateAdapter, VectorDBInterface from .databases.vector import WeaviateAdapter, VectorDBInterface
from .llm.llm_interface import LLMInterface
from .llm.openai.adapter import OpenAIAdapter
config = Config() config = Config()
config.load() config.load()
@ -8,11 +10,15 @@ config.load()
class InfrastructureConfig(): class InfrastructureConfig():
database_engine: DatabaseEngine = None database_engine: DatabaseEngine = None
vector_engine: VectorDBInterface = None vector_engine: VectorDBInterface = None
llm_engine: LLMInterface = None
def get_config(self) -> dict: def get_config(self) -> dict:
if self.database_engine is None: if self.database_engine is None:
self.database_engine = SqliteEngine(config.db_path, config.db_name) self.database_engine = SqliteEngine(config.db_path, config.db_name)
if self.llm_engine is None:
self.llm_engine = OpenAIAdapter(config.openai_key, config.model)
if self.vector_engine is None: if self.vector_engine is None:
self.vector_engine = WeaviateAdapter( self.vector_engine = WeaviateAdapter(
config.weaviate_url, config.weaviate_url,
@ -28,5 +34,6 @@ class InfrastructureConfig():
def set_config(self, new_config: dict): def set_config(self, new_config: dict):
self.database_engine = new_config["database_engine"] self.database_engine = new_config["database_engine"]
self.vector_engine = new_config["vector_engine"] self.vector_engine = new_config["vector_engine"]
self.llm_engine = new_config["llm_engine"]
infrastructure_config = InfrastructureConfig() infrastructure_config = InfrastructureConfig()

View file

@ -1,10 +1,27 @@
"""Get the LLM client.""" """Get the LLM client."""
from enum import Enum
from cognee.config import Config from cognee.config import Config
from .openai.adapter import OpenAIAdapter from .openai.adapter import OpenAIAdapter
from .ollama.adapter import OllamaAPIAdapter
# Define an Enum for LLM Providers
class LLMProvider(Enum):
OPENAI = "openai"
OLLAMA = "ollama"
config = Config() config = Config()
config.load() config.load()
def get_llm_client(): def get_llm_client():
"""Get the LLM client.""" """Get the LLM client based on the configuration using Enums."""
return OpenAIAdapter(config.openai_key, config.model) provider = LLMProvider(config.llm_provider)
if provider == LLMProvider.OPENAI:
return OpenAIAdapter(config.openai_key, config.model)
elif provider == LLMProvider.OLLAMA:
return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
# Usage example
llm_client = get_llm_client()

View file

@ -0,0 +1,108 @@
import asyncio
import aiohttp
from typing import List, Type
from pydantic import BaseModel
import instructor
from tenacity import retry, stop_after_attempt
from openai import AsyncOpenAI
import openai
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
class OllamaAPIAdapter(LLMInterface):
"""Adapter for Ollama's API"""
def __init__(self, ollama_endpoint, api_key: str, model: str):
self.aclient = instructor.patch(
AsyncOpenAI(
base_url=ollama_endpoint,
api_key=api_key, # required, but unused
),
mode=instructor.Mode.JSON,
)
self.model = model
@retry(stop=stop_after_attempt(5))
def completions_with_backoff(self, **kwargs):
"""Wrapper around ChatCompletion.create w/ backoff"""
# Local model
return openai.chat.completions.create(**kwargs)
@retry(stop=stop_after_attempt(5))
async def acompletions_with_backoff(self, **kwargs):
"""Wrapper around ChatCompletion.acreate w/ backoff"""
return await openai.chat.completions.acreate(**kwargs)
@retry(stop=stop_after_attempt(5))
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
"""Wrapper around Embedding.acreate w/ backoff"""
return await self.aclient.embeddings.create(input=input, model=model)
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting + is async"""
text = text.replace("\n", " ")
response = await self.aclient.embeddings.create(input=text, model=model)
embedding = response.data[0].embedding
return embedding
@retry(stop=stop_after_attempt(5))
def create_embedding_with_backoff(self, **kwargs):
"""Wrapper around Embedding.create w/ backoff"""
return openai.embeddings.create(**kwargs)
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting
:param text: str
:param model: str
"""
text = text.replace("\n", " ")
response = self.create_embedding_with_backoff(input=[text], model=model)
embedding = response.data[0].embedding
return embedding
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
"""To get multiple text embeddings in parallel, import/call this function
It specifies defaults + handles rate-limiting + is async"""
# Collect all coroutines
coroutines = (self.async_get_embedding_with_backoff(text, model)
for text, model in zip(texts, models))
# Run the coroutines in parallel and gather the results
embeddings = await asyncio.gather(*coroutines)
return embeddings
@retry(stop=stop_after_attempt(5))
async def acreate_structured_output(self, text_input: str, system_prompt: str,
response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{"role": "system", "content": system_prompt},
],
response_model=response_model,
)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query."""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise ValueError("No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt)
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
return formatted_prompt

1008
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -18,7 +18,7 @@ classifiers = [
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "~3.10" python = "~3.10"
openai = "1.12.0" openai = "1.14.3"
python-dotenv = "1.0.1" python-dotenv = "1.0.1"
fastapi = "^0.109.2" fastapi = "^0.109.2"
uvicorn = "0.22.0" uvicorn = "0.22.0"
@ -26,7 +26,7 @@ boto3 = "^1.26.125"
gunicorn = "^20.1.0" gunicorn = "^20.1.0"
sqlalchemy = "^2.0.21" sqlalchemy = "^2.0.21"
asyncpg = "^0.28.0" asyncpg = "^0.28.0"
instructor = "^0.3.4" instructor = "^0.6.6"
networkx = "^3.2.1" networkx = "^3.2.1"
graphviz = "^0.20.1" graphviz = "^0.20.1"
langdetect = "^1.0.9" langdetect = "^1.0.9"
@ -46,11 +46,9 @@ dlt = "^0.4.6"
duckdb = {version = "^0.10.0", extras = ["dlt"]} duckdb = {version = "^0.10.0", extras = ["dlt"]}
overrides = "^7.7.0" overrides = "^7.7.0"
aiofiles = "^23.2.1" aiofiles = "^23.2.1"
qdrant-client = "^1.8.0"
duckdb-engine = "^0.11.2" duckdb-engine = "^0.11.2"
graphistry = "^0.33.5" graphistry = "^0.33.5"
tenacity = "^8.2.3" tenacity = "^8.2.3"
weaviate-client = "^4.5.4"
[tool.poetry.extras] [tool.poetry.extras]
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"] dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]
@ -61,6 +59,7 @@ postgres = ["psycopg2-binary", "psycopg2cffi"]
redshift = ["psycopg2-binary", "psycopg2cffi"] redshift = ["psycopg2-binary", "psycopg2cffi"]
parquet = ["pyarrow"] parquet = ["pyarrow"]
duckdb = ["duckdb"] duckdb = ["duckdb"]
qdrant = ["qdrant-client"]
filesystem = ["s3fs", "botocore"] filesystem = ["s3fs", "botocore"]
s3 = ["s3fs", "botocore"] s3 = ["s3fs", "botocore"]
gs = ["gcsfs"] gs = ["gcsfs"]