Merge remote-tracking branch 'origin/main' into feature/cog-537-implement-retrieval-algorithm-from-research-paper
This commit is contained in:
commit
59f8ec665f
31 changed files with 729 additions and 231 deletions
BIN
.data/multimedia/example.png
Normal file
BIN
.data/multimedia/example.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 10 KiB |
BIN
.data/multimedia/text_to_speech.mp3
Normal file
BIN
.data/multimedia/text_to_speech.mp3
Normal file
Binary file not shown.
63
.github/workflows/test_cognee_multimedia_notebook.yml
vendored
Normal file
63
.github/workflows/test_cognee_multimedia_notebook.yml
vendored
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
name: test | multimedia notebook
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
env:
|
||||||
|
RUNTIME__LOG_LEVEL: ERROR
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
get_docs_changes:
|
||||||
|
name: docs changes
|
||||||
|
uses: ./.github/workflows/get_docs_changes.yml
|
||||||
|
|
||||||
|
run_notebook_test:
|
||||||
|
name: test
|
||||||
|
needs: get_docs_changes
|
||||||
|
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
steps:
|
||||||
|
- name: Check out
|
||||||
|
uses: actions/checkout@master
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: snok/install-poetry@v1.3.2
|
||||||
|
with:
|
||||||
|
virtualenvs-create: true
|
||||||
|
virtualenvs-in-project: true
|
||||||
|
installer-parallel: true
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
poetry install --no-interaction
|
||||||
|
poetry add jupyter --no-interaction
|
||||||
|
|
||||||
|
- name: Execute Jupyter Notebook
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||||
|
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||||
|
run: |
|
||||||
|
poetry run jupyter nbconvert \
|
||||||
|
--to notebook \
|
||||||
|
--execute notebooks/cognee_multimedia_demo.ipynb \
|
||||||
|
--output executed_notebook.ipynb \
|
||||||
|
--ExecutePreprocessor.timeout=1200
|
||||||
1
.github/workflows/test_python_3_10.yml
vendored
1
.github/workflows/test_python_3_10.yml
vendored
|
|
@ -13,6 +13,7 @@ concurrency:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
RUNTIME__LOG_LEVEL: ERROR
|
RUNTIME__LOG_LEVEL: ERROR
|
||||||
|
ENV: 'dev'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
get_docs_changes:
|
get_docs_changes:
|
||||||
|
|
|
||||||
1
.github/workflows/test_python_3_11.yml
vendored
1
.github/workflows/test_python_3_11.yml
vendored
|
|
@ -13,6 +13,7 @@ concurrency:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
RUNTIME__LOG_LEVEL: ERROR
|
RUNTIME__LOG_LEVEL: ERROR
|
||||||
|
ENV: 'dev'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
get_docs_changes:
|
get_docs_changes:
|
||||||
|
|
|
||||||
1
.github/workflows/test_python_3_9.yml
vendored
1
.github/workflows/test_python_3_9.yml
vendored
|
|
@ -13,6 +13,7 @@ concurrency:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
RUNTIME__LOG_LEVEL: ERROR
|
RUNTIME__LOG_LEVEL: ERROR
|
||||||
|
ENV: 'dev'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
get_docs_changes:
|
get_docs_changes:
|
||||||
|
|
|
||||||
52
README.md
52
README.md
|
|
@ -105,37 +105,65 @@ import asyncio
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.api.v1.search import SearchType
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Reset cognee data
|
# Create a clean slate for cognee -- reset data and system state
|
||||||
|
print("Resetting cognee data...")
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
# Reset cognee system state
|
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
print("Data reset complete.\n")
|
||||||
|
|
||||||
|
# cognee knowledge graph will be created based on this text
|
||||||
text = """
|
text = """
|
||||||
Natural language processing (NLP) is an interdisciplinary
|
Natural language processing (NLP) is an interdisciplinary
|
||||||
subfield of computer science and information retrieval.
|
subfield of computer science and information retrieval.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Add text to cognee
|
print("Adding text to cognee:")
|
||||||
|
print(text.strip())
|
||||||
|
# Add the text, and make it available for cognify
|
||||||
await cognee.add(text)
|
await cognee.add(text)
|
||||||
|
print("Text added successfully.\n")
|
||||||
|
|
||||||
|
|
||||||
|
print("Running cognify to create knowledge graph...\n")
|
||||||
|
print("Cognify process steps:")
|
||||||
|
print("1. Classifying the document: Determining the type and category of the input text.")
|
||||||
|
print("2. Checking permissions: Ensuring the user has the necessary rights to process the text.")
|
||||||
|
print("3. Extracting text chunks: Breaking down the text into sentences or phrases for analysis.")
|
||||||
|
print("4. Adding data points: Storing the extracted chunks for processing.")
|
||||||
|
print("5. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph.")
|
||||||
|
print("6. Summarizing text: Creating concise summaries of the content for quick insights.\n")
|
||||||
|
|
||||||
# Use LLMs and cognee to create knowledge graph
|
# Use LLMs and cognee to create knowledge graph
|
||||||
await cognee.cognify()
|
await cognee.cognify()
|
||||||
|
print("Cognify process complete.\n")
|
||||||
|
|
||||||
# Search cognee for insights
|
|
||||||
|
query_text = 'Tell me about NLP'
|
||||||
|
print(f"Searching cognee for insights with query: '{query_text}'")
|
||||||
|
# Query cognee for insights on the added text
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
SearchType.INSIGHTS,
|
SearchType.INSIGHTS, query_text=query_text
|
||||||
"Tell me about NLP",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("Search results:")
|
||||||
# Display results
|
# Display results
|
||||||
for result_text in search_results:
|
for result_text in search_results:
|
||||||
print(result_text)
|
print(result_text)
|
||||||
# natural_language_processing is_a field
|
|
||||||
# natural_language_processing is_subfield_of computer_science
|
|
||||||
# natural_language_processing is_subfield_of information_retrieval
|
|
||||||
|
|
||||||
asyncio.run(main())
|
# Example output:
|
||||||
|
# ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'})
|
||||||
|
# (...)
|
||||||
|
#
|
||||||
|
# It represents nodes and relationships in the knowledge graph:
|
||||||
|
# - The first element is the source node (e.g., 'natural language processing').
|
||||||
|
# - The second element is the relationship between nodes (e.g., 'is_a_subfield_of').
|
||||||
|
# - The third element is the target node (e.g., 'computer science').
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
```
|
```
|
||||||
|
When you run this script, you will see step-by-step messages in the console that help you trace the execution flow and understand what the script is doing at each stage.
|
||||||
A version of this example is here: `examples/python/simple_example.py`
|
A version of this example is here: `examples/python/simple_example.py`
|
||||||
|
|
||||||
### Create your own memory store
|
### Create your own memory store
|
||||||
|
|
|
||||||
|
|
@ -9,18 +9,21 @@ async def get_graph_engine() -> GraphDBInterface :
|
||||||
config = get_graph_config()
|
config = get_graph_config()
|
||||||
|
|
||||||
if config.graph_database_provider == "neo4j":
|
if config.graph_database_provider == "neo4j":
|
||||||
try:
|
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
|
||||||
from .neo4j_driver.adapter import Neo4jAdapter
|
raise EnvironmentError("Missing required Neo4j credentials.")
|
||||||
|
|
||||||
|
from .neo4j_driver.adapter import Neo4jAdapter
|
||||||
|
|
||||||
return Neo4jAdapter(
|
return Neo4jAdapter(
|
||||||
graph_database_url = config.graph_database_url,
|
graph_database_url = config.graph_database_url,
|
||||||
graph_database_username = config.graph_database_username,
|
graph_database_username = config.graph_database_username,
|
||||||
graph_database_password = config.graph_database_password
|
graph_database_password = config.graph_database_password
|
||||||
)
|
)
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif config.graph_database_provider == "falkordb":
|
elif config.graph_database_provider == "falkordb":
|
||||||
|
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
|
||||||
|
raise EnvironmentError("Missing required FalkorDB credentials.")
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||||
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -130,6 +130,29 @@ class SQLAlchemyAdapter():
|
||||||
return metadata.tables[full_table_name]
|
return metadata.tables[full_table_name]
|
||||||
raise ValueError(f"Table '{full_table_name}' not found.")
|
raise ValueError(f"Table '{full_table_name}' not found.")
|
||||||
|
|
||||||
|
async def get_table_names(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Return a list of all tables names in database
|
||||||
|
"""
|
||||||
|
table_names = []
|
||||||
|
async with self.engine.begin() as connection:
|
||||||
|
if self.engine.dialect.name == "sqlite":
|
||||||
|
await connection.run_sync(Base.metadata.reflect)
|
||||||
|
for table in Base.metadata.tables:
|
||||||
|
table_names.append(str(table))
|
||||||
|
else:
|
||||||
|
schema_list = await self.get_schema_list()
|
||||||
|
# Create a MetaData instance to load table information
|
||||||
|
metadata = MetaData()
|
||||||
|
# Drop all tables from all schemas
|
||||||
|
for schema_name in schema_list:
|
||||||
|
# Load the schema information into the MetaData object
|
||||||
|
await connection.run_sync(metadata.reflect, schema=schema_name)
|
||||||
|
for table in metadata.sorted_tables:
|
||||||
|
table_names.append(str(table))
|
||||||
|
metadata.clear()
|
||||||
|
return table_names
|
||||||
|
|
||||||
|
|
||||||
async def get_data(self, table_name: str, filters: dict = None):
|
async def get_data(self, table_name: str, filters: dict = None):
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
|
|
|
||||||
|
|
@ -10,26 +10,29 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
if config["vector_db_provider"] == "weaviate":
|
if config["vector_db_provider"] == "weaviate":
|
||||||
from .weaviate_db import WeaviateAdapter
|
from .weaviate_db import WeaviateAdapter
|
||||||
|
|
||||||
if config["vector_db_url"] is None and config["vector_db_key"] is None:
|
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||||
raise EnvironmentError("Weaviate is not configured!")
|
raise EnvironmentError("Missing requred Weaviate credentials!")
|
||||||
|
|
||||||
return WeaviateAdapter(
|
return WeaviateAdapter(
|
||||||
config["vector_db_url"],
|
config["vector_db_url"],
|
||||||
config["vector_db_key"],
|
config["vector_db_key"],
|
||||||
embedding_engine = embedding_engine
|
embedding_engine = embedding_engine
|
||||||
)
|
)
|
||||||
elif config["vector_db_provider"] == "qdrant":
|
|
||||||
if config["vector_db_url"] and config["vector_db_key"]:
|
|
||||||
from .qdrant.QDrantAdapter import QDrantAdapter
|
|
||||||
|
|
||||||
return QDrantAdapter(
|
elif config["vector_db_provider"] == "qdrant":
|
||||||
url = config["vector_db_url"],
|
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||||
api_key = config["vector_db_key"],
|
raise EnvironmentError("Missing requred Qdrant credentials!")
|
||||||
embedding_engine = embedding_engine
|
|
||||||
)
|
from .qdrant.QDrantAdapter import QDrantAdapter
|
||||||
|
|
||||||
|
return QDrantAdapter(
|
||||||
|
url = config["vector_db_url"],
|
||||||
|
api_key = config["vector_db_key"],
|
||||||
|
embedding_engine = embedding_engine
|
||||||
|
)
|
||||||
|
|
||||||
elif config["vector_db_provider"] == "pgvector":
|
elif config["vector_db_provider"] == "pgvector":
|
||||||
from cognee.infrastructure.databases.relational import get_relational_config
|
from cognee.infrastructure.databases.relational import get_relational_config
|
||||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
|
||||||
|
|
||||||
# Get configuration for postgres database
|
# Get configuration for postgres database
|
||||||
relational_config = get_relational_config()
|
relational_config = get_relational_config()
|
||||||
|
|
@ -39,16 +42,25 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
db_port = relational_config.db_port
|
db_port = relational_config.db_port
|
||||||
db_name = relational_config.db_name
|
db_name = relational_config.db_name
|
||||||
|
|
||||||
|
if not (db_host and db_port and db_name and db_username and db_password):
|
||||||
|
raise EnvironmentError("Missing requred pgvector credentials!")
|
||||||
|
|
||||||
connection_string: str = (
|
connection_string: str = (
|
||||||
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||||
|
|
||||||
return PGVectorAdapter(
|
return PGVectorAdapter(
|
||||||
connection_string,
|
connection_string,
|
||||||
config["vector_db_key"],
|
config["vector_db_key"],
|
||||||
embedding_engine,
|
embedding_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif config["vector_db_provider"] == "falkordb":
|
elif config["vector_db_provider"] == "falkordb":
|
||||||
|
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||||
|
raise EnvironmentError("Missing requred FalkorDB credentials!")
|
||||||
|
|
||||||
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||||
|
|
||||||
return FalkorDBAdapter(
|
return FalkorDBAdapter(
|
||||||
|
|
@ -56,6 +68,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
database_port = config["vector_db_port"],
|
database_port = config["vector_db_port"],
|
||||||
embedding_engine = embedding_engine,
|
embedding_engine = embedding_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||||
|
|
||||||
|
|
@ -64,5 +77,3 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
api_key = config["vector_db_key"],
|
api_key = config["vector_db_key"],
|
||||||
embedding_engine = embedding_engine,
|
embedding_engine = embedding_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise EnvironmentError(f"Vector provider not configured correctly: {config['vector_db_provider']}")
|
|
||||||
|
|
|
||||||
|
|
@ -1,32 +1,39 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import aembedding
|
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
|
||||||
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
api_key: str
|
api_key: str
|
||||||
embedding_model: str
|
endpoint: str
|
||||||
embedding_dimensions: int
|
api_version: str
|
||||||
|
model: str
|
||||||
|
dimensions: int
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_model: Optional[str] = "text-embedding-3-large",
|
model: Optional[str] = "text-embedding-3-large",
|
||||||
embedding_dimensions: Optional[int] = 3072,
|
dimensions: Optional[int] = 3072,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
|
endpoint: str = None,
|
||||||
|
api_version: str = None,
|
||||||
):
|
):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.embedding_model = embedding_model
|
self.endpoint = endpoint
|
||||||
self.embedding_dimensions = embedding_dimensions
|
self.api_version = api_version
|
||||||
|
self.model = model
|
||||||
|
self.dimensions = dimensions
|
||||||
|
|
||||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||||
async def get_embedding(text_):
|
async def get_embedding(text_):
|
||||||
response = await aembedding(
|
response = await litellm.aembedding(
|
||||||
self.embedding_model,
|
self.model,
|
||||||
input = text_,
|
input = text_,
|
||||||
api_key = self.api_key
|
api_key = self.api_key,
|
||||||
|
api_base = self.endpoint,
|
||||||
|
api_version = self.api_version
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.data[0]["embedding"]
|
return response.data[0]["embedding"]
|
||||||
|
|
@ -36,4 +43,4 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_vector_size(self) -> int:
|
def get_vector_size(self) -> int:
|
||||||
return self.embedding_dimensions
|
return self.dimensions
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,16 @@
|
||||||
|
from typing import Optional
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
class EmbeddingConfig(BaseSettings):
|
class EmbeddingConfig(BaseSettings):
|
||||||
openai_embedding_model: str = "text-embedding-3-large"
|
embedding_model: Optional[str] = "text-embedding-3-large"
|
||||||
openai_embedding_dimensions: int = 3072
|
embedding_dimensions: Optional[int] = 3072
|
||||||
litellm_embedding_model: str = "BAAI/bge-large-en-v1.5"
|
embedding_endpoint: Optional[str] = None
|
||||||
litellm_embedding_dimensions: int = 1024
|
embedding_api_key: Optional[str] = None
|
||||||
# embedding_engine:object = DefaultEmbeddingEngine(embedding_model=litellm_embedding_model, embedding_dimensions=litellm_embedding_dimensions)
|
embedding_api_version: Optional[str] = None
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
"openai_embedding_model": self.openai_embedding_model,
|
|
||||||
"openai_embedding_dimensions": self.openai_embedding_dimensions,
|
|
||||||
"litellm_embedding_model": self.litellm_embedding_model,
|
|
||||||
"litellm_embedding_dimensions": self.litellm_embedding_dimensions,
|
|
||||||
}
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def get_embedding_config():
|
def get_embedding_config():
|
||||||
return EmbeddingConfig()
|
return EmbeddingConfig()
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,17 @@
|
||||||
from cognee.infrastructure.llm import get_llm_config
|
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
||||||
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
from .EmbeddingEngine import EmbeddingEngine
|
from .EmbeddingEngine import EmbeddingEngine
|
||||||
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
||||||
|
|
||||||
def get_embedding_engine() -> EmbeddingEngine:
|
def get_embedding_engine() -> EmbeddingEngine:
|
||||||
|
config = get_embedding_config()
|
||||||
llm_config = get_llm_config()
|
llm_config = get_llm_config()
|
||||||
return LiteLLMEmbeddingEngine(api_key = llm_config.llm_api_key)
|
|
||||||
|
return LiteLLMEmbeddingEngine(
|
||||||
|
# If OpenAI API is used for embeddings, litellm needs only the api_key.
|
||||||
|
api_key = config.embedding_api_key or llm_config.llm_api_key,
|
||||||
|
endpoint = config.embedding_endpoint,
|
||||||
|
api_version = config.embedding_api_version,
|
||||||
|
model = config.embedding_model,
|
||||||
|
dimensions = config.embedding_dimensions,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,5 +10,3 @@ async def create_db_and_tables():
|
||||||
await vector_engine.create_database()
|
await vector_engine.create_database()
|
||||||
async with vector_engine.engine.begin() as connection:
|
async with vector_engine.engine.begin() as connection:
|
||||||
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ class LLMConfig(BaseSettings):
|
||||||
llm_model: str = "gpt-4o-mini"
|
llm_model: str = "gpt-4o-mini"
|
||||||
llm_endpoint: str = ""
|
llm_endpoint: str = ""
|
||||||
llm_api_key: Optional[str] = None
|
llm_api_key: Optional[str] = None
|
||||||
|
llm_api_version: Optional[str] = None
|
||||||
llm_temperature: float = 0.0
|
llm_temperature: float = 0.0
|
||||||
llm_streaming: bool = False
|
llm_streaming: bool = False
|
||||||
transcription_model: str = "whisper-1"
|
transcription_model: str = "whisper-1"
|
||||||
|
|
@ -19,6 +20,7 @@ class LLMConfig(BaseSettings):
|
||||||
"model": self.llm_model,
|
"model": self.llm_model,
|
||||||
"endpoint": self.llm_endpoint,
|
"endpoint": self.llm_endpoint,
|
||||||
"api_key": self.llm_api_key,
|
"api_key": self.llm_api_key,
|
||||||
|
"api_version": self.llm_api_version,
|
||||||
"temperature": self.llm_temperature,
|
"temperature": self.llm_temperature,
|
||||||
"streaming": self.llm_streaming,
|
"streaming": self.llm_streaming,
|
||||||
"transcription_model": self.transcription_model
|
"transcription_model": self.transcription_model
|
||||||
|
|
|
||||||
|
|
@ -20,21 +20,33 @@ def get_llm_client():
|
||||||
raise ValueError("LLM API key is not set.")
|
raise ValueError("LLM API key is not set.")
|
||||||
|
|
||||||
from .openai.adapter import OpenAIAdapter
|
from .openai.adapter import OpenAIAdapter
|
||||||
return OpenAIAdapter(api_key=llm_config.llm_api_key, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, streaming=llm_config.llm_streaming)
|
|
||||||
|
return OpenAIAdapter(
|
||||||
|
api_key = llm_config.llm_api_key,
|
||||||
|
endpoint = llm_config.llm_endpoint,
|
||||||
|
api_version = llm_config.llm_api_version,
|
||||||
|
model = llm_config.llm_model,
|
||||||
|
transcription_model = llm_config.transcription_model,
|
||||||
|
streaming = llm_config.llm_streaming,
|
||||||
|
)
|
||||||
|
|
||||||
elif provider == LLMProvider.OLLAMA:
|
elif provider == LLMProvider.OLLAMA:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None:
|
||||||
raise ValueError("LLM API key is not set.")
|
raise ValueError("LLM API key is not set.")
|
||||||
|
|
||||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
|
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
|
||||||
|
|
||||||
elif provider == LLMProvider.ANTHROPIC:
|
elif provider == LLMProvider.ANTHROPIC:
|
||||||
from .anthropic.adapter import AnthropicAdapter
|
from .anthropic.adapter import AnthropicAdapter
|
||||||
return AnthropicAdapter(llm_config.llm_model)
|
return AnthropicAdapter(llm_config.llm_model)
|
||||||
|
|
||||||
elif provider == LLMProvider.CUSTOM:
|
elif provider == LLMProvider.CUSTOM:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None:
|
||||||
raise ValueError("LLM API key is not set.")
|
raise ValueError("LLM API key is not set.")
|
||||||
|
|
||||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
|
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -1,174 +1,127 @@
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import os
|
import os
|
||||||
|
import base64
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Type
|
from typing import Type
|
||||||
|
|
||||||
import openai
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tenacity import retry, stop_after_attempt
|
|
||||||
|
|
||||||
from cognee.base_config import get_base_config
|
|
||||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
# from cognee.shared.data_models import MonitoringTool
|
|
||||||
|
|
||||||
class OpenAIAdapter(LLMInterface):
|
class OpenAIAdapter(LLMInterface):
|
||||||
name = "OpenAI"
|
name = "OpenAI"
|
||||||
model: str
|
model: str
|
||||||
api_key: str
|
api_key: str
|
||||||
|
api_version: str
|
||||||
|
|
||||||
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
||||||
def __init__(self, api_key: str, model: str, transcription_model:str, streaming: bool = False):
|
def __init__(
|
||||||
base_config = get_base_config()
|
self,
|
||||||
|
api_key: str,
|
||||||
# if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
|
endpoint: str,
|
||||||
# from langfuse.openai import AsyncOpenAI, OpenAI
|
api_version: str,
|
||||||
# elif base_config.monitoring_tool == MonitoringTool.LANGSMITH:
|
model: str,
|
||||||
# from langsmith import wrappers
|
transcription_model: str,
|
||||||
# from openai import AsyncOpenAI
|
streaming: bool = False,
|
||||||
# AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
|
):
|
||||||
# else:
|
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||||
from openai import AsyncOpenAI, OpenAI
|
self.client = instructor.from_litellm(litellm.completion)
|
||||||
|
self.transcription_model = transcription_model
|
||||||
self.aclient = instructor.from_openai(AsyncOpenAI(api_key = api_key))
|
|
||||||
self.client = instructor.from_openai(OpenAI(api_key = api_key))
|
|
||||||
self.base_openai_client = OpenAI(api_key = api_key)
|
|
||||||
self.transcription_model = "whisper-1"
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.api_version = api_version
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
def completions_with_backoff(self, **kwargs):
|
|
||||||
"""Wrapper around ChatCompletion.create w/ backoff"""
|
|
||||||
return openai.chat.completions.create(**kwargs)
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
async def acompletions_with_backoff(self,**kwargs):
|
|
||||||
"""Wrapper around ChatCompletion.acreate w/ backoff"""
|
|
||||||
return await openai.chat.completions.acreate(**kwargs)
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
|
|
||||||
"""Wrapper around Embedding.acreate w/ backoff"""
|
|
||||||
|
|
||||||
return await self.aclient.embeddings.create(input = input, model = model)
|
|
||||||
|
|
||||||
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
|
|
||||||
"""To get text embeddings, import/call this function
|
|
||||||
It specifies defaults + handles rate-limiting + is async"""
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
response = await self.aclient.embeddings.create(input = text, model = model)
|
|
||||||
embedding = response.data[0].embedding
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
def create_embedding_with_backoff(self, **kwargs):
|
|
||||||
"""Wrapper around Embedding.create w/ backoff"""
|
|
||||||
return openai.embeddings.create(**kwargs)
|
|
||||||
|
|
||||||
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
|
|
||||||
"""To get text embeddings, import/call this function
|
|
||||||
It specifies defaults + handles rate-limiting
|
|
||||||
:param text: str
|
|
||||||
:param model: str
|
|
||||||
"""
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
response = self.create_embedding_with_backoff(input=[text], model=model)
|
|
||||||
embedding = response.data[0].embedding
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
|
|
||||||
"""To get multiple text embeddings in parallel, import/call this function
|
|
||||||
It specifies defaults + handles rate-limiting + is async"""
|
|
||||||
# Collect all coroutines
|
|
||||||
coroutines = (self.async_get_embedding_with_backoff(text, model)
|
|
||||||
for text, model in zip(texts, models))
|
|
||||||
|
|
||||||
# Run the coroutines in parallel and gather the results
|
|
||||||
embeddings = await asyncio.gather(*coroutines)
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
||||||
"""Generate a response from a user query."""
|
"""Generate a response from a user query."""
|
||||||
|
|
||||||
return await self.aclient.chat.completions.create(
|
return await self.aclient.chat.completions.create(
|
||||||
model = self.model,
|
model = self.model,
|
||||||
messages = [
|
messages = [{
|
||||||
{
|
"role": "user",
|
||||||
"role": "user",
|
"content": f"""Use the given format to
|
||||||
"content": f"""Use the given format to
|
extract information from the following input: {text_input}. """,
|
||||||
extract information from the following input: {text_input}. """,
|
}, {
|
||||||
},
|
"role": "system",
|
||||||
{"role": "system", "content": system_prompt},
|
"content": system_prompt,
|
||||||
],
|
}],
|
||||||
|
api_key = self.api_key,
|
||||||
|
api_base = self.endpoint,
|
||||||
|
api_version = self.api_version,
|
||||||
response_model = response_model,
|
response_model = response_model,
|
||||||
|
max_retries = 5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
||||||
"""Generate a response from a user query."""
|
"""Generate a response from a user query."""
|
||||||
|
|
||||||
return self.client.chat.completions.create(
|
return self.client.chat.completions.create(
|
||||||
model = self.model,
|
model = self.model,
|
||||||
messages = [
|
messages = [{
|
||||||
{
|
"role": "user",
|
||||||
"role": "user",
|
"content": f"""Use the given format to
|
||||||
"content": f"""Use the given format to
|
extract information from the following input: {text_input}. """,
|
||||||
extract information from the following input: {text_input}. """,
|
}, {
|
||||||
},
|
"role": "system",
|
||||||
{"role": "system", "content": system_prompt},
|
"content": system_prompt,
|
||||||
],
|
}],
|
||||||
|
api_key = self.api_key,
|
||||||
|
api_base = self.endpoint,
|
||||||
|
api_version = self.api_version,
|
||||||
response_model = response_model,
|
response_model = response_model,
|
||||||
|
max_retries = 5,
|
||||||
)
|
)
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
def create_transcript(self, input):
|
def create_transcript(self, input):
|
||||||
"""Generate a audio transcript from a user query."""
|
"""Generate a audio transcript from a user query."""
|
||||||
|
|
||||||
if not os.path.isfile(input):
|
if not os.path.isfile(input):
|
||||||
raise FileNotFoundError(f"The file {input} does not exist.")
|
raise FileNotFoundError(f"The file {input} does not exist.")
|
||||||
|
|
||||||
with open(input, 'rb') as audio_file:
|
# with open(input, 'rb') as audio_file:
|
||||||
audio_data = audio_file.read()
|
# audio_data = audio_file.read()
|
||||||
|
|
||||||
|
transcription = litellm.transcription(
|
||||||
|
model = self.transcription_model,
|
||||||
transcription = self.base_openai_client.audio.transcriptions.create(
|
file = Path(input),
|
||||||
model=self.transcription_model ,
|
api_key=self.api_key,
|
||||||
file=Path(input),
|
api_base=self.endpoint,
|
||||||
)
|
api_version=self.api_version,
|
||||||
|
max_retries = 5,
|
||||||
|
)
|
||||||
|
|
||||||
return transcription
|
return transcription
|
||||||
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
def transcribe_image(self, input) -> BaseModel:
|
def transcribe_image(self, input) -> BaseModel:
|
||||||
with open(input, "rb") as image_file:
|
with open(input, "rb") as image_file:
|
||||||
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
|
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
|
||||||
return self.base_openai_client.chat.completions.create(
|
return litellm.completion(
|
||||||
model=self.model,
|
model = self.model,
|
||||||
messages=[
|
messages = [{
|
||||||
{
|
"role": "user",
|
||||||
"role": "user",
|
"content": [
|
||||||
"content": [
|
{
|
||||||
{"type": "text", "text": "What’s in this image?"},
|
"type": "text",
|
||||||
{
|
"text": "What’s in this image?",
|
||||||
"type": "image_url",
|
}, {
|
||||||
"image_url": {
|
"type": "image_url",
|
||||||
"url": f"data:image/jpeg;base64,{encoded_image}",
|
"image_url": {
|
||||||
},
|
"url": f"data:image/jpeg;base64,{encoded_image}",
|
||||||
},
|
},
|
||||||
],
|
},
|
||||||
}
|
],
|
||||||
],
|
}],
|
||||||
max_tokens=300,
|
api_key=self.api_key,
|
||||||
|
api_base=self.endpoint,
|
||||||
|
api_version=self.api_version,
|
||||||
|
max_tokens = 300,
|
||||||
|
max_retries = 5,
|
||||||
)
|
)
|
||||||
|
|
||||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||||
"""Format and display the prompt for a user query."""
|
"""Format and display the prompt for a user query."""
|
||||||
if not text_input:
|
if not text_input:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
|
import json
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
from cognee.modules.settings import get_current_settings
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
|
@ -157,7 +159,7 @@ async def run_tasks_base(tasks: list[Task], data = None, user: User = None):
|
||||||
})
|
})
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pipeline"):
|
async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -185,3 +187,10 @@ async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pi
|
||||||
})
|
})
|
||||||
|
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
|
async def run_tasks(tasks: list[Task], data = None, pipeline_name: str = "default_pipeline"):
|
||||||
|
config = get_current_settings()
|
||||||
|
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent = 1))
|
||||||
|
|
||||||
|
async for result in run_tasks_with_telemetry(tasks, data, pipeline_name):
|
||||||
|
yield result
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from .get_current_settings import get_current_settings
|
||||||
from .get_settings import get_settings, SettingsDict
|
from .get_settings import get_settings, SettingsDict
|
||||||
from .save_llm_config import save_llm_config
|
from .save_llm_config import save_llm_config
|
||||||
from .save_vector_db_config import save_vector_db_config
|
from .save_vector_db_config import save_vector_db_config
|
||||||
|
|
|
||||||
54
cognee/modules/settings/get_current_settings.py
Normal file
54
cognee/modules/settings/get_current_settings.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
from typing import TypedDict
|
||||||
|
from cognee.infrastructure.llm import get_llm_config
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
|
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||||
|
from cognee.infrastructure.databases.relational.config import get_relational_config
|
||||||
|
|
||||||
|
class LLMConfig(TypedDict):
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
class VectorDBConfig(TypedDict):
|
||||||
|
url: str
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
class GraphDBConfig(TypedDict):
|
||||||
|
url: str
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
class RelationalConfig(TypedDict):
|
||||||
|
url: str
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
class SettingsDict(TypedDict):
|
||||||
|
llm: LLMConfig
|
||||||
|
graph: GraphDBConfig
|
||||||
|
vector: VectorDBConfig
|
||||||
|
relational: RelationalConfig
|
||||||
|
|
||||||
|
def get_current_settings() -> SettingsDict:
|
||||||
|
llm_config = get_llm_config()
|
||||||
|
graph_config = get_graph_config()
|
||||||
|
vector_config = get_vectordb_config()
|
||||||
|
relational_config = get_relational_config()
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
llm = {
|
||||||
|
"provider": llm_config.llm_provider,
|
||||||
|
"model": llm_config.llm_model,
|
||||||
|
},
|
||||||
|
graph = {
|
||||||
|
"provider": graph_config.graph_database_provider,
|
||||||
|
"url": graph_config.graph_database_url or graph_config.graph_file_path,
|
||||||
|
},
|
||||||
|
vector = {
|
||||||
|
"provider": vector_config.vector_db_provider,
|
||||||
|
"url": vector_config.vector_db_url,
|
||||||
|
},
|
||||||
|
relational = {
|
||||||
|
"provider": relational_config.db_provider,
|
||||||
|
"url": f"{relational_config.db_host}:{relational_config.db_port}" \
|
||||||
|
if relational_config.db_host \
|
||||||
|
else f"{relational_config.db_path}/{relational_config.db_name}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
@ -1,16 +1,51 @@
|
||||||
from cognee.modules.data.models import Data
|
from cognee.modules.data.models import Data
|
||||||
from cognee.modules.data.processing.document_types import Document, PdfDocument, AudioDocument, ImageDocument, TextDocument
|
from cognee.modules.data.processing.document_types import (
|
||||||
|
Document,
|
||||||
|
PdfDocument,
|
||||||
|
AudioDocument,
|
||||||
|
ImageDocument,
|
||||||
|
TextDocument,
|
||||||
|
)
|
||||||
|
|
||||||
EXTENSION_TO_DOCUMENT_CLASS = {
|
EXTENSION_TO_DOCUMENT_CLASS = {
|
||||||
"pdf": PdfDocument,
|
"pdf": PdfDocument, # Text documents
|
||||||
"audio": AudioDocument,
|
"txt": TextDocument,
|
||||||
"image": ImageDocument,
|
"png": ImageDocument, # Image documents
|
||||||
"txt": TextDocument
|
"dwg": ImageDocument,
|
||||||
|
"xcf": ImageDocument,
|
||||||
|
"jpg": ImageDocument,
|
||||||
|
"jpx": ImageDocument,
|
||||||
|
"apng": ImageDocument,
|
||||||
|
"gif": ImageDocument,
|
||||||
|
"webp": ImageDocument,
|
||||||
|
"cr2": ImageDocument,
|
||||||
|
"tif": ImageDocument,
|
||||||
|
"bmp": ImageDocument,
|
||||||
|
"jxr": ImageDocument,
|
||||||
|
"psd": ImageDocument,
|
||||||
|
"ico": ImageDocument,
|
||||||
|
"heic": ImageDocument,
|
||||||
|
"avif": ImageDocument,
|
||||||
|
"aac": AudioDocument, # Audio documents
|
||||||
|
"mid": AudioDocument,
|
||||||
|
"mp3": AudioDocument,
|
||||||
|
"m4a": AudioDocument,
|
||||||
|
"ogg": AudioDocument,
|
||||||
|
"flac": AudioDocument,
|
||||||
|
"wav": AudioDocument,
|
||||||
|
"amr": AudioDocument,
|
||||||
|
"aiff": AudioDocument,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def classify_documents(data_documents: list[Data]) -> list[Document]:
|
def classify_documents(data_documents: list[Data]) -> list[Document]:
|
||||||
documents = [
|
documents = [
|
||||||
EXTENSION_TO_DOCUMENT_CLASS[data_item.extension](id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location, name=data_item.name)
|
EXTENSION_TO_DOCUMENT_CLASS[data_item.extension](
|
||||||
|
id=data_item.id,
|
||||||
|
title=f"{data_item.name}.{data_item.extension}",
|
||||||
|
raw_data_location=data_item.raw_data_location,
|
||||||
|
name=data_item.name,
|
||||||
|
)
|
||||||
for data_item in data_documents
|
for data_item in data_documents
|
||||||
]
|
]
|
||||||
return documents
|
return documents
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,24 @@ async def main():
|
||||||
|
|
||||||
assert len(history) == 6, "Search history is not correct."
|
assert len(history) == 6, "Search history is not correct."
|
||||||
|
|
||||||
|
# Assert local data files are cleaned properly
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
||||||
|
# Assert relational, vector and graph databases have been cleaned properly
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
connection = await vector_engine.get_connection()
|
||||||
|
collection_names = await connection.table_names()
|
||||||
|
assert len(collection_names) == 0, "LanceDB vector database is not empty"
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
assert not os.path.exists(get_relational_engine().db_path), "SQLite relational database is not empty"
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
|
graph_config = get_graph_config()
|
||||||
|
assert not os.path.exists(graph_config.graph_file_path), "Networkx graph database is not empty"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
asyncio.run(main(), debug=True)
|
asyncio.run(main(), debug=True)
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,15 @@ async def main():
|
||||||
|
|
||||||
assert len(history) == 6, "Search history is not correct."
|
assert len(history) == 6, "Search history is not correct."
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) == 0 and len(edges) == 0, "Neo4j graph database is not empty"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
|
|
@ -87,9 +87,15 @@ async def main():
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
history = await cognee.get_search_history()
|
history = await cognee.get_search_history()
|
||||||
|
|
||||||
assert len(history) == 6, "Search history is not correct."
|
assert len(history) == 6, "Search history is not correct."
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
tables_in_database = await vector_engine.get_table_names()
|
||||||
|
assert len(tables_in_database) == 0, "PostgreSQL database is not empty"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,9 +59,16 @@ async def main():
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
history = await cognee.get_search_history()
|
history = await cognee.get_search_history()
|
||||||
|
|
||||||
assert len(history) == 6, "Search history is not correct."
|
assert len(history) == 6, "Search history is not correct."
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
qdrant_client = get_vector_engine().get_qdrant_client()
|
||||||
|
collections_response = await qdrant_client.get_collections()
|
||||||
|
assert len(collections_response.collections) == 0, "QDrant vector database is not empty"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
|
|
@ -57,9 +57,15 @@ async def main():
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
history = await cognee.get_search_history()
|
history = await cognee.get_search_history()
|
||||||
|
|
||||||
assert len(history) == 6, "Search history is not correct."
|
assert len(history) == 6, "Search history is not correct."
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
collections = get_vector_engine().client.collections.list_all()
|
||||||
|
assert len(collections) == 0, "Weaviate vector database is not empty"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
48
examples/python/multimedia_example.py
Normal file
48
examples/python/multimedia_example.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.api.v1.search import SearchType
|
||||||
|
|
||||||
|
# Prerequisites:
|
||||||
|
# 1. Copy `.env.template` and rename it to `.env`.
|
||||||
|
# 2. Add your OpenAI API key to the `.env` file in the `LLM_API_KEY` field:
|
||||||
|
# LLM_API_KEY = "your_key_here"
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Create a clean slate for cognee -- reset data and system state
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
# cognee knowledge graph will be created based on the text
|
||||||
|
# and description of these files
|
||||||
|
mp3_file_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent.parent.parent,
|
||||||
|
".data/multimedia/text_to_speech.mp3",
|
||||||
|
)
|
||||||
|
png_file_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent.parent.parent,
|
||||||
|
".data/multimedia/example.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the files, and make it available for cognify
|
||||||
|
await cognee.add([mp3_file_path, png_file_path])
|
||||||
|
|
||||||
|
# Use LLMs and cognee to create knowledge graph
|
||||||
|
await cognee.cognify()
|
||||||
|
|
||||||
|
# Query cognee for summaries of the data in the multimedia files
|
||||||
|
search_results = await cognee.search(
|
||||||
|
SearchType.SUMMARIES,
|
||||||
|
query_text="What is in the multimedia files?",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display search results
|
||||||
|
for result_text in search_results:
|
||||||
|
print(result_text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.api.v1.search import SearchType
|
||||||
|
|
||||||
|
|
@ -11,29 +10,57 @@ from cognee.api.v1.search import SearchType
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Create a clean slate for cognee -- reset data and system state
|
# Create a clean slate for cognee -- reset data and system state
|
||||||
|
print("Resetting cognee data...")
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
print("Data reset complete.\n")
|
||||||
|
|
||||||
# cognee knowledge graph will be created based on this text
|
# cognee knowledge graph will be created based on this text
|
||||||
text = """
|
text = """
|
||||||
Natural language processing (NLP) is an interdisciplinary
|
Natural language processing (NLP) is an interdisciplinary
|
||||||
subfield of computer science and information retrieval.
|
subfield of computer science and information retrieval.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
print("Adding text to cognee:")
|
||||||
|
print(text.strip())
|
||||||
# Add the text, and make it available for cognify
|
# Add the text, and make it available for cognify
|
||||||
await cognee.add(text)
|
await cognee.add(text)
|
||||||
|
print("Text added successfully.\n")
|
||||||
|
|
||||||
|
|
||||||
|
print("Running cognify to create knowledge graph...\n")
|
||||||
|
print("Cognify process steps:")
|
||||||
|
print("1. Classifying the document: Determining the type and category of the input text.")
|
||||||
|
print("2. Checking permissions: Ensuring the user has the necessary rights to process the text.")
|
||||||
|
print("3. Extracting text chunks: Breaking down the text into sentences or phrases for analysis.")
|
||||||
|
print("4. Adding data points: Storing the extracted chunks for processing.")
|
||||||
|
print("5. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph.")
|
||||||
|
print("6. Summarizing text: Creating concise summaries of the content for quick insights.\n")
|
||||||
|
|
||||||
# Use LLMs and cognee to create knowledge graph
|
# Use LLMs and cognee to create knowledge graph
|
||||||
await cognee.cognify()
|
await cognee.cognify()
|
||||||
|
print("Cognify process complete.\n")
|
||||||
|
|
||||||
|
|
||||||
|
query_text = 'Tell me about NLP'
|
||||||
|
print(f"Searching cognee for insights with query: '{query_text}'")
|
||||||
# Query cognee for insights on the added text
|
# Query cognee for insights on the added text
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
SearchType.INSIGHTS, query_text='Tell me about NLP'
|
SearchType.INSIGHTS, query_text=query_text
|
||||||
)
|
)
|
||||||
|
|
||||||
# Display search results
|
print("Search results:")
|
||||||
|
# Display results
|
||||||
for result_text in search_results:
|
for result_text in search_results:
|
||||||
print(result_text)
|
print(result_text)
|
||||||
|
|
||||||
|
# Example output:
|
||||||
|
# ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'})
|
||||||
|
# (...)
|
||||||
|
# It represents nodes and relationships in the knowledge graph:
|
||||||
|
# - The first element is the source node (e.g., 'natural language processing').
|
||||||
|
# - The second element is the relationship between nodes (e.g., 'is_a_subfield_of').
|
||||||
|
# - The third element is the target node (e.g., 'computer science').
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
|
|
@ -265,7 +265,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 1,
|
||||||
"id": "df16431d0f48b006",
|
"id": "df16431d0f48b006",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
|
|
@ -304,7 +304,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 2,
|
||||||
"id": "9086abf3af077ab4",
|
"id": "9086abf3af077ab4",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
|
|
@ -349,7 +349,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 3,
|
||||||
"id": "a9de0cc07f798b7f",
|
"id": "a9de0cc07f798b7f",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
|
|
@ -393,7 +393,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 4,
|
||||||
"id": "185ff1c102d06111",
|
"id": "185ff1c102d06111",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
|
|
@ -437,7 +437,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 5,
|
||||||
"id": "d55ce4c58f8efb67",
|
"id": "d55ce4c58f8efb67",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
|
|
@ -479,7 +479,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 6,
|
||||||
"id": "ca4ecc32721ad332",
|
"id": "ca4ecc32721ad332",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
|
|
@ -529,14 +529,14 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 7,
|
||||||
"id": "bce39dc6",
|
"id": "bce39dc6",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# # Setting environment variables\n",
|
"# Setting environment variables\n",
|
||||||
"if \"GRAPHISTRY_USERNAME\" not in os.environ: \n",
|
"if \"GRAPHISTRY_USERNAME\" not in os.environ: \n",
|
||||||
" os.environ[\"GRAPHISTRY_USERNAME\"] = \"\"\n",
|
" os.environ[\"GRAPHISTRY_USERNAME\"] = \"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -546,24 +546,26 @@
|
||||||
"if \"LLM_API_KEY\" not in os.environ:\n",
|
"if \"LLM_API_KEY\" not in os.environ:\n",
|
||||||
" os.environ[\"LLM_API_KEY\"] = \"\"\n",
|
" os.environ[\"LLM_API_KEY\"] = \"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" # \"neo4j\" or \"networkx\"\n",
|
"# \"neo4j\" or \"networkx\"\n",
|
||||||
|
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" \n",
|
||||||
"# Not needed if using networkx\n",
|
"# Not needed if using networkx\n",
|
||||||
"#GRAPH_DATABASE_URL=\"\"\n",
|
"#os.environ[\"GRAPH_DATABASE_URL\"]=\"\"\n",
|
||||||
"#GRAPH_DATABASE_USERNAME=\"\"\n",
|
"#os.environ[\"GRAPH_DATABASE_USERNAME\"]=\"\"\n",
|
||||||
"#GRAPH_DATABASE_PASSWORD=\"\"\n",
|
"#os.environ[\"GRAPH_DATABASE_PASSWORD\"]=\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" # \"qdrant\", \"weaviate\" or \"lancedb\"\n",
|
"# \"pgvector\", \"qdrant\", \"weaviate\" or \"lancedb\"\n",
|
||||||
"# Not needed if using \"lancedb\"\n",
|
"os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" \n",
|
||||||
|
"# Not needed if using \"lancedb\" or \"pgvector\"\n",
|
||||||
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
|
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
|
||||||
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
|
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Database provider\n",
|
"# Relational Database provider \"sqlite\" or \"postgres\"\n",
|
||||||
"os.environ[\"DB_PROVIDER\"]=\"sqlite\" # or \"postgres\"\n",
|
"os.environ[\"DB_PROVIDER\"]=\"sqlite\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Database name\n",
|
"# Database name\n",
|
||||||
"os.environ[\"DB_NAME\"]=\"cognee_db\"\n",
|
"os.environ[\"DB_NAME\"]=\"cognee_db\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Postgres specific parameters (Only if Postgres is run)\n",
|
"# Postgres specific parameters (Only if Postgres or PGVector is used)\n",
|
||||||
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
|
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
|
||||||
"# os.environ[\"DB_PORT\"]=\"5432\"\n",
|
"# os.environ[\"DB_PORT\"]=\"5432\"\n",
|
||||||
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
|
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
|
||||||
|
|
@ -620,7 +622,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 10,
|
||||||
"id": "7c431fdef4921ae0",
|
"id": "7c431fdef4921ae0",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
|
@ -71,7 +71,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
|
@ -90,23 +90,23 @@
|
||||||
"# \"neo4j\" or \"networkx\"\n",
|
"# \"neo4j\" or \"networkx\"\n",
|
||||||
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" \n",
|
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" \n",
|
||||||
"# Not needed if using networkx\n",
|
"# Not needed if using networkx\n",
|
||||||
"#GRAPH_DATABASE_URL=\"\"\n",
|
"#os.environ[\"GRAPH_DATABASE_URL\"]=\"\"\n",
|
||||||
"#GRAPH_DATABASE_USERNAME=\"\"\n",
|
"#os.environ[\"GRAPH_DATABASE_USERNAME\"]=\"\"\n",
|
||||||
"#GRAPH_DATABASE_PASSWORD=\"\"\n",
|
"#os.environ[\"GRAPH_DATABASE_PASSWORD\"]=\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# \"qdrant\", \"weaviate\" or \"lancedb\"\n",
|
"# \"pgvector\", \"qdrant\", \"weaviate\" or \"lancedb\"\n",
|
||||||
"os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" \n",
|
"os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" \n",
|
||||||
"# Not needed if using \"lancedb\"\n",
|
"# Not needed if using \"lancedb\" or \"pgvector\"\n",
|
||||||
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
|
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
|
||||||
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
|
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Database provider\n",
|
"# Relational Database provider \"sqlite\" or \"postgres\"\n",
|
||||||
"os.environ[\"DB_PROVIDER\"]=\"sqlite\" # or \"postgres\"\n",
|
"os.environ[\"DB_PROVIDER\"]=\"sqlite\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Database name\n",
|
"# Database name\n",
|
||||||
"os.environ[\"DB_NAME\"]=\"cognee_db\"\n",
|
"os.environ[\"DB_NAME\"]=\"cognee_db\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Postgres specific parameters (Only if Postgres is run)\n",
|
"# Postgres specific parameters (Only if Postgres or PGVector is used)\n",
|
||||||
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
|
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
|
||||||
"# os.environ[\"DB_PORT\"]=\"5432\"\n",
|
"# os.environ[\"DB_PORT\"]=\"5432\"\n",
|
||||||
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
|
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
|
||||||
|
|
@ -130,8 +130,6 @@
|
||||||
"\n",
|
"\n",
|
||||||
"from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables\n",
|
"from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables\n",
|
||||||
"from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables\n",
|
"from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables\n",
|
||||||
"from cognee.infrastructure.databases.graph import get_graph_engine\n",
|
|
||||||
"from cognee.shared.utils import render_graph\n",
|
|
||||||
"from cognee.modules.users.models import User\n",
|
"from cognee.modules.users.models import User\n",
|
||||||
"from cognee.modules.users.methods import get_default_user\n",
|
"from cognee.modules.users.methods import get_default_user\n",
|
||||||
"from cognee.tasks.ingestion.ingest_data_with_metadata import ingest_data_with_metadata\n",
|
"from cognee.tasks.ingestion.ingest_data_with_metadata import ingest_data_with_metadata\n",
|
||||||
|
|
@ -196,6 +194,9 @@
|
||||||
"source": [
|
"source": [
|
||||||
"import graphistry\n",
|
"import graphistry\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"from cognee.infrastructure.databases.graph import get_graph_engine\n",
|
||||||
|
"from cognee.shared.utils import render_graph\n",
|
||||||
|
"\n",
|
||||||
"# Get graph\n",
|
"# Get graph\n",
|
||||||
"graphistry.login(username=os.getenv(\"GRAPHISTRY_USERNAME\"), password=os.getenv(\"GRAPHISTRY_PASSWORD\"))\n",
|
"graphistry.login(username=os.getenv(\"GRAPHISTRY_USERNAME\"), password=os.getenv(\"GRAPHISTRY_PASSWORD\"))\n",
|
||||||
"graph_engine = await get_graph_engine()\n",
|
"graph_engine = await get_graph_engine()\n",
|
||||||
|
|
|
||||||
169
notebooks/cognee_multimedia_demo.ipynb
Normal file
169
notebooks/cognee_multimedia_demo.ipynb
Normal file
|
|
@ -0,0 +1,169 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Cognee GraphRAG with Multimedia files"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"vscode": {
|
||||||
|
"languageId": "plaintext"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Load Data\n",
|
||||||
|
"\n",
|
||||||
|
"We will use a few sample multimedia files which we have on GitHub for easy access."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import pathlib\n",
|
||||||
|
"\n",
|
||||||
|
"# cognee knowledge graph will be created based on the text\n",
|
||||||
|
"# and description of these files\n",
|
||||||
|
"mp3_file_path = os.path.join(\n",
|
||||||
|
" os.path.abspath(''), \"../\",\n",
|
||||||
|
" \".data/multimedia/text_to_speech.mp3\",\n",
|
||||||
|
")\n",
|
||||||
|
"png_file_path = os.path.join(\n",
|
||||||
|
" os.path.abspath(''), \"../\",\n",
|
||||||
|
" \".data/multimedia/example.png\",\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Set environment variables"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"# Setting environment variables\n",
|
||||||
|
"if \"GRAPHISTRY_USERNAME\" not in os.environ: \n",
|
||||||
|
" os.environ[\"GRAPHISTRY_USERNAME\"] = \"\"\n",
|
||||||
|
"\n",
|
||||||
|
"if \"GRAPHISTRY_PASSWORD\" not in os.environ: \n",
|
||||||
|
" os.environ[\"GRAPHISTRY_PASSWORD\"] = \"\"\n",
|
||||||
|
"\n",
|
||||||
|
"if \"LLM_API_KEY\" not in os.environ:\n",
|
||||||
|
" os.environ[\"LLM_API_KEY\"] = \"\"\n",
|
||||||
|
"\n",
|
||||||
|
"# \"neo4j\" or \"networkx\"\n",
|
||||||
|
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" \n",
|
||||||
|
"# Not needed if using networkx\n",
|
||||||
|
"#os.environ[\"GRAPH_DATABASE_URL\"]=\"\"\n",
|
||||||
|
"#os.environ[\"GRAPH_DATABASE_USERNAME\"]=\"\"\n",
|
||||||
|
"#os.environ[\"GRAPH_DATABASE_PASSWORD\"]=\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"# \"pgvector\", \"qdrant\", \"weaviate\" or \"lancedb\"\n",
|
||||||
|
"os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" \n",
|
||||||
|
"# Not needed if using \"lancedb\" or \"pgvector\"\n",
|
||||||
|
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
|
||||||
|
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Relational Database provider \"sqlite\" or \"postgres\"\n",
|
||||||
|
"os.environ[\"DB_PROVIDER\"]=\"sqlite\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Database name\n",
|
||||||
|
"os.environ[\"DB_NAME\"]=\"cognee_db\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Postgres specific parameters (Only if Postgres or PGVector is used)\n",
|
||||||
|
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
|
||||||
|
"# os.environ[\"DB_PORT\"]=\"5432\"\n",
|
||||||
|
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
|
||||||
|
"# os.environ[\"DB_PASSWORD\"]=\"cognee\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Run Cognee with multimedia files"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import cognee\n",
|
||||||
|
"\n",
|
||||||
|
"# Create a clean slate for cognee -- reset data and system state\n",
|
||||||
|
"await cognee.prune.prune_data()\n",
|
||||||
|
"await cognee.prune.prune_system(metadata=True)\n",
|
||||||
|
"\n",
|
||||||
|
"# Add multimedia files and make them available for cognify\n",
|
||||||
|
"await cognee.add([mp3_file_path, png_file_path])\n",
|
||||||
|
"\n",
|
||||||
|
"# Create knowledge graph with cognee\n",
|
||||||
|
"await cognee.cognify()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Query Cognee for summaries related to multimedia files"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from cognee.api.v1.search import SearchType\n",
|
||||||
|
"\n",
|
||||||
|
"# Query cognee for summaries of the data in the multimedia files\n",
|
||||||
|
"search_results = await cognee.search(\n",
|
||||||
|
" SearchType.SUMMARIES,\n",
|
||||||
|
" query_text=\"What is in the multimedia files?\",\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Display search results\n",
|
||||||
|
"for result_text in search_results:\n",
|
||||||
|
" print(result_text)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue