Added docs
This commit is contained in:
parent
d903eacdc7
commit
a86978fb15
6 changed files with 858 additions and 297 deletions
|
|
@ -43,6 +43,10 @@ class Config:
|
|||
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
|
||||
|
||||
# Model parameters
|
||||
llm_provider: str = "ollama"
|
||||
ollama_endpoint: str = "http://localhost:11434/v1"
|
||||
ollama_key: Optional[str] = os.getenv("OLLAMA_API_KEY")
|
||||
ollama_model: str = "gpt-3.5-turbo"
|
||||
model: str = "gpt-4-0125-preview"
|
||||
# model: str = "gpt-3.5-turbo"
|
||||
model_endpoint: str = "openai"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from cognee.config import Config
|
||||
from .databases.relational import SqliteEngine, DatabaseEngine
|
||||
from .databases.vector import WeaviateAdapter, VectorDBInterface
|
||||
from .llm.llm_interface import LLMInterface
|
||||
from .llm.openai.adapter import OpenAIAdapter
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
|
|
@ -8,11 +10,15 @@ config.load()
|
|||
class InfrastructureConfig():
|
||||
database_engine: DatabaseEngine = None
|
||||
vector_engine: VectorDBInterface = None
|
||||
llm_engine: LLMInterface = None
|
||||
|
||||
def get_config(self) -> dict:
|
||||
if self.database_engine is None:
|
||||
self.database_engine = SqliteEngine(config.db_path, config.db_name)
|
||||
|
||||
if self.llm_engine is None:
|
||||
self.llm_engine = OpenAIAdapter(config.openai_key, config.model)
|
||||
|
||||
if self.vector_engine is None:
|
||||
self.vector_engine = WeaviateAdapter(
|
||||
config.weaviate_url,
|
||||
|
|
@ -28,5 +34,6 @@ class InfrastructureConfig():
|
|||
def set_config(self, new_config: dict):
|
||||
self.database_engine = new_config["database_engine"]
|
||||
self.vector_engine = new_config["vector_engine"]
|
||||
self.llm_engine = new_config["llm_engine"]
|
||||
|
||||
infrastructure_config = InfrastructureConfig()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,27 @@
|
|||
"""Get the LLM client."""
|
||||
from enum import Enum
|
||||
from cognee.config import Config
|
||||
from .openai.adapter import OpenAIAdapter
|
||||
from .ollama.adapter import OllamaAPIAdapter
|
||||
|
||||
# Define an Enum for LLM Providers
|
||||
class LLMProvider(Enum):
|
||||
OPENAI = "openai"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
|
||||
def get_llm_client():
|
||||
"""Get the LLM client."""
|
||||
return OpenAIAdapter(config.openai_key, config.model)
|
||||
"""Get the LLM client based on the configuration using Enums."""
|
||||
provider = LLMProvider(config.llm_provider)
|
||||
|
||||
if provider == LLMProvider.OPENAI:
|
||||
return OpenAIAdapter(config.openai_key, config.model)
|
||||
elif provider == LLMProvider.OLLAMA:
|
||||
return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
# Usage example
|
||||
llm_client = get_llm_client()
|
||||
|
|
|
|||
108
cognee/infrastructure/llm/ollama/adapter.py
Normal file
108
cognee/infrastructure/llm/ollama/adapter.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
from typing import List, Type
|
||||
from pydantic import BaseModel
|
||||
import instructor
|
||||
from tenacity import retry, stop_after_attempt
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
import openai
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
|
||||
|
||||
class OllamaAPIAdapter(LLMInterface):
|
||||
"""Adapter for Ollama's API"""
|
||||
|
||||
def __init__(self, ollama_endpoint, api_key: str, model: str):
|
||||
|
||||
self.aclient = instructor.patch(
|
||||
AsyncOpenAI(
|
||||
base_url=ollama_endpoint,
|
||||
api_key=api_key, # required, but unused
|
||||
),
|
||||
mode=instructor.Mode.JSON,
|
||||
)
|
||||
self.model = model
|
||||
|
||||
@retry(stop=stop_after_attempt(5))
|
||||
def completions_with_backoff(self, **kwargs):
|
||||
"""Wrapper around ChatCompletion.create w/ backoff"""
|
||||
# Local model
|
||||
return openai.chat.completions.create(**kwargs)
|
||||
|
||||
@retry(stop=stop_after_attempt(5))
|
||||
async def acompletions_with_backoff(self, **kwargs):
|
||||
"""Wrapper around ChatCompletion.acreate w/ backoff"""
|
||||
return await openai.chat.completions.acreate(**kwargs)
|
||||
|
||||
@retry(stop=stop_after_attempt(5))
|
||||
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
|
||||
"""Wrapper around Embedding.acreate w/ backoff"""
|
||||
|
||||
return await self.aclient.embeddings.create(input=input, model=model)
|
||||
|
||||
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
|
||||
"""To get text embeddings, import/call this function
|
||||
It specifies defaults + handles rate-limiting + is async"""
|
||||
text = text.replace("\n", " ")
|
||||
response = await self.aclient.embeddings.create(input=text, model=model)
|
||||
embedding = response.data[0].embedding
|
||||
return embedding
|
||||
|
||||
@retry(stop=stop_after_attempt(5))
|
||||
def create_embedding_with_backoff(self, **kwargs):
|
||||
"""Wrapper around Embedding.create w/ backoff"""
|
||||
return openai.embeddings.create(**kwargs)
|
||||
|
||||
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
|
||||
"""To get text embeddings, import/call this function
|
||||
It specifies defaults + handles rate-limiting
|
||||
:param text: str
|
||||
:param model: str
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
response = self.create_embedding_with_backoff(input=[text], model=model)
|
||||
embedding = response.data[0].embedding
|
||||
return embedding
|
||||
|
||||
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
|
||||
"""To get multiple text embeddings in parallel, import/call this function
|
||||
It specifies defaults + handles rate-limiting + is async"""
|
||||
# Collect all coroutines
|
||||
coroutines = (self.async_get_embedding_with_backoff(text, model)
|
||||
for text, model in zip(texts, models))
|
||||
|
||||
# Run the coroutines in parallel and gather the results
|
||||
embeddings = await asyncio.gather(*coroutines)
|
||||
|
||||
return embeddings
|
||||
|
||||
@retry(stop=stop_after_attempt(5))
|
||||
async def acreate_structured_output(self, text_input: str, system_prompt: str,
|
||||
response_model: Type[BaseModel]) -> BaseModel:
|
||||
"""Generate a response from a user query."""
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
},
|
||||
{"role": "system", "content": system_prompt},
|
||||
],
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""Format and display the prompt for a user query."""
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise ValueError("No system prompt path provided.")
|
||||
system_prompt = read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
|
||||
return formatted_prompt
|
||||
1008
poetry.lock
generated
1008
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -18,7 +18,7 @@ classifiers = [
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = "~3.10"
|
||||
openai = "1.12.0"
|
||||
openai = "1.14.3"
|
||||
python-dotenv = "1.0.1"
|
||||
fastapi = "^0.109.2"
|
||||
uvicorn = "0.22.0"
|
||||
|
|
@ -26,7 +26,7 @@ boto3 = "^1.26.125"
|
|||
gunicorn = "^20.1.0"
|
||||
sqlalchemy = "^2.0.21"
|
||||
asyncpg = "^0.28.0"
|
||||
instructor = "^0.3.4"
|
||||
instructor = "^0.6.6"
|
||||
networkx = "^3.2.1"
|
||||
graphviz = "^0.20.1"
|
||||
langdetect = "^1.0.9"
|
||||
|
|
@ -46,11 +46,9 @@ dlt = "^0.4.6"
|
|||
duckdb = {version = "^0.10.0", extras = ["dlt"]}
|
||||
overrides = "^7.7.0"
|
||||
aiofiles = "^23.2.1"
|
||||
qdrant-client = "^1.8.0"
|
||||
duckdb-engine = "^0.11.2"
|
||||
graphistry = "^0.33.5"
|
||||
tenacity = "^8.2.3"
|
||||
weaviate-client = "^4.5.4"
|
||||
|
||||
[tool.poetry.extras]
|
||||
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]
|
||||
|
|
@ -61,6 +59,7 @@ postgres = ["psycopg2-binary", "psycopg2cffi"]
|
|||
redshift = ["psycopg2-binary", "psycopg2cffi"]
|
||||
parquet = ["pyarrow"]
|
||||
duckdb = ["duckdb"]
|
||||
qdrant = ["qdrant-client"]
|
||||
filesystem = ["s3fs", "botocore"]
|
||||
s3 = ["s3fs", "botocore"]
|
||||
gs = ["gcsfs"]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue