Ollama fixes, missing libs + config fixes

This commit is contained in:
Vasilije 2024-03-28 11:26:22 +01:00
parent 7bf80fb0f5
commit 90c41512ed
10 changed files with 64 additions and 21 deletions

View file

@ -87,4 +87,4 @@ async def add(file_paths: Union[str, List[str]], dataset_name: str = None):
write_disposition = "merge", write_disposition = "merge",
) )
return run_info return run_info

View file

@ -50,4 +50,4 @@ async def add_standalone(
def is_data_path(data: str) -> bool: def is_data_path(data: str) -> bool:
return False if not isinstance(data, str) else data.startswith("file://") return False if not isinstance(data, str) else data.startswith("file://")

View file

@ -18,4 +18,4 @@ async def remember(user_id: str, memory_name: str, payload: List[str]):
if await is_existing_memory(memory_name) is False: if await is_existing_memory(memory_name) is False:
raise MemoryException(f"Memory with the name \"{memory_name}\" doesn't exist.") raise MemoryException(f"Memory with the name \"{memory_name}\" doesn't exist.")
await create_information_points(memory_name, payload) await create_information_points(memory_name, payload)

View file

@ -193,4 +193,4 @@ if __name__ == "__main__":
print(graph_url) print(graph_url)
asyncio.run(main()) asyncio.run(main())

View file

@ -43,7 +43,8 @@ class Config:
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl") graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
# Model parameters # Model parameters
llm_provider: str = "openai" #openai, or custom or ollama llm_provider: str = "anthropic" #openai, or custom or ollama
custom_model: str = "claude-3-haiku-20240307"
custom_endpoint: str = "" # pass claude endpoint custom_endpoint: str = "" # pass claude endpoint
custom_key: Optional[str] = "custom" custom_key: Optional[str] = "custom"
ollama_endpoint: str = "http://localhost:11434/v1" ollama_endpoint: str = "http://localhost:11434/v1"

View file

@ -1,7 +1,7 @@
from typing import List from typing import List
from fastembed import TextEmbedding from fastembed import TextEmbedding
from .EmbeddingEngine import EmbeddingEngine from .EmbeddingEngine import EmbeddingEngine
from cognitive_architecture.config import Config from cognee.config import Config
config = Config() config = Config()
config.load() config.load()

View file

@ -1,6 +1,7 @@
"""Get the LLM client.""" """Get the LLM client."""
from enum import Enum from enum import Enum
from cognee.config import Config from cognee.config import Config
from .anthropic.adapter import AnthropicAdapter
from .openai.adapter import OpenAIAdapter from .openai.adapter import OpenAIAdapter
from .ollama.adapter import OllamaAPIAdapter from .ollama.adapter import OllamaAPIAdapter
import logging import logging
@ -10,6 +11,7 @@ logging.basicConfig(level=logging.INFO)
class LLMProvider(Enum): class LLMProvider(Enum):
OPENAI = "openai" OPENAI = "openai"
OLLAMA = "ollama" OLLAMA = "ollama"
ANTHROPIC = "anthropic"
CUSTOM = "custom" CUSTOM = "custom"
config = Config() config = Config()
@ -25,6 +27,9 @@ def get_llm_client():
elif provider == LLMProvider.OLLAMA: elif provider == LLMProvider.OLLAMA:
print("Using Ollama API") print("Using Ollama API")
return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model) return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
elif provider == LLMProvider.ANTHROPIC:
print("Using Anthropic API")
return AnthropicAdapter(config.ollama_endpoint, config.ollama_key, config.custom_model)
elif provider == LLMProvider.CUSTOM: elif provider == LLMProvider.CUSTOM:
print("Using Custom API") print("Using Custom API")
return OllamaAPIAdapter(config.custom_endpoint, config.custom_key, config.model) return OllamaAPIAdapter(config.custom_endpoint, config.custom_key, config.model)

View file

@ -6,20 +6,20 @@ from pydantic import BaseModel
class LLMInterface(Protocol): class LLMInterface(Protocol):
""" LLM Interface """ """ LLM Interface """
@abstractmethod # @abstractmethod
async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"): # async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"):
"""To get text embeddings, import/call this function""" # """To get text embeddings, import/call this function"""
raise NotImplementedError # raise NotImplementedError
#
@abstractmethod # @abstractmethod
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"): # def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"):
"""To get text embeddings, import/call this function""" # """To get text embeddings, import/call this function"""
raise NotImplementedError # raise NotImplementedError
#
@abstractmethod # @abstractmethod
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]): # async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
"""To get multiple text embeddings in parallel, import/call this function""" # """To get multiple text embeddings in parallel, import/call this function"""
raise NotImplementedError # raise NotImplementedError
# """ Get completions """ # """ Get completions """
# async def acompletions_with_backoff(self, **kwargs): # async def acompletions_with_backoff(self, **kwargs):

37
poetry.lock generated
View file

@ -150,6 +150,30 @@ files = [
{file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"},
] ]
[[package]]
name = "anthropic"
version = "0.21.3"
description = "The official Python library for the anthropic API"
optional = false
python-versions = ">=3.7"
files = [
{file = "anthropic-0.21.3-py3-none-any.whl", hash = "sha256:5869115453b543a46ded6515c9f29b8d610b6e94bbba3230ad80ac947d2b0862"},
{file = "anthropic-0.21.3.tar.gz", hash = "sha256:02f1ab5694c497e2b2d42d30d51a4f2edcaca92d2ec86bb64fe78a9c7434a869"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
tokenizers = ">=0.13.0"
typing-extensions = ">=4.7,<5"
[package.extras]
bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"]
vertex = ["google-auth (>=2,<3)"]
[[package]] [[package]]
name = "anyio" name = "anyio"
version = "4.3.0" version = "4.3.0"
@ -6065,6 +6089,17 @@ files = [
[package.extras] [package.extras]
dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
[[package]]
name = "xmltodict"
version = "0.13.0"
description = "Makes working with XML feel like you are working with JSON"
optional = false
python-versions = ">=3.4"
files = [
{file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
{file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"},
]
[[package]] [[package]]
name = "yarl" name = "yarl"
version = "1.9.4" version = "1.9.4"
@ -6196,4 +6231,4 @@ weaviate = ["weaviate-client"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "~3.10" python-versions = "~3.10"
content-hash = "d929caab2d4114374cf2c36e1d956a7950476ff6e0a550e50011702c568f9195" content-hash = "1a6b4648fa95a43e76eef36fc6f9951b34988001ea24b8288bcc70962f05d7db"

View file

@ -52,6 +52,8 @@ weaviate-client = "^4.5.4"
scikit-learn = "^1.4.1.post1" scikit-learn = "^1.4.1.post1"
fastembed = "^0.2.5" fastembed = "^0.2.5"
pypdf = "^4.1.0" pypdf = "^4.1.0"
anthropic = "^0.21.3"
xmltodict = "^0.13.0"
[tool.poetry.extras] [tool.poetry.extras]
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"] dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]