Added docs
This commit is contained in:
parent
d903eacdc7
commit
a86978fb15
6 changed files with 858 additions and 297 deletions
|
|
@ -43,6 +43,10 @@ class Config:
|
||||||
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
|
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
|
llm_provider: str = "ollama"
|
||||||
|
ollama_endpoint: str = "http://localhost:11434/v1"
|
||||||
|
ollama_key: Optional[str] = os.getenv("OLLAMA_API_KEY")
|
||||||
|
ollama_model: str = "gpt-3.5-turbo"
|
||||||
model: str = "gpt-4-0125-preview"
|
model: str = "gpt-4-0125-preview"
|
||||||
# model: str = "gpt-3.5-turbo"
|
# model: str = "gpt-3.5-turbo"
|
||||||
model_endpoint: str = "openai"
|
model_endpoint: str = "openai"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
from cognee.config import Config
|
from cognee.config import Config
|
||||||
from .databases.relational import SqliteEngine, DatabaseEngine
|
from .databases.relational import SqliteEngine, DatabaseEngine
|
||||||
from .databases.vector import WeaviateAdapter, VectorDBInterface
|
from .databases.vector import WeaviateAdapter, VectorDBInterface
|
||||||
|
from .llm.llm_interface import LLMInterface
|
||||||
|
from .llm.openai.adapter import OpenAIAdapter
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
|
|
@ -8,11 +10,15 @@ config.load()
|
||||||
class InfrastructureConfig():
|
class InfrastructureConfig():
|
||||||
database_engine: DatabaseEngine = None
|
database_engine: DatabaseEngine = None
|
||||||
vector_engine: VectorDBInterface = None
|
vector_engine: VectorDBInterface = None
|
||||||
|
llm_engine: LLMInterface = None
|
||||||
|
|
||||||
def get_config(self) -> dict:
|
def get_config(self) -> dict:
|
||||||
if self.database_engine is None:
|
if self.database_engine is None:
|
||||||
self.database_engine = SqliteEngine(config.db_path, config.db_name)
|
self.database_engine = SqliteEngine(config.db_path, config.db_name)
|
||||||
|
|
||||||
|
if self.llm_engine is None:
|
||||||
|
self.llm_engine = OpenAIAdapter(config.openai_key, config.model)
|
||||||
|
|
||||||
if self.vector_engine is None:
|
if self.vector_engine is None:
|
||||||
self.vector_engine = WeaviateAdapter(
|
self.vector_engine = WeaviateAdapter(
|
||||||
config.weaviate_url,
|
config.weaviate_url,
|
||||||
|
|
@ -28,5 +34,6 @@ class InfrastructureConfig():
|
||||||
def set_config(self, new_config: dict):
|
def set_config(self, new_config: dict):
|
||||||
self.database_engine = new_config["database_engine"]
|
self.database_engine = new_config["database_engine"]
|
||||||
self.vector_engine = new_config["vector_engine"]
|
self.vector_engine = new_config["vector_engine"]
|
||||||
|
self.llm_engine = new_config["llm_engine"]
|
||||||
|
|
||||||
infrastructure_config = InfrastructureConfig()
|
infrastructure_config = InfrastructureConfig()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,27 @@
|
||||||
"""Get the LLM client."""
|
"""Get the LLM client."""
|
||||||
|
from enum import Enum
|
||||||
from cognee.config import Config
|
from cognee.config import Config
|
||||||
from .openai.adapter import OpenAIAdapter
|
from .openai.adapter import OpenAIAdapter
|
||||||
|
from .ollama.adapter import OllamaAPIAdapter
|
||||||
|
|
||||||
|
# Define an Enum for LLM Providers
|
||||||
|
class LLMProvider(Enum):
|
||||||
|
OPENAI = "openai"
|
||||||
|
OLLAMA = "ollama"
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
|
|
||||||
def get_llm_client():
|
def get_llm_client():
|
||||||
"""Get the LLM client."""
|
"""Get the LLM client based on the configuration using Enums."""
|
||||||
return OpenAIAdapter(config.openai_key, config.model)
|
provider = LLMProvider(config.llm_provider)
|
||||||
|
|
||||||
|
if provider == LLMProvider.OPENAI:
|
||||||
|
return OpenAIAdapter(config.openai_key, config.model)
|
||||||
|
elif provider == LLMProvider.OLLAMA:
|
||||||
|
return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
||||||
|
# Usage example
|
||||||
|
llm_client = get_llm_client()
|
||||||
|
|
|
||||||
108
cognee/infrastructure/llm/ollama/adapter.py
Normal file
108
cognee/infrastructure/llm/ollama/adapter.py
Normal file
|
|
@ -0,0 +1,108 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from typing import List, Type
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import instructor
|
||||||
|
from tenacity import retry, stop_after_attempt
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||||
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaAPIAdapter(LLMInterface):
|
||||||
|
"""Adapter for Ollama's API"""
|
||||||
|
|
||||||
|
def __init__(self, ollama_endpoint, api_key: str, model: str):
|
||||||
|
|
||||||
|
self.aclient = instructor.patch(
|
||||||
|
AsyncOpenAI(
|
||||||
|
base_url=ollama_endpoint,
|
||||||
|
api_key=api_key, # required, but unused
|
||||||
|
),
|
||||||
|
mode=instructor.Mode.JSON,
|
||||||
|
)
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
@retry(stop=stop_after_attempt(5))
|
||||||
|
def completions_with_backoff(self, **kwargs):
|
||||||
|
"""Wrapper around ChatCompletion.create w/ backoff"""
|
||||||
|
# Local model
|
||||||
|
return openai.chat.completions.create(**kwargs)
|
||||||
|
|
||||||
|
@retry(stop=stop_after_attempt(5))
|
||||||
|
async def acompletions_with_backoff(self, **kwargs):
|
||||||
|
"""Wrapper around ChatCompletion.acreate w/ backoff"""
|
||||||
|
return await openai.chat.completions.acreate(**kwargs)
|
||||||
|
|
||||||
|
@retry(stop=stop_after_attempt(5))
|
||||||
|
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
|
||||||
|
"""Wrapper around Embedding.acreate w/ backoff"""
|
||||||
|
|
||||||
|
return await self.aclient.embeddings.create(input=input, model=model)
|
||||||
|
|
||||||
|
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
|
||||||
|
"""To get text embeddings, import/call this function
|
||||||
|
It specifies defaults + handles rate-limiting + is async"""
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
response = await self.aclient.embeddings.create(input=text, model=model)
|
||||||
|
embedding = response.data[0].embedding
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
@retry(stop=stop_after_attempt(5))
|
||||||
|
def create_embedding_with_backoff(self, **kwargs):
|
||||||
|
"""Wrapper around Embedding.create w/ backoff"""
|
||||||
|
return openai.embeddings.create(**kwargs)
|
||||||
|
|
||||||
|
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
|
||||||
|
"""To get text embeddings, import/call this function
|
||||||
|
It specifies defaults + handles rate-limiting
|
||||||
|
:param text: str
|
||||||
|
:param model: str
|
||||||
|
"""
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
response = self.create_embedding_with_backoff(input=[text], model=model)
|
||||||
|
embedding = response.data[0].embedding
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
|
||||||
|
"""To get multiple text embeddings in parallel, import/call this function
|
||||||
|
It specifies defaults + handles rate-limiting + is async"""
|
||||||
|
# Collect all coroutines
|
||||||
|
coroutines = (self.async_get_embedding_with_backoff(text, model)
|
||||||
|
for text, model in zip(texts, models))
|
||||||
|
|
||||||
|
# Run the coroutines in parallel and gather the results
|
||||||
|
embeddings = await asyncio.gather(*coroutines)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
@retry(stop=stop_after_attempt(5))
|
||||||
|
async def acreate_structured_output(self, text_input: str, system_prompt: str,
|
||||||
|
response_model: Type[BaseModel]) -> BaseModel:
|
||||||
|
"""Generate a response from a user query."""
|
||||||
|
return await self.aclient.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""Use the given format to
|
||||||
|
extract information from the following input: {text_input}. """,
|
||||||
|
},
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
],
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||||
|
"""Format and display the prompt for a user query."""
|
||||||
|
if not text_input:
|
||||||
|
text_input = "No user input provided."
|
||||||
|
if not system_prompt:
|
||||||
|
raise ValueError("No system prompt path provided.")
|
||||||
|
system_prompt = read_query_prompt(system_prompt)
|
||||||
|
|
||||||
|
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
|
||||||
|
return formatted_prompt
|
||||||
1008
poetry.lock
generated
1008
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -18,7 +18,7 @@ classifiers = [
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "~3.10"
|
python = "~3.10"
|
||||||
openai = "1.12.0"
|
openai = "1.14.3"
|
||||||
python-dotenv = "1.0.1"
|
python-dotenv = "1.0.1"
|
||||||
fastapi = "^0.109.2"
|
fastapi = "^0.109.2"
|
||||||
uvicorn = "0.22.0"
|
uvicorn = "0.22.0"
|
||||||
|
|
@ -26,7 +26,7 @@ boto3 = "^1.26.125"
|
||||||
gunicorn = "^20.1.0"
|
gunicorn = "^20.1.0"
|
||||||
sqlalchemy = "^2.0.21"
|
sqlalchemy = "^2.0.21"
|
||||||
asyncpg = "^0.28.0"
|
asyncpg = "^0.28.0"
|
||||||
instructor = "^0.3.4"
|
instructor = "^0.6.6"
|
||||||
networkx = "^3.2.1"
|
networkx = "^3.2.1"
|
||||||
graphviz = "^0.20.1"
|
graphviz = "^0.20.1"
|
||||||
langdetect = "^1.0.9"
|
langdetect = "^1.0.9"
|
||||||
|
|
@ -46,11 +46,9 @@ dlt = "^0.4.6"
|
||||||
duckdb = {version = "^0.10.0", extras = ["dlt"]}
|
duckdb = {version = "^0.10.0", extras = ["dlt"]}
|
||||||
overrides = "^7.7.0"
|
overrides = "^7.7.0"
|
||||||
aiofiles = "^23.2.1"
|
aiofiles = "^23.2.1"
|
||||||
qdrant-client = "^1.8.0"
|
|
||||||
duckdb-engine = "^0.11.2"
|
duckdb-engine = "^0.11.2"
|
||||||
graphistry = "^0.33.5"
|
graphistry = "^0.33.5"
|
||||||
tenacity = "^8.2.3"
|
tenacity = "^8.2.3"
|
||||||
weaviate-client = "^4.5.4"
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]
|
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]
|
||||||
|
|
@ -61,6 +59,7 @@ postgres = ["psycopg2-binary", "psycopg2cffi"]
|
||||||
redshift = ["psycopg2-binary", "psycopg2cffi"]
|
redshift = ["psycopg2-binary", "psycopg2cffi"]
|
||||||
parquet = ["pyarrow"]
|
parquet = ["pyarrow"]
|
||||||
duckdb = ["duckdb"]
|
duckdb = ["duckdb"]
|
||||||
|
qdrant = ["qdrant-client"]
|
||||||
filesystem = ["s3fs", "botocore"]
|
filesystem = ["s3fs", "botocore"]
|
||||||
s3 = ["s3fs", "botocore"]
|
s3 = ["s3fs", "botocore"]
|
||||||
gs = ["gcsfs"]
|
gs = ["gcsfs"]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue