Merge remote-tracking branch 'origin/adding_ollama' into adding_ollama

This commit is contained in:
Boris Arzentar 2024-03-29 13:55:27 +01:00
commit fde01fec7b
12 changed files with 159 additions and 55 deletions

View file

@ -50,4 +50,4 @@ async def add_standalone(
def is_data_path(data: str) -> bool: def is_data_path(data: str) -> bool:
return False if not isinstance(data, str) else data.startswith("file://") return False if not isinstance(data, str) else data.startswith("file://")

View file

@ -18,4 +18,4 @@ async def remember(user_id: str, memory_name: str, payload: List[str]):
if await is_existing_memory(memory_name) is False: if await is_existing_memory(memory_name) is False:
raise MemoryException(f"Memory with the name \"{memory_name}\" doesn't exist.") raise MemoryException(f"Memory with the name \"{memory_name}\" doesn't exist.")
await create_information_points(memory_name, payload) await create_information_points(memory_name, payload)

View file

@ -202,4 +202,4 @@ if __name__ == "__main__":
print(graph_url) print(graph_url)
asyncio.run(main()) asyncio.run(main())

View file

@ -46,8 +46,9 @@ class Config:
# Model parameters # Model parameters
llm_provider: str = "openai" #openai, or custom or ollama llm_provider: str = "openai" #openai, or custom or ollama
custom_endpoint: str = "" # pass claude endpoint custom_model: str = "mistralai/Mixtral-8x7B-Instruct-v0.1"
custom_key: Optional[str] = "custom" custom_endpoint: str = "https://api.endpoints.anyscale.com/v1" # pass claude endpoint
custom_key: Optional[str] = os.getenv("ANYSCALE_API_KEY")
ollama_endpoint: str = "http://localhost:11434/v1" ollama_endpoint: str = "http://localhost:11434/v1"
ollama_key: Optional[str] = "ollama" ollama_key: Optional[str] = "ollama"
ollama_model: str = "mistral:instruct" ollama_model: str = "mistral:instruct"

View file

@ -6,6 +6,7 @@ from .EmbeddingEngine import EmbeddingEngine
config = Config() config = Config()
config.load() config.load()
class DefaultEmbeddingEngine(EmbeddingEngine): class DefaultEmbeddingEngine(EmbeddingEngine):
async def embed_text(self, text: List[str]) -> List[float]: async def embed_text(self, text: List[str]) -> List[float]:
embedding_model = TextEmbedding(model_name = config.embedding_model, cache_dir = get_absolute_path("cache/embeddings")) embedding_model = TextEmbedding(model_name = config.embedding_model, cache_dir = get_absolute_path("cache/embeddings"))

View file

@ -0,0 +1,57 @@
import asyncio
import aiohttp
from typing import List, Type
from pydantic import BaseModel
import instructor
from tenacity import retry, stop_after_attempt
import anthropic
import openai
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
class AnthropicAdapter(LLMInterface):
"""Adapter for Ollama's API"""
def __init__(self, ollama_endpoint, api_key: str, model: str):
self.aclient = instructor.patch(
create=anthropic.Anthropic().messages.create,
mode=instructor.Mode.ANTHROPIC_TOOLS
)
self.model = model
@retry(stop=stop_after_attempt(5))
async def acreate_structured_output(self, text_input: str, system_prompt: str,
response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
return await self.aclient(
model=self.model,
max_tokens=4096,
max_retries=0,
messages=[
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. {system_prompt}""",
}
],
response_model=response_model,
)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query."""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise ValueError("No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt)
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
return formatted_prompt

View file

@ -12,7 +12,7 @@ from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
class OllamaAPIAdapter(LLMInterface): class GenericAPIAdapter(LLMInterface):
"""Adapter for Ollama's API""" """Adapter for Ollama's API"""
def __init__(self, ollama_endpoint, api_key: str, model: str): def __init__(self, ollama_endpoint, api_key: str, model: str):
@ -89,9 +89,8 @@ class OllamaAPIAdapter(LLMInterface):
{ {
"role": "user", "role": "user",
"content": f"""Use the given format to "content": f"""Use the given format to
extract information from the following input: {text_input}. """, extract information from the following input: {text_input}. {system_prompt} """,
}, }
{"role": "system", "content": system_prompt},
], ],
response_model=response_model, response_model=response_model,
) )

View file

@ -1,8 +1,9 @@
"""Get the LLM client.""" """Get the LLM client."""
from enum import Enum from enum import Enum
from cognee.config import Config from cognee.config import Config
from .anthropic.adapter import AnthropicAdapter
from .openai.adapter import OpenAIAdapter from .openai.adapter import OpenAIAdapter
from .ollama.adapter import OllamaAPIAdapter from .generic_llm_api.adapter import GenericAPIAdapter
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -10,6 +11,7 @@ logging.basicConfig(level=logging.INFO)
class LLMProvider(Enum): class LLMProvider(Enum):
OPENAI = "openai" OPENAI = "openai"
OLLAMA = "ollama" OLLAMA = "ollama"
ANTHROPIC = "anthropic"
CUSTOM = "custom" CUSTOM = "custom"
config = Config() config = Config()
@ -24,10 +26,13 @@ def get_llm_client():
return OpenAIAdapter(config.openai_key, config.model) return OpenAIAdapter(config.openai_key, config.model)
elif provider == LLMProvider.OLLAMA: elif provider == LLMProvider.OLLAMA:
print("Using Ollama API") print("Using Ollama API")
return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model) return GenericAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
elif provider == LLMProvider.ANTHROPIC:
print("Using Anthropic API")
return AnthropicAdapter(config.custom_endpoint, config.custom_endpoint, config.custom_model)
elif provider == LLMProvider.CUSTOM: elif provider == LLMProvider.CUSTOM:
print("Using Custom API") print("Using Custom API")
return OllamaAPIAdapter(config.custom_endpoint, config.custom_key, config.model) return GenericAPIAdapter(config.custom_endpoint, config.custom_key, config.custom_model)
# Add your custom LLM provider here # Add your custom LLM provider here
else: else:
raise ValueError(f"Unsupported LLM provider: {provider}") raise ValueError(f"Unsupported LLM provider: {provider}")

View file

@ -6,20 +6,20 @@ from pydantic import BaseModel
class LLMInterface(Protocol): class LLMInterface(Protocol):
""" LLM Interface """ """ LLM Interface """
@abstractmethod # @abstractmethod
async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"): # async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"):
"""To get text embeddings, import/call this function""" # """To get text embeddings, import/call this function"""
raise NotImplementedError # raise NotImplementedError
#
@abstractmethod # @abstractmethod
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"): # def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"):
"""To get text embeddings, import/call this function""" # """To get text embeddings, import/call this function"""
raise NotImplementedError # raise NotImplementedError
#
@abstractmethod # @abstractmethod
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]): # async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
"""To get multiple text embeddings in parallel, import/call this function""" # """To get multiple text embeddings in parallel, import/call this function"""
raise NotImplementedError # raise NotImplementedError
# """ Get completions """ # """ Get completions """
# async def acompletions_with_backoff(self, **kwargs): # async def acompletions_with_backoff(self, **kwargs):

95
poetry.lock generated
View file

@ -150,6 +150,30 @@ files = [
{file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"},
] ]
[[package]]
name = "anthropic"
version = "0.21.3"
description = "The official Python library for the anthropic API"
optional = false
python-versions = ">=3.7"
files = [
{file = "anthropic-0.21.3-py3-none-any.whl", hash = "sha256:5869115453b543a46ded6515c9f29b8d610b6e94bbba3230ad80ac947d2b0862"},
{file = "anthropic-0.21.3.tar.gz", hash = "sha256:02f1ab5694c497e2b2d42d30d51a4f2edcaca92d2ec86bb64fe78a9c7434a869"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
tokenizers = ">=0.13.0"
typing-extensions = ">=4.7,<5"
[package.extras]
bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"]
vertex = ["google-auth (>=2,<3)"]
[[package]] [[package]]
name = "anyio" name = "anyio"
version = "4.3.0" version = "4.3.0"
@ -518,17 +542,17 @@ css = ["tinycss2 (>=1.1.0,<1.3)"]
[[package]] [[package]]
name = "boto3" name = "boto3"
version = "1.34.70" version = "1.34.73"
description = "The AWS SDK for Python" description = "The AWS SDK for Python"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "boto3-1.34.70-py3-none-any.whl", hash = "sha256:8d7902e2c0c62837457ba18146e3feaf1dec62018617edc5c0336b65b305b682"}, {file = "boto3-1.34.73-py3-none-any.whl", hash = "sha256:4d68e7c7c1339e251c661fd6e2a34e31d281177106326712417fed839907fa84"},
{file = "boto3-1.34.70.tar.gz", hash = "sha256:54150a52eb93028b8e09df00319e8dcb68be7459333d5da00d706d75ba5130d6"}, {file = "boto3-1.34.73.tar.gz", hash = "sha256:f45503333286c03fb692a3ce497b6fdb4e88c51c98a3b8ff05071d7f56571448"},
] ]
[package.dependencies] [package.dependencies]
botocore = ">=1.34.70,<1.35.0" botocore = ">=1.34.73,<1.35.0"
jmespath = ">=0.7.1,<2.0.0" jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.10.0,<0.11.0" s3transfer = ">=0.10.0,<0.11.0"
@ -537,13 +561,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]] [[package]]
name = "botocore" name = "botocore"
version = "1.34.70" version = "1.34.73"
description = "Low-level, data-driven core of boto 3." description = "Low-level, data-driven core of boto 3."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "botocore-1.34.70-py3-none-any.whl", hash = "sha256:c86944114e85c8a8d5da06fb84f2609ed3bd23cd2fc06b30250bef7e37e8c589"}, {file = "botocore-1.34.73-py3-none-any.whl", hash = "sha256:88d660b711cc5b5b049e15d547cb09526f86e48c15b78dacad78522109502b91"},
{file = "botocore-1.34.70.tar.gz", hash = "sha256:fa03d4972cd57d505e6c0eb5d7c7a1caeb7dd49e84f963f7ebeca41fe8ab736e"}, {file = "botocore-1.34.73.tar.gz", hash = "sha256:8df020b6682b9f1e9ee7b0554d5d0c14b7b23e3de070c85bcdf07fb20bfe4e2b"},
] ]
[package.dependencies] [package.dependencies]
@ -1891,13 +1915,13 @@ files = [
[[package]] [[package]]
name = "httpcore" name = "httpcore"
version = "1.0.4" version = "1.0.5"
description = "A minimal low-level HTTP client." description = "A minimal low-level HTTP client."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "httpcore-1.0.4-py3-none-any.whl", hash = "sha256:ac418c1db41bade2ad53ae2f3834a3a0f5ae76b56cf5aa497d2d033384fc7d73"}, {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"},
{file = "httpcore-1.0.4.tar.gz", hash = "sha256:cb2839ccfcba0d2d3c1131d3c3e26dfc327326fbe7a5dc0dbfe9f6c9151bb022"}, {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"},
] ]
[package.dependencies] [package.dependencies]
@ -1908,7 +1932,7 @@ h11 = ">=0.13,<0.15"
asyncio = ["anyio (>=4.0,<5.0)"] asyncio = ["anyio (>=4.0,<5.0)"]
http2 = ["h2 (>=3,<5)"] http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"] socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<0.25.0)"] trio = ["trio (>=0.22.0,<0.26.0)"]
[[package]] [[package]]
name = "httpx" name = "httpx"
@ -2030,13 +2054,13 @@ files = [
[[package]] [[package]]
name = "instructor" name = "instructor"
version = "0.6.7" version = "0.6.8"
description = "structured outputs for llm" description = "structured outputs for llm"
optional = false optional = false
python-versions = "<4.0,>=3.10" python-versions = "<4.0,>=3.10"
files = [ files = [
{file = "instructor-0.6.7-py3-none-any.whl", hash = "sha256:bb2cdc4b56ba9af763e01e590e051b13168038537a9ef12648142cec53472e53"}, {file = "instructor-0.6.8-py3-none-any.whl", hash = "sha256:f2099e49b21232ddb50ce9ba27e13159dcb3af17e8ede7cbcd93ce990fe6bc82"},
{file = "instructor-0.6.7.tar.gz", hash = "sha256:cbae44db8c71796a6237432f8c929b15d021b13c82b5474dc2921b2cdcfe647f"}, {file = "instructor-0.6.8.tar.gz", hash = "sha256:e261d73deb3535d62ee775c437b82aeb6e9c2b2f63bb533b53a9fa6a47dbb95a"},
] ]
[package.dependencies] [package.dependencies]
@ -2048,15 +2072,18 @@ rich = ">=13.7.0,<14.0.0"
tenacity = ">=8.2.3,<9.0.0" tenacity = ">=8.2.3,<9.0.0"
typer = ">=0.9.0,<0.10.0" typer = ">=0.9.0,<0.10.0"
[package.extras]
anthropic = ["anthropic (>=0.18.1,<0.19.0)", "xmltodict (>=0.13.0,<0.14.0)"]
[[package]] [[package]]
name = "ipykernel" name = "ipykernel"
version = "6.29.3" version = "6.29.4"
description = "IPython Kernel for Jupyter" description = "IPython Kernel for Jupyter"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "ipykernel-6.29.3-py3-none-any.whl", hash = "sha256:5aa086a4175b0229d4eca211e181fb473ea78ffd9869af36ba7694c947302a21"}, {file = "ipykernel-6.29.4-py3-none-any.whl", hash = "sha256:1181e653d95c6808039c509ef8e67c4126b3b3af7781496c7cbfb5ed938a27da"},
{file = "ipykernel-6.29.3.tar.gz", hash = "sha256:e14c250d1f9ea3989490225cc1a542781b095a18a19447fcf2b5eaf7d0ac5bd2"}, {file = "ipykernel-6.29.4.tar.gz", hash = "sha256:3d44070060f9475ac2092b760123fadf105d2e2493c24848b6691a7c4f42af5c"},
] ]
[package.dependencies] [package.dependencies]
@ -4660,26 +4687,26 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""}
[[package]] [[package]]
name = "qdrant-client" name = "qdrant-client"
version = "1.8.0" version = "1.8.2"
description = "Client library for the Qdrant vector search engine" description = "Client library for the Qdrant vector search engine"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "qdrant_client-1.8.0-py3-none-any.whl", hash = "sha256:fa28d3eb64c0c57ec029c7c85c71f6c72c197f92502022655741f3632c518e29"}, {file = "qdrant_client-1.8.2-py3-none-any.whl", hash = "sha256:ee5341c0486d09e4346b0f5ef7781436e6d8cdbf1d5ecddfde7adb3647d353a8"},
{file = "qdrant_client-1.8.0.tar.gz", hash = "sha256:2a1a3f2cbacc7adba85644cf6cfdee20401cf25764b32da479c81fb63e178d15"}, {file = "qdrant_client-1.8.2.tar.gz", hash = "sha256:65078d5328bc0393f42a46a31cd319a989b8285bf3958360acf1dffffdf4cc4e"},
] ]
[package.dependencies] [package.dependencies]
grpcio = ">=1.41.0" grpcio = ">=1.41.0"
grpcio-tools = ">=1.41.0" grpcio-tools = ">=1.41.0"
httpx = {version = ">=0.14.0", extras = ["http2"]} httpx = {version = ">=0.20.0", extras = ["http2"]}
numpy = {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""} numpy = {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}
portalocker = ">=2.7.0,<3.0.0" portalocker = ">=2.7.0,<3.0.0"
pydantic = ">=1.10.8" pydantic = ">=1.10.8"
urllib3 = ">=1.26.14,<3" urllib3 = ">=1.26.14,<3"
[package.extras] [package.extras]
fastembed = ["fastembed (==0.2.2)"] fastembed = ["fastembed (==0.2.5)"]
[[package]] [[package]]
name = "redis" name = "redis"
@ -4839,17 +4866,18 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]] [[package]]
name = "requirements-parser" name = "requirements-parser"
version = "0.5.0" version = "0.7.0"
description = "This is a small Python module for parsing Pip requirement files." description = "This is a small Python module for parsing Pip requirement files."
optional = false optional = false
python-versions = ">=3.6,<4.0" python-versions = "<4.0,>=3.7"
files = [ files = [
{file = "requirements-parser-0.5.0.tar.gz", hash = "sha256:3336f3a3ae23e06d3f0f88595e4052396e3adf91688787f637e5d2ca1a904069"}, {file = "requirements_parser-0.7.0-py3-none-any.whl", hash = "sha256:80569baa23b13cf0980fb2ceb5dc2e3b7ee05df203a26d83e3ed56c155c6597a"},
{file = "requirements_parser-0.5.0-py3-none-any.whl", hash = "sha256:e7fcdcd04f2049e73a9fb150d8a0f9d51ce4108f5f7cbeac74c484e17b12bcd9"}, {file = "requirements_parser-0.7.0.tar.gz", hash = "sha256:33f1b1c668fa85df8c6a638c479ac743ea8541f5d8d56011591068757ce1a201"},
] ]
[package.dependencies] [package.dependencies]
types-setuptools = ">=57.0.0" setuptools = ">=59.7.0"
types-setuptools = ">=59.7.0"
[[package]] [[package]]
name = "rfc3339-validator" name = "rfc3339-validator"
@ -6065,6 +6093,17 @@ files = [
[package.extras] [package.extras]
dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
[[package]]
name = "xmltodict"
version = "0.13.0"
description = "Makes working with XML feel like you are working with JSON"
optional = false
python-versions = ">=3.4"
files = [
{file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
{file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"},
]
[[package]] [[package]]
name = "yarl" name = "yarl"
version = "1.9.4" version = "1.9.4"
@ -6196,4 +6235,4 @@ weaviate = ["weaviate-client"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "~3.10" python-versions = "~3.10"
content-hash = "d929caab2d4114374cf2c36e1d956a7950476ff6e0a550e50011702c568f9195" content-hash = "35ad50753694260acc7e34b3c85e869e310fe2fb614fb5da3a1f3c1df4e82b1a"

View file

@ -26,7 +26,7 @@ boto3 = "^1.26.125"
gunicorn = "^20.1.0" gunicorn = "^20.1.0"
sqlalchemy = "^2.0.21" sqlalchemy = "^2.0.21"
asyncpg = "^0.28.0" asyncpg = "^0.28.0"
instructor = "^0.6.7" instructor = "^0.6.8"
networkx = "^3.2.1" networkx = "^3.2.1"
graphviz = "^0.20.1" graphviz = "^0.20.1"
langdetect = "^1.0.9" langdetect = "^1.0.9"
@ -52,6 +52,8 @@ weaviate-client = "^4.5.4"
scikit-learn = "^1.4.1.post1" scikit-learn = "^1.4.1.post1"
fastembed = "^0.2.5" fastembed = "^0.2.5"
pypdf = "^4.1.0" pypdf = "^4.1.0"
anthropic = "^0.21.3"
xmltodict = "^0.13.0"
[tool.poetry.extras] [tool.poetry.extras]
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"] dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]