diff --git a/cognee/api/v1/add/add_standalone.py b/cognee/api/v1/add/add_standalone.py index 7def013eb..ef7be455b 100644 --- a/cognee/api/v1/add/add_standalone.py +++ b/cognee/api/v1/add/add_standalone.py @@ -50,4 +50,4 @@ async def add_standalone( def is_data_path(data: str) -> bool: - return False if not isinstance(data, str) else data.startswith("file://") + return False if not isinstance(data, str) else data.startswith("file://") \ No newline at end of file diff --git a/cognee/api/v1/add/remember.py b/cognee/api/v1/add/remember.py index 4a6dbf8a2..de11aa71b 100644 --- a/cognee/api/v1/add/remember.py +++ b/cognee/api/v1/add/remember.py @@ -18,4 +18,4 @@ async def remember(user_id: str, memory_name: str, payload: List[str]): if await is_existing_memory(memory_name) is False: raise MemoryException(f"Memory with the name \"{memory_name}\" doesn't exist.") - await create_information_points(memory_name, payload) + await create_information_points(memory_name, payload) \ No newline at end of file diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 98e617a8c..c4d5786cd 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -202,4 +202,4 @@ if __name__ == "__main__": print(graph_url) - asyncio.run(main()) + asyncio.run(main()) \ No newline at end of file diff --git a/cognee/config.py b/cognee/config.py index 90db9e49e..e23b60dcf 100644 --- a/cognee/config.py +++ b/cognee/config.py @@ -46,8 +46,9 @@ class Config: # Model parameters llm_provider: str = "openai" #openai, or custom or ollama - custom_endpoint: str = "" # pass claude endpoint - custom_key: Optional[str] = "custom" + custom_model: str = "mistralai/Mixtral-8x7B-Instruct-v0.1" + custom_endpoint: str = "https://api.endpoints.anyscale.com/v1" # pass claude endpoint + custom_key: Optional[str] = os.getenv("ANYSCALE_API_KEY") ollama_endpoint: str = "http://localhost:11434/v1" ollama_key: Optional[str] = "ollama" ollama_model: str = "mistral:instruct" diff --git a/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py index 4dd5ad4c0..d22345919 100644 --- a/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py @@ -6,6 +6,7 @@ from .EmbeddingEngine import EmbeddingEngine config = Config() config.load() + class DefaultEmbeddingEngine(EmbeddingEngine): async def embed_text(self, text: List[str]) -> List[float]: embedding_model = TextEmbedding(model_name = config.embedding_model, cache_dir = get_absolute_path("cache/embeddings")) diff --git a/cognee/infrastructure/llm/anthropic/__init__.py b/cognee/infrastructure/llm/anthropic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cognee/infrastructure/llm/anthropic/adapter.py b/cognee/infrastructure/llm/anthropic/adapter.py new file mode 100644 index 000000000..92decfca0 --- /dev/null +++ b/cognee/infrastructure/llm/anthropic/adapter.py @@ -0,0 +1,57 @@ +import asyncio + +import aiohttp +from typing import List, Type +from pydantic import BaseModel +import instructor +from tenacity import retry, stop_after_attempt + +import anthropic +import openai +from cognee.infrastructure.llm.llm_interface import LLMInterface +from cognee.infrastructure.llm.prompts import read_query_prompt + + +class AnthropicAdapter(LLMInterface): + """Adapter for Ollama's API""" + + def __init__(self, ollama_endpoint, api_key: str, model: str): + + + self.aclient = instructor.patch( + create=anthropic.Anthropic().messages.create, + mode=instructor.Mode.ANTHROPIC_TOOLS + ) + self.model = model + + + + + @retry(stop=stop_after_attempt(5)) + async def acreate_structured_output(self, text_input: str, system_prompt: str, + response_model: Type[BaseModel]) -> BaseModel: + """Generate a response from a user query.""" + return await self.aclient( + model=self.model, + max_tokens=4096, + max_retries=0, + messages=[ + { + "role": "user", + "content": f"""Use the given format to + extract information from the following input: {text_input}. {system_prompt}""", + } + ], + response_model=response_model, + ) + + def show_prompt(self, text_input: str, system_prompt: str) -> str: + """Format and display the prompt for a user query.""" + if not text_input: + text_input = "No user input provided." + if not system_prompt: + raise ValueError("No system prompt path provided.") + system_prompt = read_query_prompt(system_prompt) + + formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None + return formatted_prompt diff --git a/cognee/infrastructure/llm/ollama/adapter.py b/cognee/infrastructure/llm/generic_llm_api/adapter.py similarity index 96% rename from cognee/infrastructure/llm/ollama/adapter.py rename to cognee/infrastructure/llm/generic_llm_api/adapter.py index 2372bd1c8..0dac5e76d 100644 --- a/cognee/infrastructure/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/generic_llm_api/adapter.py @@ -12,7 +12,7 @@ from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.prompts import read_query_prompt -class OllamaAPIAdapter(LLMInterface): +class GenericAPIAdapter(LLMInterface): """Adapter for Ollama's API""" def __init__(self, ollama_endpoint, api_key: str, model: str): @@ -89,9 +89,8 @@ class OllamaAPIAdapter(LLMInterface): { "role": "user", "content": f"""Use the given format to - extract information from the following input: {text_input}. """, - }, - {"role": "system", "content": system_prompt}, + extract information from the following input: {text_input}. {system_prompt} """, + } ], response_model=response_model, ) diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index 6c3fedfda..37c685571 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -1,8 +1,9 @@ """Get the LLM client.""" from enum import Enum from cognee.config import Config +from .anthropic.adapter import AnthropicAdapter from .openai.adapter import OpenAIAdapter -from .ollama.adapter import OllamaAPIAdapter +from .generic_llm_api.adapter import GenericAPIAdapter import logging logging.basicConfig(level=logging.INFO) @@ -10,6 +11,7 @@ logging.basicConfig(level=logging.INFO) class LLMProvider(Enum): OPENAI = "openai" OLLAMA = "ollama" + ANTHROPIC = "anthropic" CUSTOM = "custom" config = Config() @@ -24,10 +26,13 @@ def get_llm_client(): return OpenAIAdapter(config.openai_key, config.model) elif provider == LLMProvider.OLLAMA: print("Using Ollama API") - return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model) + return GenericAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model) + elif provider == LLMProvider.ANTHROPIC: + print("Using Anthropic API") + return AnthropicAdapter(config.custom_endpoint, config.custom_endpoint, config.custom_model) elif provider == LLMProvider.CUSTOM: print("Using Custom API") - return OllamaAPIAdapter(config.custom_endpoint, config.custom_key, config.model) + return GenericAPIAdapter(config.custom_endpoint, config.custom_key, config.custom_model) # Add your custom LLM provider here else: raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/cognee/infrastructure/llm/llm_interface.py b/cognee/infrastructure/llm/llm_interface.py index dfc56536a..230082482 100644 --- a/cognee/infrastructure/llm/llm_interface.py +++ b/cognee/infrastructure/llm/llm_interface.py @@ -6,20 +6,20 @@ from pydantic import BaseModel class LLMInterface(Protocol): """ LLM Interface """ - @abstractmethod - async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"): - """To get text embeddings, import/call this function""" - raise NotImplementedError - - @abstractmethod - def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"): - """To get text embeddings, import/call this function""" - raise NotImplementedError - - @abstractmethod - async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]): - """To get multiple text embeddings in parallel, import/call this function""" - raise NotImplementedError + # @abstractmethod + # async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"): + # """To get text embeddings, import/call this function""" + # raise NotImplementedError + # + # @abstractmethod + # def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"): + # """To get text embeddings, import/call this function""" + # raise NotImplementedError + # + # @abstractmethod + # async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]): + # """To get multiple text embeddings in parallel, import/call this function""" + # raise NotImplementedError # """ Get completions """ # async def acompletions_with_backoff(self, **kwargs): diff --git a/poetry.lock b/poetry.lock index 0db77dd7d..80ff2219f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -150,6 +150,30 @@ files = [ {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, ] +[[package]] +name = "anthropic" +version = "0.21.3" +description = "The official Python library for the anthropic API" +optional = false +python-versions = ">=3.7" +files = [ + {file = "anthropic-0.21.3-py3-none-any.whl", hash = "sha256:5869115453b543a46ded6515c9f29b8d610b6e94bbba3230ad80ac947d2b0862"}, + {file = "anthropic-0.21.3.tar.gz", hash = "sha256:02f1ab5694c497e2b2d42d30d51a4f2edcaca92d2ec86bb64fe78a9c7434a869"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tokenizers = ">=0.13.0" +typing-extensions = ">=4.7,<5" + +[package.extras] +bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"] +vertex = ["google-auth (>=2,<3)"] + [[package]] name = "anyio" version = "4.3.0" @@ -518,17 +542,17 @@ css = ["tinycss2 (>=1.1.0,<1.3)"] [[package]] name = "boto3" -version = "1.34.70" +version = "1.34.73" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.70-py3-none-any.whl", hash = "sha256:8d7902e2c0c62837457ba18146e3feaf1dec62018617edc5c0336b65b305b682"}, - {file = "boto3-1.34.70.tar.gz", hash = "sha256:54150a52eb93028b8e09df00319e8dcb68be7459333d5da00d706d75ba5130d6"}, + {file = "boto3-1.34.73-py3-none-any.whl", hash = "sha256:4d68e7c7c1339e251c661fd6e2a34e31d281177106326712417fed839907fa84"}, + {file = "boto3-1.34.73.tar.gz", hash = "sha256:f45503333286c03fb692a3ce497b6fdb4e88c51c98a3b8ff05071d7f56571448"}, ] [package.dependencies] -botocore = ">=1.34.70,<1.35.0" +botocore = ">=1.34.73,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -537,13 +561,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.70" +version = "1.34.73" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.70-py3-none-any.whl", hash = "sha256:c86944114e85c8a8d5da06fb84f2609ed3bd23cd2fc06b30250bef7e37e8c589"}, - {file = "botocore-1.34.70.tar.gz", hash = "sha256:fa03d4972cd57d505e6c0eb5d7c7a1caeb7dd49e84f963f7ebeca41fe8ab736e"}, + {file = "botocore-1.34.73-py3-none-any.whl", hash = "sha256:88d660b711cc5b5b049e15d547cb09526f86e48c15b78dacad78522109502b91"}, + {file = "botocore-1.34.73.tar.gz", hash = "sha256:8df020b6682b9f1e9ee7b0554d5d0c14b7b23e3de070c85bcdf07fb20bfe4e2b"}, ] [package.dependencies] @@ -1891,13 +1915,13 @@ files = [ [[package]] name = "httpcore" -version = "1.0.4" +version = "1.0.5" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpcore-1.0.4-py3-none-any.whl", hash = "sha256:ac418c1db41bade2ad53ae2f3834a3a0f5ae76b56cf5aa497d2d033384fc7d73"}, - {file = "httpcore-1.0.4.tar.gz", hash = "sha256:cb2839ccfcba0d2d3c1131d3c3e26dfc327326fbe7a5dc0dbfe9f6c9151bb022"}, + {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, + {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, ] [package.dependencies] @@ -1908,7 +1932,7 @@ h11 = ">=0.13,<0.15" asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] -trio = ["trio (>=0.22.0,<0.25.0)"] +trio = ["trio (>=0.22.0,<0.26.0)"] [[package]] name = "httpx" @@ -2030,13 +2054,13 @@ files = [ [[package]] name = "instructor" -version = "0.6.7" +version = "0.6.8" description = "structured outputs for llm" optional = false python-versions = "<4.0,>=3.10" files = [ - {file = "instructor-0.6.7-py3-none-any.whl", hash = "sha256:bb2cdc4b56ba9af763e01e590e051b13168038537a9ef12648142cec53472e53"}, - {file = "instructor-0.6.7.tar.gz", hash = "sha256:cbae44db8c71796a6237432f8c929b15d021b13c82b5474dc2921b2cdcfe647f"}, + {file = "instructor-0.6.8-py3-none-any.whl", hash = "sha256:f2099e49b21232ddb50ce9ba27e13159dcb3af17e8ede7cbcd93ce990fe6bc82"}, + {file = "instructor-0.6.8.tar.gz", hash = "sha256:e261d73deb3535d62ee775c437b82aeb6e9c2b2f63bb533b53a9fa6a47dbb95a"}, ] [package.dependencies] @@ -2048,15 +2072,18 @@ rich = ">=13.7.0,<14.0.0" tenacity = ">=8.2.3,<9.0.0" typer = ">=0.9.0,<0.10.0" +[package.extras] +anthropic = ["anthropic (>=0.18.1,<0.19.0)", "xmltodict (>=0.13.0,<0.14.0)"] + [[package]] name = "ipykernel" -version = "6.29.3" +version = "6.29.4" description = "IPython Kernel for Jupyter" optional = false python-versions = ">=3.8" files = [ - {file = "ipykernel-6.29.3-py3-none-any.whl", hash = "sha256:5aa086a4175b0229d4eca211e181fb473ea78ffd9869af36ba7694c947302a21"}, - {file = "ipykernel-6.29.3.tar.gz", hash = "sha256:e14c250d1f9ea3989490225cc1a542781b095a18a19447fcf2b5eaf7d0ac5bd2"}, + {file = "ipykernel-6.29.4-py3-none-any.whl", hash = "sha256:1181e653d95c6808039c509ef8e67c4126b3b3af7781496c7cbfb5ed938a27da"}, + {file = "ipykernel-6.29.4.tar.gz", hash = "sha256:3d44070060f9475ac2092b760123fadf105d2e2493c24848b6691a7c4f42af5c"}, ] [package.dependencies] @@ -4660,26 +4687,26 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} [[package]] name = "qdrant-client" -version = "1.8.0" +version = "1.8.2" description = "Client library for the Qdrant vector search engine" optional = false python-versions = ">=3.8" files = [ - {file = "qdrant_client-1.8.0-py3-none-any.whl", hash = "sha256:fa28d3eb64c0c57ec029c7c85c71f6c72c197f92502022655741f3632c518e29"}, - {file = "qdrant_client-1.8.0.tar.gz", hash = "sha256:2a1a3f2cbacc7adba85644cf6cfdee20401cf25764b32da479c81fb63e178d15"}, + {file = "qdrant_client-1.8.2-py3-none-any.whl", hash = "sha256:ee5341c0486d09e4346b0f5ef7781436e6d8cdbf1d5ecddfde7adb3647d353a8"}, + {file = "qdrant_client-1.8.2.tar.gz", hash = "sha256:65078d5328bc0393f42a46a31cd319a989b8285bf3958360acf1dffffdf4cc4e"}, ] [package.dependencies] grpcio = ">=1.41.0" grpcio-tools = ">=1.41.0" -httpx = {version = ">=0.14.0", extras = ["http2"]} +httpx = {version = ">=0.20.0", extras = ["http2"]} numpy = {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""} portalocker = ">=2.7.0,<3.0.0" pydantic = ">=1.10.8" urllib3 = ">=1.26.14,<3" [package.extras] -fastembed = ["fastembed (==0.2.2)"] +fastembed = ["fastembed (==0.2.5)"] [[package]] name = "redis" @@ -4839,17 +4866,18 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "requirements-parser" -version = "0.5.0" +version = "0.7.0" description = "This is a small Python module for parsing Pip requirement files." optional = false -python-versions = ">=3.6,<4.0" +python-versions = "<4.0,>=3.7" files = [ - {file = "requirements-parser-0.5.0.tar.gz", hash = "sha256:3336f3a3ae23e06d3f0f88595e4052396e3adf91688787f637e5d2ca1a904069"}, - {file = "requirements_parser-0.5.0-py3-none-any.whl", hash = "sha256:e7fcdcd04f2049e73a9fb150d8a0f9d51ce4108f5f7cbeac74c484e17b12bcd9"}, + {file = "requirements_parser-0.7.0-py3-none-any.whl", hash = "sha256:80569baa23b13cf0980fb2ceb5dc2e3b7ee05df203a26d83e3ed56c155c6597a"}, + {file = "requirements_parser-0.7.0.tar.gz", hash = "sha256:33f1b1c668fa85df8c6a638c479ac743ea8541f5d8d56011591068757ce1a201"}, ] [package.dependencies] -types-setuptools = ">=57.0.0" +setuptools = ">=59.7.0" +types-setuptools = ">=59.7.0" [[package]] name = "rfc3339-validator" @@ -6065,6 +6093,17 @@ files = [ [package.extras] dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] +[[package]] +name = "xmltodict" +version = "0.13.0" +description = "Makes working with XML feel like you are working with JSON" +optional = false +python-versions = ">=3.4" +files = [ + {file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"}, + {file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"}, +] + [[package]] name = "yarl" version = "1.9.4" @@ -6196,4 +6235,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = "~3.10" -content-hash = "d929caab2d4114374cf2c36e1d956a7950476ff6e0a550e50011702c568f9195" +content-hash = "35ad50753694260acc7e34b3c85e869e310fe2fb614fb5da3a1f3c1df4e82b1a" diff --git a/pyproject.toml b/pyproject.toml index 92214ad92..2ef060ad0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ boto3 = "^1.26.125" gunicorn = "^20.1.0" sqlalchemy = "^2.0.21" asyncpg = "^0.28.0" -instructor = "^0.6.7" +instructor = "^0.6.8" networkx = "^3.2.1" graphviz = "^0.20.1" langdetect = "^1.0.9" @@ -52,6 +52,8 @@ weaviate-client = "^4.5.4" scikit-learn = "^1.4.1.post1" fastembed = "^0.2.5" pypdf = "^4.1.0" +anthropic = "^0.21.3" +xmltodict = "^0.13.0" [tool.poetry.extras] dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]