From 90c41512ed5309f1199cbdfc7df164a9153f7494 Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Thu, 28 Mar 2024 11:26:22 +0100 Subject: [PATCH] Ollama fixes, missing libs + config fixes --- cognee/api/v1/add/add.py | 2 +- cognee/api/v1/add/add_standalone.py | 2 +- cognee/api/v1/add/remember.py | 2 +- cognee/api/v1/cognify/cognify.py | 2 +- cognee/config.py | 3 +- .../embeddings/DefaultEmbeddingEngine.py | 2 +- cognee/infrastructure/llm/get_llm_client.py | 5 +++ cognee/infrastructure/llm/llm_interface.py | 28 +++++++------- poetry.lock | 37 ++++++++++++++++++- pyproject.toml | 2 + 10 files changed, 64 insertions(+), 21 deletions(-) diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index b91a979a8..8e94a7d35 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -87,4 +87,4 @@ async def add(file_paths: Union[str, List[str]], dataset_name: str = None): write_disposition = "merge", ) - return run_info + return run_info \ No newline at end of file diff --git a/cognee/api/v1/add/add_standalone.py b/cognee/api/v1/add/add_standalone.py index 7def013eb..ef7be455b 100644 --- a/cognee/api/v1/add/add_standalone.py +++ b/cognee/api/v1/add/add_standalone.py @@ -50,4 +50,4 @@ async def add_standalone( def is_data_path(data: str) -> bool: - return False if not isinstance(data, str) else data.startswith("file://") + return False if not isinstance(data, str) else data.startswith("file://") \ No newline at end of file diff --git a/cognee/api/v1/add/remember.py b/cognee/api/v1/add/remember.py index 4a6dbf8a2..de11aa71b 100644 --- a/cognee/api/v1/add/remember.py +++ b/cognee/api/v1/add/remember.py @@ -18,4 +18,4 @@ async def remember(user_id: str, memory_name: str, payload: List[str]): if await is_existing_memory(memory_name) is False: raise MemoryException(f"Memory with the name \"{memory_name}\" doesn't exist.") - await create_information_points(memory_name, payload) + await create_information_points(memory_name, payload) \ No newline at end of file diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index f6b2d9534..e3eecc656 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -193,4 +193,4 @@ if __name__ == "__main__": print(graph_url) - asyncio.run(main()) + asyncio.run(main()) \ No newline at end of file diff --git a/cognee/config.py b/cognee/config.py index bb97d4372..a9a210f5a 100644 --- a/cognee/config.py +++ b/cognee/config.py @@ -43,7 +43,8 @@ class Config: graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl") # Model parameters - llm_provider: str = "openai" #openai, or custom or ollama + llm_provider: str = "anthropic" #openai, or custom or ollama + custom_model: str = "claude-3-haiku-20240307" custom_endpoint: str = "" # pass claude endpoint custom_key: Optional[str] = "custom" ollama_endpoint: str = "http://localhost:11434/v1" diff --git a/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py index 3192cb5ef..c91d1cf29 100644 --- a/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py @@ -1,7 +1,7 @@ from typing import List from fastembed import TextEmbedding from .EmbeddingEngine import EmbeddingEngine -from cognitive_architecture.config import Config +from cognee.config import Config config = Config() config.load() diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index 6c3fedfda..7720f3a6e 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -1,6 +1,7 @@ """Get the LLM client.""" from enum import Enum from cognee.config import Config +from .anthropic.adapter import AnthropicAdapter from .openai.adapter import OpenAIAdapter from .ollama.adapter import OllamaAPIAdapter import logging @@ -10,6 +11,7 @@ logging.basicConfig(level=logging.INFO) class LLMProvider(Enum): OPENAI = "openai" OLLAMA = "ollama" + ANTHROPIC = "anthropic" CUSTOM = "custom" config = Config() @@ -25,6 +27,9 @@ def get_llm_client(): elif provider == LLMProvider.OLLAMA: print("Using Ollama API") return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model) + elif provider == LLMProvider.ANTHROPIC: + print("Using Anthropic API") + return AnthropicAdapter(config.ollama_endpoint, config.ollama_key, config.custom_model) elif provider == LLMProvider.CUSTOM: print("Using Custom API") return OllamaAPIAdapter(config.custom_endpoint, config.custom_key, config.model) diff --git a/cognee/infrastructure/llm/llm_interface.py b/cognee/infrastructure/llm/llm_interface.py index dfc56536a..230082482 100644 --- a/cognee/infrastructure/llm/llm_interface.py +++ b/cognee/infrastructure/llm/llm_interface.py @@ -6,20 +6,20 @@ from pydantic import BaseModel class LLMInterface(Protocol): """ LLM Interface """ - @abstractmethod - async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"): - """To get text embeddings, import/call this function""" - raise NotImplementedError - - @abstractmethod - def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"): - """To get text embeddings, import/call this function""" - raise NotImplementedError - - @abstractmethod - async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]): - """To get multiple text embeddings in parallel, import/call this function""" - raise NotImplementedError + # @abstractmethod + # async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"): + # """To get text embeddings, import/call this function""" + # raise NotImplementedError + # + # @abstractmethod + # def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"): + # """To get text embeddings, import/call this function""" + # raise NotImplementedError + # + # @abstractmethod + # async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]): + # """To get multiple text embeddings in parallel, import/call this function""" + # raise NotImplementedError # """ Get completions """ # async def acompletions_with_backoff(self, **kwargs): diff --git a/poetry.lock b/poetry.lock index 0db77dd7d..62191ff44 100644 --- a/poetry.lock +++ b/poetry.lock @@ -150,6 +150,30 @@ files = [ {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, ] +[[package]] +name = "anthropic" +version = "0.21.3" +description = "The official Python library for the anthropic API" +optional = false +python-versions = ">=3.7" +files = [ + {file = "anthropic-0.21.3-py3-none-any.whl", hash = "sha256:5869115453b543a46ded6515c9f29b8d610b6e94bbba3230ad80ac947d2b0862"}, + {file = "anthropic-0.21.3.tar.gz", hash = "sha256:02f1ab5694c497e2b2d42d30d51a4f2edcaca92d2ec86bb64fe78a9c7434a869"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tokenizers = ">=0.13.0" +typing-extensions = ">=4.7,<5" + +[package.extras] +bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"] +vertex = ["google-auth (>=2,<3)"] + [[package]] name = "anyio" version = "4.3.0" @@ -6065,6 +6089,17 @@ files = [ [package.extras] dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] +[[package]] +name = "xmltodict" +version = "0.13.0" +description = "Makes working with XML feel like you are working with JSON" +optional = false +python-versions = ">=3.4" +files = [ + {file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"}, + {file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"}, +] + [[package]] name = "yarl" version = "1.9.4" @@ -6196,4 +6231,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = "~3.10" -content-hash = "d929caab2d4114374cf2c36e1d956a7950476ff6e0a550e50011702c568f9195" +content-hash = "1a6b4648fa95a43e76eef36fc6f9951b34988001ea24b8288bcc70962f05d7db" diff --git a/pyproject.toml b/pyproject.toml index 92214ad92..8879b2040 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,8 @@ weaviate-client = "^4.5.4" scikit-learn = "^1.4.1.post1" fastembed = "^0.2.5" pypdf = "^4.1.0" +anthropic = "^0.21.3" +xmltodict = "^0.13.0" [tool.poetry.extras] dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]