Merge remote-tracking branch 'origin/adding_ollama' into adding_ollama

This commit is contained in:
Boris Arzentar 2024-03-29 13:55:27 +01:00
commit fde01fec7b
12 changed files with 159 additions and 55 deletions

View file

@ -50,4 +50,4 @@ async def add_standalone(
def is_data_path(data: str) -> bool:
return False if not isinstance(data, str) else data.startswith("file://")
return False if not isinstance(data, str) else data.startswith("file://")

View file

@ -18,4 +18,4 @@ async def remember(user_id: str, memory_name: str, payload: List[str]):
if await is_existing_memory(memory_name) is False:
raise MemoryException(f"Memory with the name \"{memory_name}\" doesn't exist.")
await create_information_points(memory_name, payload)
await create_information_points(memory_name, payload)

View file

@ -202,4 +202,4 @@ if __name__ == "__main__":
print(graph_url)
asyncio.run(main())
asyncio.run(main())

View file

@ -46,8 +46,9 @@ class Config:
# Model parameters
llm_provider: str = "openai" #openai, or custom or ollama
custom_endpoint: str = "" # pass claude endpoint
custom_key: Optional[str] = "custom"
custom_model: str = "mistralai/Mixtral-8x7B-Instruct-v0.1"
custom_endpoint: str = "https://api.endpoints.anyscale.com/v1" # pass claude endpoint
custom_key: Optional[str] = os.getenv("ANYSCALE_API_KEY")
ollama_endpoint: str = "http://localhost:11434/v1"
ollama_key: Optional[str] = "ollama"
ollama_model: str = "mistral:instruct"

View file

@ -6,6 +6,7 @@ from .EmbeddingEngine import EmbeddingEngine
config = Config()
config.load()
class DefaultEmbeddingEngine(EmbeddingEngine):
async def embed_text(self, text: List[str]) -> List[float]:
embedding_model = TextEmbedding(model_name = config.embedding_model, cache_dir = get_absolute_path("cache/embeddings"))

View file

@ -0,0 +1,57 @@
import asyncio
import aiohttp
from typing import List, Type
from pydantic import BaseModel
import instructor
from tenacity import retry, stop_after_attempt
import anthropic
import openai
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
class AnthropicAdapter(LLMInterface):
"""Adapter for Ollama's API"""
def __init__(self, ollama_endpoint, api_key: str, model: str):
self.aclient = instructor.patch(
create=anthropic.Anthropic().messages.create,
mode=instructor.Mode.ANTHROPIC_TOOLS
)
self.model = model
@retry(stop=stop_after_attempt(5))
async def acreate_structured_output(self, text_input: str, system_prompt: str,
response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
return await self.aclient(
model=self.model,
max_tokens=4096,
max_retries=0,
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. {system_prompt}""",
}
],
response_model=response_model,
)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query."""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise ValueError("No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt)
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
return formatted_prompt

View file

@ -12,7 +12,7 @@ from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
class OllamaAPIAdapter(LLMInterface):
class GenericAPIAdapter(LLMInterface):
"""Adapter for Ollama's API"""
def __init__(self, ollama_endpoint, api_key: str, model: str):
@ -89,9 +89,8 @@ class OllamaAPIAdapter(LLMInterface):
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{"role": "system", "content": system_prompt},
extract information from the following input: {text_input}. {system_prompt} """,
}
],
response_model=response_model,
)

View file

@ -1,8 +1,9 @@
"""Get the LLM client."""
from enum import Enum
from cognee.config import Config
from .anthropic.adapter import AnthropicAdapter
from .openai.adapter import OpenAIAdapter
from .ollama.adapter import OllamaAPIAdapter
from .generic_llm_api.adapter import GenericAPIAdapter
import logging
logging.basicConfig(level=logging.INFO)
@ -10,6 +11,7 @@ logging.basicConfig(level=logging.INFO)
class LLMProvider(Enum):
OPENAI = "openai"
OLLAMA = "ollama"
ANTHROPIC = "anthropic"
CUSTOM = "custom"
config = Config()
@ -24,10 +26,13 @@ def get_llm_client():
return OpenAIAdapter(config.openai_key, config.model)
elif provider == LLMProvider.OLLAMA:
print("Using Ollama API")
return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
return GenericAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
elif provider == LLMProvider.ANTHROPIC:
print("Using Anthropic API")
return AnthropicAdapter(config.custom_endpoint, config.custom_endpoint, config.custom_model)
elif provider == LLMProvider.CUSTOM:
print("Using Custom API")
return OllamaAPIAdapter(config.custom_endpoint, config.custom_key, config.model)
return GenericAPIAdapter(config.custom_endpoint, config.custom_key, config.custom_model)
# Add your custom LLM provider here
else:
raise ValueError(f"Unsupported LLM provider: {provider}")

View file

@ -6,20 +6,20 @@ from pydantic import BaseModel
class LLMInterface(Protocol):
""" LLM Interface """
@abstractmethod
async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"):
"""To get text embeddings, import/call this function"""
raise NotImplementedError
@abstractmethod
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"):
"""To get text embeddings, import/call this function"""
raise NotImplementedError
@abstractmethod
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
"""To get multiple text embeddings in parallel, import/call this function"""
raise NotImplementedError
# @abstractmethod
# async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"):
# """To get text embeddings, import/call this function"""
# raise NotImplementedError
#
# @abstractmethod
# def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"):
# """To get text embeddings, import/call this function"""
# raise NotImplementedError
#
# @abstractmethod
# async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
# """To get multiple text embeddings in parallel, import/call this function"""
# raise NotImplementedError
# """ Get completions """
# async def acompletions_with_backoff(self, **kwargs):

95
poetry.lock generated
View file

@ -150,6 +150,30 @@ files = [
{file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"},
]
[[package]]
name = "anthropic"
version = "0.21.3"
description = "The official Python library for the anthropic API"
optional = false
python-versions = ">=3.7"
files = [
{file = "anthropic-0.21.3-py3-none-any.whl", hash = "sha256:5869115453b543a46ded6515c9f29b8d610b6e94bbba3230ad80ac947d2b0862"},
{file = "anthropic-0.21.3.tar.gz", hash = "sha256:02f1ab5694c497e2b2d42d30d51a4f2edcaca92d2ec86bb64fe78a9c7434a869"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
tokenizers = ">=0.13.0"
typing-extensions = ">=4.7,<5"
[package.extras]
bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"]
vertex = ["google-auth (>=2,<3)"]
[[package]]
name = "anyio"
version = "4.3.0"
@ -518,17 +542,17 @@ css = ["tinycss2 (>=1.1.0,<1.3)"]
[[package]]
name = "boto3"
version = "1.34.70"
version = "1.34.73"
description = "The AWS SDK for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "boto3-1.34.70-py3-none-any.whl", hash = "sha256:8d7902e2c0c62837457ba18146e3feaf1dec62018617edc5c0336b65b305b682"},
{file = "boto3-1.34.70.tar.gz", hash = "sha256:54150a52eb93028b8e09df00319e8dcb68be7459333d5da00d706d75ba5130d6"},
{file = "boto3-1.34.73-py3-none-any.whl", hash = "sha256:4d68e7c7c1339e251c661fd6e2a34e31d281177106326712417fed839907fa84"},
{file = "boto3-1.34.73.tar.gz", hash = "sha256:f45503333286c03fb692a3ce497b6fdb4e88c51c98a3b8ff05071d7f56571448"},
]
[package.dependencies]
botocore = ">=1.34.70,<1.35.0"
botocore = ">=1.34.73,<1.35.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.10.0,<0.11.0"
@ -537,13 +561,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
version = "1.34.70"
version = "1.34.73"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">=3.8"
files = [
{file = "botocore-1.34.70-py3-none-any.whl", hash = "sha256:c86944114e85c8a8d5da06fb84f2609ed3bd23cd2fc06b30250bef7e37e8c589"},
{file = "botocore-1.34.70.tar.gz", hash = "sha256:fa03d4972cd57d505e6c0eb5d7c7a1caeb7dd49e84f963f7ebeca41fe8ab736e"},
{file = "botocore-1.34.73-py3-none-any.whl", hash = "sha256:88d660b711cc5b5b049e15d547cb09526f86e48c15b78dacad78522109502b91"},
{file = "botocore-1.34.73.tar.gz", hash = "sha256:8df020b6682b9f1e9ee7b0554d5d0c14b7b23e3de070c85bcdf07fb20bfe4e2b"},
]
[package.dependencies]
@ -1891,13 +1915,13 @@ files = [
[[package]]
name = "httpcore"
version = "1.0.4"
version = "1.0.5"
description = "A minimal low-level HTTP client."
optional = false
python-versions = ">=3.8"
files = [
{file = "httpcore-1.0.4-py3-none-any.whl", hash = "sha256:ac418c1db41bade2ad53ae2f3834a3a0f5ae76b56cf5aa497d2d033384fc7d73"},
{file = "httpcore-1.0.4.tar.gz", hash = "sha256:cb2839ccfcba0d2d3c1131d3c3e26dfc327326fbe7a5dc0dbfe9f6c9151bb022"},
{file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"},
{file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"},
]
[package.dependencies]
@ -1908,7 +1932,7 @@ h11 = ">=0.13,<0.15"
asyncio = ["anyio (>=4.0,<5.0)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<0.25.0)"]
trio = ["trio (>=0.22.0,<0.26.0)"]
[[package]]
name = "httpx"
@ -2030,13 +2054,13 @@ files = [
[[package]]
name = "instructor"
version = "0.6.7"
version = "0.6.8"
description = "structured outputs for llm"
optional = false
python-versions = "<4.0,>=3.10"
files = [
{file = "instructor-0.6.7-py3-none-any.whl", hash = "sha256:bb2cdc4b56ba9af763e01e590e051b13168038537a9ef12648142cec53472e53"},
{file = "instructor-0.6.7.tar.gz", hash = "sha256:cbae44db8c71796a6237432f8c929b15d021b13c82b5474dc2921b2cdcfe647f"},
{file = "instructor-0.6.8-py3-none-any.whl", hash = "sha256:f2099e49b21232ddb50ce9ba27e13159dcb3af17e8ede7cbcd93ce990fe6bc82"},
{file = "instructor-0.6.8.tar.gz", hash = "sha256:e261d73deb3535d62ee775c437b82aeb6e9c2b2f63bb533b53a9fa6a47dbb95a"},
]
[package.dependencies]
@ -2048,15 +2072,18 @@ rich = ">=13.7.0,<14.0.0"
tenacity = ">=8.2.3,<9.0.0"
typer = ">=0.9.0,<0.10.0"
[package.extras]
anthropic = ["anthropic (>=0.18.1,<0.19.0)", "xmltodict (>=0.13.0,<0.14.0)"]
[[package]]
name = "ipykernel"
version = "6.29.3"
version = "6.29.4"
description = "IPython Kernel for Jupyter"
optional = false
python-versions = ">=3.8"
files = [
{file = "ipykernel-6.29.3-py3-none-any.whl", hash = "sha256:5aa086a4175b0229d4eca211e181fb473ea78ffd9869af36ba7694c947302a21"},
{file = "ipykernel-6.29.3.tar.gz", hash = "sha256:e14c250d1f9ea3989490225cc1a542781b095a18a19447fcf2b5eaf7d0ac5bd2"},
{file = "ipykernel-6.29.4-py3-none-any.whl", hash = "sha256:1181e653d95c6808039c509ef8e67c4126b3b3af7781496c7cbfb5ed938a27da"},
{file = "ipykernel-6.29.4.tar.gz", hash = "sha256:3d44070060f9475ac2092b760123fadf105d2e2493c24848b6691a7c4f42af5c"},
]
[package.dependencies]
@ -4660,26 +4687,26 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""}
[[package]]
name = "qdrant-client"
version = "1.8.0"
version = "1.8.2"
description = "Client library for the Qdrant vector search engine"
optional = false
python-versions = ">=3.8"
files = [
{file = "qdrant_client-1.8.0-py3-none-any.whl", hash = "sha256:fa28d3eb64c0c57ec029c7c85c71f6c72c197f92502022655741f3632c518e29"},
{file = "qdrant_client-1.8.0.tar.gz", hash = "sha256:2a1a3f2cbacc7adba85644cf6cfdee20401cf25764b32da479c81fb63e178d15"},
{file = "qdrant_client-1.8.2-py3-none-any.whl", hash = "sha256:ee5341c0486d09e4346b0f5ef7781436e6d8cdbf1d5ecddfde7adb3647d353a8"},
{file = "qdrant_client-1.8.2.tar.gz", hash = "sha256:65078d5328bc0393f42a46a31cd319a989b8285bf3958360acf1dffffdf4cc4e"},
]
[package.dependencies]
grpcio = ">=1.41.0"
grpcio-tools = ">=1.41.0"
httpx = {version = ">=0.14.0", extras = ["http2"]}
httpx = {version = ">=0.20.0", extras = ["http2"]}
numpy = {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}
portalocker = ">=2.7.0,<3.0.0"
pydantic = ">=1.10.8"
urllib3 = ">=1.26.14,<3"
[package.extras]
fastembed = ["fastembed (==0.2.2)"]
fastembed = ["fastembed (==0.2.5)"]
[[package]]
name = "redis"
@ -4839,17 +4866,18 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "requirements-parser"
version = "0.5.0"
version = "0.7.0"
description = "This is a small Python module for parsing Pip requirement files."
optional = false
python-versions = ">=3.6,<4.0"
python-versions = "<4.0,>=3.7"
files = [
{file = "requirements-parser-0.5.0.tar.gz", hash = "sha256:3336f3a3ae23e06d3f0f88595e4052396e3adf91688787f637e5d2ca1a904069"},
{file = "requirements_parser-0.5.0-py3-none-any.whl", hash = "sha256:e7fcdcd04f2049e73a9fb150d8a0f9d51ce4108f5f7cbeac74c484e17b12bcd9"},
{file = "requirements_parser-0.7.0-py3-none-any.whl", hash = "sha256:80569baa23b13cf0980fb2ceb5dc2e3b7ee05df203a26d83e3ed56c155c6597a"},
{file = "requirements_parser-0.7.0.tar.gz", hash = "sha256:33f1b1c668fa85df8c6a638c479ac743ea8541f5d8d56011591068757ce1a201"},
]
[package.dependencies]
types-setuptools = ">=57.0.0"
setuptools = ">=59.7.0"
types-setuptools = ">=59.7.0"
[[package]]
name = "rfc3339-validator"
@ -6065,6 +6093,17 @@ files = [
[package.extras]
dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
[[package]]
name = "xmltodict"
version = "0.13.0"
description = "Makes working with XML feel like you are working with JSON"
optional = false
python-versions = ">=3.4"
files = [
{file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
{file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"},
]
[[package]]
name = "yarl"
version = "1.9.4"
@ -6196,4 +6235,4 @@ weaviate = ["weaviate-client"]
[metadata]
lock-version = "2.0"
python-versions = "~3.10"
content-hash = "d929caab2d4114374cf2c36e1d956a7950476ff6e0a550e50011702c568f9195"
content-hash = "35ad50753694260acc7e34b3c85e869e310fe2fb614fb5da3a1f3c1df4e82b1a"

View file

@ -26,7 +26,7 @@ boto3 = "^1.26.125"
gunicorn = "^20.1.0"
sqlalchemy = "^2.0.21"
asyncpg = "^0.28.0"
instructor = "^0.6.7"
instructor = "^0.6.8"
networkx = "^3.2.1"
graphviz = "^0.20.1"
langdetect = "^1.0.9"
@ -52,6 +52,8 @@ weaviate-client = "^4.5.4"
scikit-learn = "^1.4.1.post1"
fastembed = "^0.2.5"
pypdf = "^4.1.0"
anthropic = "^0.21.3"
xmltodict = "^0.13.0"
[tool.poetry.extras]
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]