Ollama fixes, missing libs + config fixes
This commit is contained in:
parent
7bf80fb0f5
commit
90c41512ed
10 changed files with 64 additions and 21 deletions
|
|
@ -87,4 +87,4 @@ async def add(file_paths: Union[str, List[str]], dataset_name: str = None):
|
||||||
write_disposition = "merge",
|
write_disposition = "merge",
|
||||||
)
|
)
|
||||||
|
|
||||||
return run_info
|
return run_info
|
||||||
|
|
@ -50,4 +50,4 @@ async def add_standalone(
|
||||||
|
|
||||||
|
|
||||||
def is_data_path(data: str) -> bool:
|
def is_data_path(data: str) -> bool:
|
||||||
return False if not isinstance(data, str) else data.startswith("file://")
|
return False if not isinstance(data, str) else data.startswith("file://")
|
||||||
|
|
@ -18,4 +18,4 @@ async def remember(user_id: str, memory_name: str, payload: List[str]):
|
||||||
if await is_existing_memory(memory_name) is False:
|
if await is_existing_memory(memory_name) is False:
|
||||||
raise MemoryException(f"Memory with the name \"{memory_name}\" doesn't exist.")
|
raise MemoryException(f"Memory with the name \"{memory_name}\" doesn't exist.")
|
||||||
|
|
||||||
await create_information_points(memory_name, payload)
|
await create_information_points(memory_name, payload)
|
||||||
|
|
@ -193,4 +193,4 @@ if __name__ == "__main__":
|
||||||
print(graph_url)
|
print(graph_url)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
@ -43,7 +43,8 @@ class Config:
|
||||||
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
|
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
llm_provider: str = "openai" #openai, or custom or ollama
|
llm_provider: str = "anthropic" #openai, or custom or ollama
|
||||||
|
custom_model: str = "claude-3-haiku-20240307"
|
||||||
custom_endpoint: str = "" # pass claude endpoint
|
custom_endpoint: str = "" # pass claude endpoint
|
||||||
custom_key: Optional[str] = "custom"
|
custom_key: Optional[str] = "custom"
|
||||||
ollama_endpoint: str = "http://localhost:11434/v1"
|
ollama_endpoint: str = "http://localhost:11434/v1"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
from fastembed import TextEmbedding
|
from fastembed import TextEmbedding
|
||||||
from .EmbeddingEngine import EmbeddingEngine
|
from .EmbeddingEngine import EmbeddingEngine
|
||||||
from cognitive_architecture.config import Config
|
from cognee.config import Config
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""Get the LLM client."""
|
"""Get the LLM client."""
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from cognee.config import Config
|
from cognee.config import Config
|
||||||
|
from .anthropic.adapter import AnthropicAdapter
|
||||||
from .openai.adapter import OpenAIAdapter
|
from .openai.adapter import OpenAIAdapter
|
||||||
from .ollama.adapter import OllamaAPIAdapter
|
from .ollama.adapter import OllamaAPIAdapter
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -10,6 +11,7 @@ logging.basicConfig(level=logging.INFO)
|
||||||
class LLMProvider(Enum):
|
class LLMProvider(Enum):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
OLLAMA = "ollama"
|
OLLAMA = "ollama"
|
||||||
|
ANTHROPIC = "anthropic"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
|
|
@ -25,6 +27,9 @@ def get_llm_client():
|
||||||
elif provider == LLMProvider.OLLAMA:
|
elif provider == LLMProvider.OLLAMA:
|
||||||
print("Using Ollama API")
|
print("Using Ollama API")
|
||||||
return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
|
return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
|
||||||
|
elif provider == LLMProvider.ANTHROPIC:
|
||||||
|
print("Using Anthropic API")
|
||||||
|
return AnthropicAdapter(config.ollama_endpoint, config.ollama_key, config.custom_model)
|
||||||
elif provider == LLMProvider.CUSTOM:
|
elif provider == LLMProvider.CUSTOM:
|
||||||
print("Using Custom API")
|
print("Using Custom API")
|
||||||
return OllamaAPIAdapter(config.custom_endpoint, config.custom_key, config.model)
|
return OllamaAPIAdapter(config.custom_endpoint, config.custom_key, config.model)
|
||||||
|
|
|
||||||
|
|
@ -6,20 +6,20 @@ from pydantic import BaseModel
|
||||||
class LLMInterface(Protocol):
|
class LLMInterface(Protocol):
|
||||||
""" LLM Interface """
|
""" LLM Interface """
|
||||||
|
|
||||||
@abstractmethod
|
# @abstractmethod
|
||||||
async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"):
|
# async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"):
|
||||||
"""To get text embeddings, import/call this function"""
|
# """To get text embeddings, import/call this function"""
|
||||||
raise NotImplementedError
|
# raise NotImplementedError
|
||||||
|
#
|
||||||
@abstractmethod
|
# @abstractmethod
|
||||||
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"):
|
# def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"):
|
||||||
"""To get text embeddings, import/call this function"""
|
# """To get text embeddings, import/call this function"""
|
||||||
raise NotImplementedError
|
# raise NotImplementedError
|
||||||
|
#
|
||||||
@abstractmethod
|
# @abstractmethod
|
||||||
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
|
# async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
|
||||||
"""To get multiple text embeddings in parallel, import/call this function"""
|
# """To get multiple text embeddings in parallel, import/call this function"""
|
||||||
raise NotImplementedError
|
# raise NotImplementedError
|
||||||
|
|
||||||
# """ Get completions """
|
# """ Get completions """
|
||||||
# async def acompletions_with_backoff(self, **kwargs):
|
# async def acompletions_with_backoff(self, **kwargs):
|
||||||
|
|
|
||||||
37
poetry.lock
generated
37
poetry.lock
generated
|
|
@ -150,6 +150,30 @@ files = [
|
||||||
{file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"},
|
{file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anthropic"
|
||||||
|
version = "0.21.3"
|
||||||
|
description = "The official Python library for the anthropic API"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "anthropic-0.21.3-py3-none-any.whl", hash = "sha256:5869115453b543a46ded6515c9f29b8d610b6e94bbba3230ad80ac947d2b0862"},
|
||||||
|
{file = "anthropic-0.21.3.tar.gz", hash = "sha256:02f1ab5694c497e2b2d42d30d51a4f2edcaca92d2ec86bb64fe78a9c7434a869"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
anyio = ">=3.5.0,<5"
|
||||||
|
distro = ">=1.7.0,<2"
|
||||||
|
httpx = ">=0.23.0,<1"
|
||||||
|
pydantic = ">=1.9.0,<3"
|
||||||
|
sniffio = "*"
|
||||||
|
tokenizers = ">=0.13.0"
|
||||||
|
typing-extensions = ">=4.7,<5"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"]
|
||||||
|
vertex = ["google-auth (>=2,<3)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anyio"
|
name = "anyio"
|
||||||
version = "4.3.0"
|
version = "4.3.0"
|
||||||
|
|
@ -6065,6 +6089,17 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
|
dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "xmltodict"
|
||||||
|
version = "0.13.0"
|
||||||
|
description = "Makes working with XML feel like you are working with JSON"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.4"
|
||||||
|
files = [
|
||||||
|
{file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
|
||||||
|
{file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "yarl"
|
name = "yarl"
|
||||||
version = "1.9.4"
|
version = "1.9.4"
|
||||||
|
|
@ -6196,4 +6231,4 @@ weaviate = ["weaviate-client"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "~3.10"
|
python-versions = "~3.10"
|
||||||
content-hash = "d929caab2d4114374cf2c36e1d956a7950476ff6e0a550e50011702c568f9195"
|
content-hash = "1a6b4648fa95a43e76eef36fc6f9951b34988001ea24b8288bcc70962f05d7db"
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,8 @@ weaviate-client = "^4.5.4"
|
||||||
scikit-learn = "^1.4.1.post1"
|
scikit-learn = "^1.4.1.post1"
|
||||||
fastembed = "^0.2.5"
|
fastembed = "^0.2.5"
|
||||||
pypdf = "^4.1.0"
|
pypdf = "^4.1.0"
|
||||||
|
anthropic = "^0.21.3"
|
||||||
|
xmltodict = "^0.13.0"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]
|
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue