Merge remote-tracking branch 'origin/main' into feature/cog-537-implement-retrieval-algorithm-from-research-paper

This commit is contained in:
hajdul88 2024-11-26 16:38:32 +01:00
commit 59f8ec665f
31 changed files with 729 additions and 231 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

View file

@ -0,0 +1,63 @@
name: test | multimedia notebook
on:
workflow_dispatch:
pull_request:
branches:
- main
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
env:
RUNTIME__LOG_LEVEL: ERROR
jobs:
get_docs_changes:
name: docs changes
uses: ./.github/workflows/get_docs_changes.yml
run_notebook_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:
run:
shell: bash
steps:
- name: Check out
uses: actions/checkout@master
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Install dependencies
run: |
poetry install --no-interaction
poetry add jupyter --no-interaction
- name: Execute Jupyter Notebook
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
run: |
poetry run jupyter nbconvert \
--to notebook \
--execute notebooks/cognee_multimedia_demo.ipynb \
--output executed_notebook.ipynb \
--ExecutePreprocessor.timeout=1200

View file

@ -13,6 +13,7 @@ concurrency:
env: env:
RUNTIME__LOG_LEVEL: ERROR RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs: jobs:
get_docs_changes: get_docs_changes:

View file

@ -13,6 +13,7 @@ concurrency:
env: env:
RUNTIME__LOG_LEVEL: ERROR RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs: jobs:
get_docs_changes: get_docs_changes:

View file

@ -13,6 +13,7 @@ concurrency:
env: env:
RUNTIME__LOG_LEVEL: ERROR RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs: jobs:
get_docs_changes: get_docs_changes:

View file

@ -105,37 +105,65 @@ import asyncio
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
async def main(): async def main():
# Reset cognee data # Create a clean slate for cognee -- reset data and system state
print("Resetting cognee data...")
await cognee.prune.prune_data() await cognee.prune.prune_data()
# Reset cognee system state
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
print("Data reset complete.\n")
# cognee knowledge graph will be created based on this text
text = """ text = """
Natural language processing (NLP) is an interdisciplinary Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval. subfield of computer science and information retrieval.
""" """
# Add text to cognee print("Adding text to cognee:")
print(text.strip())
# Add the text, and make it available for cognify
await cognee.add(text) await cognee.add(text)
print("Text added successfully.\n")
print("Running cognify to create knowledge graph...\n")
print("Cognify process steps:")
print("1. Classifying the document: Determining the type and category of the input text.")
print("2. Checking permissions: Ensuring the user has the necessary rights to process the text.")
print("3. Extracting text chunks: Breaking down the text into sentences or phrases for analysis.")
print("4. Adding data points: Storing the extracted chunks for processing.")
print("5. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph.")
print("6. Summarizing text: Creating concise summaries of the content for quick insights.\n")
# Use LLMs and cognee to create knowledge graph # Use LLMs and cognee to create knowledge graph
await cognee.cognify() await cognee.cognify()
print("Cognify process complete.\n")
# Search cognee for insights
query_text = 'Tell me about NLP'
print(f"Searching cognee for insights with query: '{query_text}'")
# Query cognee for insights on the added text
search_results = await cognee.search( search_results = await cognee.search(
SearchType.INSIGHTS, SearchType.INSIGHTS, query_text=query_text
"Tell me about NLP",
) )
print("Search results:")
# Display results # Display results
for result_text in search_results: for result_text in search_results:
print(result_text) print(result_text)
# natural_language_processing is_a field
# natural_language_processing is_subfield_of computer_science
# natural_language_processing is_subfield_of information_retrieval
asyncio.run(main()) # Example output:
# ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'})
# (...)
#
# It represents nodes and relationships in the knowledge graph:
# - The first element is the source node (e.g., 'natural language processing').
# - The second element is the relationship between nodes (e.g., 'is_a_subfield_of').
# - The third element is the target node (e.g., 'computer science').
if __name__ == '__main__':
asyncio.run(main())
``` ```
When you run this script, you will see step-by-step messages in the console that help you trace the execution flow and understand what the script is doing at each stage.
A version of this example is here: `examples/python/simple_example.py` A version of this example is here: `examples/python/simple_example.py`
### Create your own memory store ### Create your own memory store

View file

@ -9,18 +9,21 @@ async def get_graph_engine() -> GraphDBInterface :
config = get_graph_config() config = get_graph_config()
if config.graph_database_provider == "neo4j": if config.graph_database_provider == "neo4j":
try: if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
from .neo4j_driver.adapter import Neo4jAdapter raise EnvironmentError("Missing required Neo4j credentials.")
from .neo4j_driver.adapter import Neo4jAdapter
return Neo4jAdapter( return Neo4jAdapter(
graph_database_url = config.graph_database_url, graph_database_url = config.graph_database_url,
graph_database_username = config.graph_database_username, graph_database_username = config.graph_database_username,
graph_database_password = config.graph_database_password graph_database_password = config.graph_database_password
) )
except:
pass
elif config.graph_database_provider == "falkordb": elif config.graph_database_provider == "falkordb":
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
raise EnvironmentError("Missing required FalkorDB credentials.")
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter

View file

@ -130,6 +130,29 @@ class SQLAlchemyAdapter():
return metadata.tables[full_table_name] return metadata.tables[full_table_name]
raise ValueError(f"Table '{full_table_name}' not found.") raise ValueError(f"Table '{full_table_name}' not found.")
async def get_table_names(self) -> List[str]:
"""
Return a list of all tables names in database
"""
table_names = []
async with self.engine.begin() as connection:
if self.engine.dialect.name == "sqlite":
await connection.run_sync(Base.metadata.reflect)
for table in Base.metadata.tables:
table_names.append(str(table))
else:
schema_list = await self.get_schema_list()
# Create a MetaData instance to load table information
metadata = MetaData()
# Drop all tables from all schemas
for schema_name in schema_list:
# Load the schema information into the MetaData object
await connection.run_sync(metadata.reflect, schema=schema_name)
for table in metadata.sorted_tables:
table_names.append(str(table))
metadata.clear()
return table_names
async def get_data(self, table_name: str, filters: dict = None): async def get_data(self, table_name: str, filters: dict = None):
async with self.engine.begin() as connection: async with self.engine.begin() as connection:

View file

@ -10,26 +10,29 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
if config["vector_db_provider"] == "weaviate": if config["vector_db_provider"] == "weaviate":
from .weaviate_db import WeaviateAdapter from .weaviate_db import WeaviateAdapter
if config["vector_db_url"] is None and config["vector_db_key"] is None: if not (config["vector_db_url"] and config["vector_db_key"]):
raise EnvironmentError("Weaviate is not configured!") raise EnvironmentError("Missing requred Weaviate credentials!")
return WeaviateAdapter( return WeaviateAdapter(
config["vector_db_url"], config["vector_db_url"],
config["vector_db_key"], config["vector_db_key"],
embedding_engine = embedding_engine embedding_engine = embedding_engine
) )
elif config["vector_db_provider"] == "qdrant":
if config["vector_db_url"] and config["vector_db_key"]:
from .qdrant.QDrantAdapter import QDrantAdapter
return QDrantAdapter( elif config["vector_db_provider"] == "qdrant":
url = config["vector_db_url"], if not (config["vector_db_url"] and config["vector_db_key"]):
api_key = config["vector_db_key"], raise EnvironmentError("Missing requred Qdrant credentials!")
embedding_engine = embedding_engine
) from .qdrant.QDrantAdapter import QDrantAdapter
return QDrantAdapter(
url = config["vector_db_url"],
api_key = config["vector_db_key"],
embedding_engine = embedding_engine
)
elif config["vector_db_provider"] == "pgvector": elif config["vector_db_provider"] == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config from cognee.infrastructure.databases.relational import get_relational_config
from .pgvector.PGVectorAdapter import PGVectorAdapter
# Get configuration for postgres database # Get configuration for postgres database
relational_config = get_relational_config() relational_config = get_relational_config()
@ -39,16 +42,25 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
db_port = relational_config.db_port db_port = relational_config.db_port
db_name = relational_config.db_name db_name = relational_config.db_name
if not (db_host and db_port and db_name and db_username and db_password):
raise EnvironmentError("Missing requred pgvector credentials!")
connection_string: str = ( connection_string: str = (
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
) )
from .pgvector.PGVectorAdapter import PGVectorAdapter
return PGVectorAdapter( return PGVectorAdapter(
connection_string, connection_string,
config["vector_db_key"], config["vector_db_key"],
embedding_engine, embedding_engine,
) )
elif config["vector_db_provider"] == "falkordb": elif config["vector_db_provider"] == "falkordb":
if not (config["vector_db_url"] and config["vector_db_key"]):
raise EnvironmentError("Missing requred FalkorDB credentials!")
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalkorDBAdapter( return FalkorDBAdapter(
@ -56,6 +68,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
database_port = config["vector_db_port"], database_port = config["vector_db_port"],
embedding_engine = embedding_engine, embedding_engine = embedding_engine,
) )
else: else:
from .lancedb.LanceDBAdapter import LanceDBAdapter from .lancedb.LanceDBAdapter import LanceDBAdapter
@ -64,5 +77,3 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
api_key = config["vector_db_key"], api_key = config["vector_db_key"],
embedding_engine = embedding_engine, embedding_engine = embedding_engine,
) )
raise EnvironmentError(f"Vector provider not configured correctly: {config['vector_db_provider']}")

View file

@ -1,32 +1,39 @@
import asyncio import asyncio
from typing import List, Optional from typing import List, Optional
import litellm import litellm
from litellm import aembedding
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
litellm.set_verbose = False litellm.set_verbose = False
class LiteLLMEmbeddingEngine(EmbeddingEngine): class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str api_key: str
embedding_model: str endpoint: str
embedding_dimensions: int api_version: str
model: str
dimensions: int
def __init__( def __init__(
self, self,
embedding_model: Optional[str] = "text-embedding-3-large", model: Optional[str] = "text-embedding-3-large",
embedding_dimensions: Optional[int] = 3072, dimensions: Optional[int] = 3072,
api_key: str = None, api_key: str = None,
endpoint: str = None,
api_version: str = None,
): ):
self.api_key = api_key self.api_key = api_key
self.embedding_model = embedding_model self.endpoint = endpoint
self.embedding_dimensions = embedding_dimensions self.api_version = api_version
self.model = model
self.dimensions = dimensions
async def embed_text(self, text: List[str]) -> List[List[float]]: async def embed_text(self, text: List[str]) -> List[List[float]]:
async def get_embedding(text_): async def get_embedding(text_):
response = await aembedding( response = await litellm.aembedding(
self.embedding_model, self.model,
input = text_, input = text_,
api_key = self.api_key api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
) )
return response.data[0]["embedding"] return response.data[0]["embedding"]
@ -36,4 +43,4 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
return result return result
def get_vector_size(self) -> int: def get_vector_size(self) -> int:
return self.embedding_dimensions return self.dimensions

View file

@ -1,23 +1,16 @@
from typing import Optional
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
class EmbeddingConfig(BaseSettings): class EmbeddingConfig(BaseSettings):
openai_embedding_model: str = "text-embedding-3-large" embedding_model: Optional[str] = "text-embedding-3-large"
openai_embedding_dimensions: int = 3072 embedding_dimensions: Optional[int] = 3072
litellm_embedding_model: str = "BAAI/bge-large-en-v1.5" embedding_endpoint: Optional[str] = None
litellm_embedding_dimensions: int = 1024 embedding_api_key: Optional[str] = None
# embedding_engine:object = DefaultEmbeddingEngine(embedding_model=litellm_embedding_model, embedding_dimensions=litellm_embedding_dimensions) embedding_api_version: Optional[str] = None
model_config = SettingsConfigDict(env_file = ".env", extra = "allow") model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
def to_dict(self) -> dict:
return {
"openai_embedding_model": self.openai_embedding_model,
"openai_embedding_dimensions": self.openai_embedding_dimensions,
"litellm_embedding_model": self.litellm_embedding_model,
"litellm_embedding_dimensions": self.litellm_embedding_dimensions,
}
@lru_cache @lru_cache
def get_embedding_config(): def get_embedding_config():
return EmbeddingConfig() return EmbeddingConfig()

View file

@ -1,7 +1,17 @@
from cognee.infrastructure.llm import get_llm_config from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
from cognee.infrastructure.llm.config import get_llm_config
from .EmbeddingEngine import EmbeddingEngine from .EmbeddingEngine import EmbeddingEngine
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
def get_embedding_engine() -> EmbeddingEngine: def get_embedding_engine() -> EmbeddingEngine:
config = get_embedding_config()
llm_config = get_llm_config() llm_config = get_llm_config()
return LiteLLMEmbeddingEngine(api_key = llm_config.llm_api_key)
return LiteLLMEmbeddingEngine(
# If OpenAI API is used for embeddings, litellm needs only the api_key.
api_key = config.embedding_api_key or llm_config.llm_api_key,
endpoint = config.embedding_endpoint,
api_version = config.embedding_api_version,
model = config.embedding_model,
dimensions = config.embedding_dimensions,
)

View file

@ -10,5 +10,3 @@ async def create_db_and_tables():
await vector_engine.create_database() await vector_engine.create_database()
async with vector_engine.engine.begin() as connection: async with vector_engine.engine.begin() as connection:
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))

View file

@ -7,6 +7,7 @@ class LLMConfig(BaseSettings):
llm_model: str = "gpt-4o-mini" llm_model: str = "gpt-4o-mini"
llm_endpoint: str = "" llm_endpoint: str = ""
llm_api_key: Optional[str] = None llm_api_key: Optional[str] = None
llm_api_version: Optional[str] = None
llm_temperature: float = 0.0 llm_temperature: float = 0.0
llm_streaming: bool = False llm_streaming: bool = False
transcription_model: str = "whisper-1" transcription_model: str = "whisper-1"
@ -19,6 +20,7 @@ class LLMConfig(BaseSettings):
"model": self.llm_model, "model": self.llm_model,
"endpoint": self.llm_endpoint, "endpoint": self.llm_endpoint,
"api_key": self.llm_api_key, "api_key": self.llm_api_key,
"api_version": self.llm_api_version,
"temperature": self.llm_temperature, "temperature": self.llm_temperature,
"streaming": self.llm_streaming, "streaming": self.llm_streaming,
"transcription_model": self.transcription_model "transcription_model": self.transcription_model

View file

@ -20,21 +20,33 @@ def get_llm_client():
raise ValueError("LLM API key is not set.") raise ValueError("LLM API key is not set.")
from .openai.adapter import OpenAIAdapter from .openai.adapter import OpenAIAdapter
return OpenAIAdapter(api_key=llm_config.llm_api_key, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, streaming=llm_config.llm_streaming)
return OpenAIAdapter(
api_key = llm_config.llm_api_key,
endpoint = llm_config.llm_endpoint,
api_version = llm_config.llm_api_version,
model = llm_config.llm_model,
transcription_model = llm_config.transcription_model,
streaming = llm_config.llm_streaming,
)
elif provider == LLMProvider.OLLAMA: elif provider == LLMProvider.OLLAMA:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.") raise ValueError("LLM API key is not set.")
from .generic_llm_api.adapter import GenericAPIAdapter from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama") return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
elif provider == LLMProvider.ANTHROPIC: elif provider == LLMProvider.ANTHROPIC:
from .anthropic.adapter import AnthropicAdapter from .anthropic.adapter import AnthropicAdapter
return AnthropicAdapter(llm_config.llm_model) return AnthropicAdapter(llm_config.llm_model)
elif provider == LLMProvider.CUSTOM: elif provider == LLMProvider.CUSTOM:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.") raise ValueError("LLM API key is not set.")
from .generic_llm_api.adapter import GenericAPIAdapter from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom") return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
else: else:
raise ValueError(f"Unsupported LLM provider: {provider}") raise ValueError(f"Unsupported LLM provider: {provider}")

View file

@ -1,174 +1,127 @@
import asyncio
import base64
import os import os
import base64
from pathlib import Path from pathlib import Path
from typing import List, Type from typing import Type
import openai import litellm
import instructor import instructor
from pydantic import BaseModel from pydantic import BaseModel
from tenacity import retry, stop_after_attempt
from cognee.base_config import get_base_config
from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
# from cognee.shared.data_models import MonitoringTool
class OpenAIAdapter(LLMInterface): class OpenAIAdapter(LLMInterface):
name = "OpenAI" name = "OpenAI"
model: str model: str
api_key: str api_key: str
api_version: str
"""Adapter for OpenAI's GPT-3, GPT=4 API""" """Adapter for OpenAI's GPT-3, GPT=4 API"""
def __init__(self, api_key: str, model: str, transcription_model:str, streaming: bool = False): def __init__(
base_config = get_base_config() self,
api_key: str,
# if base_config.monitoring_tool == MonitoringTool.LANGFUSE: endpoint: str,
# from langfuse.openai import AsyncOpenAI, OpenAI api_version: str,
# elif base_config.monitoring_tool == MonitoringTool.LANGSMITH: model: str,
# from langsmith import wrappers transcription_model: str,
# from openai import AsyncOpenAI streaming: bool = False,
# AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI()) ):
# else: self.aclient = instructor.from_litellm(litellm.acompletion)
from openai import AsyncOpenAI, OpenAI self.client = instructor.from_litellm(litellm.completion)
self.transcription_model = transcription_model
self.aclient = instructor.from_openai(AsyncOpenAI(api_key = api_key))
self.client = instructor.from_openai(OpenAI(api_key = api_key))
self.base_openai_client = OpenAI(api_key = api_key)
self.transcription_model = "whisper-1"
self.model = model self.model = model
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming self.streaming = streaming
@retry(stop = stop_after_attempt(5))
def completions_with_backoff(self, **kwargs):
"""Wrapper around ChatCompletion.create w/ backoff"""
return openai.chat.completions.create(**kwargs)
@retry(stop = stop_after_attempt(5))
async def acompletions_with_backoff(self,**kwargs):
"""Wrapper around ChatCompletion.acreate w/ backoff"""
return await openai.chat.completions.acreate(**kwargs)
@retry(stop = stop_after_attempt(5))
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
"""Wrapper around Embedding.acreate w/ backoff"""
return await self.aclient.embeddings.create(input = input, model = model)
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting + is async"""
text = text.replace("\n", " ")
response = await self.aclient.embeddings.create(input = text, model = model)
embedding = response.data[0].embedding
return embedding
@retry(stop = stop_after_attempt(5))
def create_embedding_with_backoff(self, **kwargs):
"""Wrapper around Embedding.create w/ backoff"""
return openai.embeddings.create(**kwargs)
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting
:param text: str
:param model: str
"""
text = text.replace("\n", " ")
response = self.create_embedding_with_backoff(input=[text], model=model)
embedding = response.data[0].embedding
return embedding
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
"""To get multiple text embeddings in parallel, import/call this function
It specifies defaults + handles rate-limiting + is async"""
# Collect all coroutines
coroutines = (self.async_get_embedding_with_backoff(text, model)
for text, model in zip(texts, models))
# Run the coroutines in parallel and gather the results
embeddings = await asyncio.gather(*coroutines)
return embeddings
@retry(stop = stop_after_attempt(5))
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel: async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query.""" """Generate a response from a user query."""
return await self.aclient.chat.completions.create( return await self.aclient.chat.completions.create(
model = self.model, model = self.model,
messages = [ messages = [{
{ "role": "user",
"role": "user", "content": f"""Use the given format to
"content": f"""Use the given format to extract information from the following input: {text_input}. """,
extract information from the following input: {text_input}. """, }, {
}, "role": "system",
{"role": "system", "content": system_prompt}, "content": system_prompt,
], }],
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version,
response_model = response_model, response_model = response_model,
max_retries = 5,
) )
@retry(stop = stop_after_attempt(5))
def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel: def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query.""" """Generate a response from a user query."""
return self.client.chat.completions.create( return self.client.chat.completions.create(
model = self.model, model = self.model,
messages = [ messages = [{
{ "role": "user",
"role": "user", "content": f"""Use the given format to
"content": f"""Use the given format to extract information from the following input: {text_input}. """,
extract information from the following input: {text_input}. """, }, {
}, "role": "system",
{"role": "system", "content": system_prompt}, "content": system_prompt,
], }],
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version,
response_model = response_model, response_model = response_model,
max_retries = 5,
) )
@retry(stop = stop_after_attempt(5))
def create_transcript(self, input): def create_transcript(self, input):
"""Generate a audio transcript from a user query.""" """Generate a audio transcript from a user query."""
if not os.path.isfile(input): if not os.path.isfile(input):
raise FileNotFoundError(f"The file {input} does not exist.") raise FileNotFoundError(f"The file {input} does not exist.")
with open(input, 'rb') as audio_file: # with open(input, 'rb') as audio_file:
audio_data = audio_file.read() # audio_data = audio_file.read()
transcription = litellm.transcription(
model = self.transcription_model,
transcription = self.base_openai_client.audio.transcriptions.create( file = Path(input),
model=self.transcription_model , api_key=self.api_key,
file=Path(input), api_base=self.endpoint,
) api_version=self.api_version,
max_retries = 5,
)
return transcription return transcription
@retry(stop = stop_after_attempt(5))
def transcribe_image(self, input) -> BaseModel: def transcribe_image(self, input) -> BaseModel:
with open(input, "rb") as image_file: with open(input, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode('utf-8') encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
return self.base_openai_client.chat.completions.create( return litellm.completion(
model=self.model, model = self.model,
messages=[ messages = [{
{ "role": "user",
"role": "user", "content": [
"content": [ {
{"type": "text", "text": "Whats in this image?"}, "type": "text",
{ "text": "Whats in this image?",
"type": "image_url", }, {
"image_url": { "type": "image_url",
"url": f"data:image/jpeg;base64,{encoded_image}", "image_url": {
}, "url": f"data:image/jpeg;base64,{encoded_image}",
}, },
], },
} ],
], }],
max_tokens=300, api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_tokens = 300,
max_retries = 5,
) )
def show_prompt(self, text_input: str, system_prompt: str) -> str: def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query.""" """Format and display the prompt for a user query."""
if not text_input: if not text_input:

View file

@ -1,5 +1,7 @@
import json
import inspect import inspect
import logging import logging
from cognee.modules.settings import get_current_settings
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
@ -157,7 +159,7 @@ async def run_tasks_base(tasks: list[Task], data = None, user: User = None):
}) })
raise error raise error
async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pipeline"): async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
user = await get_default_user() user = await get_default_user()
try: try:
@ -185,3 +187,10 @@ async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pi
}) })
raise error raise error
async def run_tasks(tasks: list[Task], data = None, pipeline_name: str = "default_pipeline"):
config = get_current_settings()
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent = 1))
async for result in run_tasks_with_telemetry(tasks, data, pipeline_name):
yield result

View file

@ -1,3 +1,4 @@
from .get_current_settings import get_current_settings
from .get_settings import get_settings, SettingsDict from .get_settings import get_settings, SettingsDict
from .save_llm_config import save_llm_config from .save_llm_config import save_llm_config
from .save_vector_db_config import save_vector_db_config from .save_vector_db_config import save_vector_db_config

View file

@ -0,0 +1,54 @@
from typing import TypedDict
from cognee.infrastructure.llm import get_llm_config
from cognee.infrastructure.databases.graph import get_graph_config
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.relational.config import get_relational_config
class LLMConfig(TypedDict):
model: str
provider: str
class VectorDBConfig(TypedDict):
url: str
provider: str
class GraphDBConfig(TypedDict):
url: str
provider: str
class RelationalConfig(TypedDict):
url: str
provider: str
class SettingsDict(TypedDict):
llm: LLMConfig
graph: GraphDBConfig
vector: VectorDBConfig
relational: RelationalConfig
def get_current_settings() -> SettingsDict:
llm_config = get_llm_config()
graph_config = get_graph_config()
vector_config = get_vectordb_config()
relational_config = get_relational_config()
return dict(
llm = {
"provider": llm_config.llm_provider,
"model": llm_config.llm_model,
},
graph = {
"provider": graph_config.graph_database_provider,
"url": graph_config.graph_database_url or graph_config.graph_file_path,
},
vector = {
"provider": vector_config.vector_db_provider,
"url": vector_config.vector_db_url,
},
relational = {
"provider": relational_config.db_provider,
"url": f"{relational_config.db_host}:{relational_config.db_port}" \
if relational_config.db_host \
else f"{relational_config.db_path}/{relational_config.db_name}",
},
)

View file

@ -1,16 +1,51 @@
from cognee.modules.data.models import Data from cognee.modules.data.models import Data
from cognee.modules.data.processing.document_types import Document, PdfDocument, AudioDocument, ImageDocument, TextDocument from cognee.modules.data.processing.document_types import (
Document,
PdfDocument,
AudioDocument,
ImageDocument,
TextDocument,
)
EXTENSION_TO_DOCUMENT_CLASS = { EXTENSION_TO_DOCUMENT_CLASS = {
"pdf": PdfDocument, "pdf": PdfDocument, # Text documents
"audio": AudioDocument, "txt": TextDocument,
"image": ImageDocument, "png": ImageDocument, # Image documents
"txt": TextDocument "dwg": ImageDocument,
"xcf": ImageDocument,
"jpg": ImageDocument,
"jpx": ImageDocument,
"apng": ImageDocument,
"gif": ImageDocument,
"webp": ImageDocument,
"cr2": ImageDocument,
"tif": ImageDocument,
"bmp": ImageDocument,
"jxr": ImageDocument,
"psd": ImageDocument,
"ico": ImageDocument,
"heic": ImageDocument,
"avif": ImageDocument,
"aac": AudioDocument, # Audio documents
"mid": AudioDocument,
"mp3": AudioDocument,
"m4a": AudioDocument,
"ogg": AudioDocument,
"flac": AudioDocument,
"wav": AudioDocument,
"amr": AudioDocument,
"aiff": AudioDocument,
} }
def classify_documents(data_documents: list[Data]) -> list[Document]: def classify_documents(data_documents: list[Data]) -> list[Document]:
documents = [ documents = [
EXTENSION_TO_DOCUMENT_CLASS[data_item.extension](id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location, name=data_item.name) EXTENSION_TO_DOCUMENT_CLASS[data_item.extension](
id=data_item.id,
title=f"{data_item.name}.{data_item.extension}",
raw_data_location=data_item.raw_data_location,
name=data_item.name,
)
for data_item in data_documents for data_item in data_documents
] ]
return documents return documents

View file

@ -57,6 +57,24 @@ async def main():
assert len(history) == 6, "Search history is not correct." assert len(history) == 6, "Search history is not correct."
# Assert local data files are cleaned properly
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
# Assert relational, vector and graph databases have been cleaned properly
await cognee.prune.prune_system(metadata=True)
connection = await vector_engine.get_connection()
collection_names = await connection.table_names()
assert len(collection_names) == 0, "LanceDB vector database is not empty"
from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists(get_relational_engine().db_path), "SQLite relational database is not empty"
from cognee.infrastructure.databases.graph import get_graph_config
graph_config = get_graph_config()
assert not os.path.exists(graph_config.graph_file_path), "Networkx graph database is not empty"
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main(), debug=True) asyncio.run(main(), debug=True)

View file

@ -62,6 +62,15 @@ async def main():
assert len(history) == 6, "Search history is not correct." assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
from cognee.infrastructure.databases.graph import get_graph_engine
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 0 and len(edges) == 0, "Neo4j graph database is not empty"
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

View file

@ -87,9 +87,15 @@ async def main():
print(f"{result}\n") print(f"{result}\n")
history = await cognee.get_search_history() history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct." assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
tables_in_database = await vector_engine.get_table_names()
assert len(tables_in_database) == 0, "PostgreSQL database is not empty"
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio

View file

@ -59,9 +59,16 @@ async def main():
print(f"{result}\n") print(f"{result}\n")
history = await cognee.get_search_history() history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct." assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
qdrant_client = get_vector_engine().get_qdrant_client()
collections_response = await qdrant_client.get_collections()
assert len(collections_response.collections) == 0, "QDrant vector database is not empty"
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

View file

@ -57,9 +57,15 @@ async def main():
print(f"{result}\n") print(f"{result}\n")
history = await cognee.get_search_history() history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct." assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
collections = get_vector_engine().client.collections.list_all()
assert len(collections) == 0, "Weaviate vector database is not empty"
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

View file

@ -0,0 +1,48 @@
import os
import asyncio
import pathlib
import cognee
from cognee.api.v1.search import SearchType
# Prerequisites:
# 1. Copy `.env.template` and rename it to `.env`.
# 2. Add your OpenAI API key to the `.env` file in the `LLM_API_KEY` field:
# LLM_API_KEY = "your_key_here"
async def main():
# Create a clean slate for cognee -- reset data and system state
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# cognee knowledge graph will be created based on the text
# and description of these files
mp3_file_path = os.path.join(
pathlib.Path(__file__).parent.parent.parent,
".data/multimedia/text_to_speech.mp3",
)
png_file_path = os.path.join(
pathlib.Path(__file__).parent.parent.parent,
".data/multimedia/example.png",
)
# Add the files, and make it available for cognify
await cognee.add([mp3_file_path, png_file_path])
# Use LLMs and cognee to create knowledge graph
await cognee.cognify()
# Query cognee for summaries of the data in the multimedia files
search_results = await cognee.search(
SearchType.SUMMARIES,
query_text="What is in the multimedia files?",
)
# Display search results
for result_text in search_results:
print(result_text)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -1,5 +1,4 @@
import asyncio import asyncio
import cognee import cognee
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
@ -11,29 +10,57 @@ from cognee.api.v1.search import SearchType
async def main(): async def main():
# Create a clean slate for cognee -- reset data and system state # Create a clean slate for cognee -- reset data and system state
print("Resetting cognee data...")
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
print("Data reset complete.\n")
# cognee knowledge graph will be created based on this text # cognee knowledge graph will be created based on this text
text = """ text = """
Natural language processing (NLP) is an interdisciplinary Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval. subfield of computer science and information retrieval.
""" """
print("Adding text to cognee:")
print(text.strip())
# Add the text, and make it available for cognify # Add the text, and make it available for cognify
await cognee.add(text) await cognee.add(text)
print("Text added successfully.\n")
print("Running cognify to create knowledge graph...\n")
print("Cognify process steps:")
print("1. Classifying the document: Determining the type and category of the input text.")
print("2. Checking permissions: Ensuring the user has the necessary rights to process the text.")
print("3. Extracting text chunks: Breaking down the text into sentences or phrases for analysis.")
print("4. Adding data points: Storing the extracted chunks for processing.")
print("5. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph.")
print("6. Summarizing text: Creating concise summaries of the content for quick insights.\n")
# Use LLMs and cognee to create knowledge graph # Use LLMs and cognee to create knowledge graph
await cognee.cognify() await cognee.cognify()
print("Cognify process complete.\n")
query_text = 'Tell me about NLP'
print(f"Searching cognee for insights with query: '{query_text}'")
# Query cognee for insights on the added text # Query cognee for insights on the added text
search_results = await cognee.search( search_results = await cognee.search(
SearchType.INSIGHTS, query_text='Tell me about NLP' SearchType.INSIGHTS, query_text=query_text
) )
# Display search results print("Search results:")
# Display results
for result_text in search_results: for result_text in search_results:
print(result_text) print(result_text)
# Example output:
# ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'})
# (...)
# It represents nodes and relationships in the knowledge graph:
# - The first element is the source node (e.g., 'natural language processing').
# - The second element is the relationship between nodes (e.g., 'is_a_subfield_of').
# - The third element is the target node (e.g., 'computer science').
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(main()) asyncio.run(main())

View file

@ -265,7 +265,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"id": "df16431d0f48b006", "id": "df16431d0f48b006",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -304,7 +304,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 2,
"id": "9086abf3af077ab4", "id": "9086abf3af077ab4",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -349,7 +349,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 3,
"id": "a9de0cc07f798b7f", "id": "a9de0cc07f798b7f",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -393,7 +393,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 4,
"id": "185ff1c102d06111", "id": "185ff1c102d06111",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -437,7 +437,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"id": "d55ce4c58f8efb67", "id": "d55ce4c58f8efb67",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -479,7 +479,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 6,
"id": "ca4ecc32721ad332", "id": "ca4ecc32721ad332",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -529,14 +529,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 7,
"id": "bce39dc6", "id": "bce39dc6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"\n", "\n",
"# # Setting environment variables\n", "# Setting environment variables\n",
"if \"GRAPHISTRY_USERNAME\" not in os.environ: \n", "if \"GRAPHISTRY_USERNAME\" not in os.environ: \n",
" os.environ[\"GRAPHISTRY_USERNAME\"] = \"\"\n", " os.environ[\"GRAPHISTRY_USERNAME\"] = \"\"\n",
"\n", "\n",
@ -546,24 +546,26 @@
"if \"LLM_API_KEY\" not in os.environ:\n", "if \"LLM_API_KEY\" not in os.environ:\n",
" os.environ[\"LLM_API_KEY\"] = \"\"\n", " os.environ[\"LLM_API_KEY\"] = \"\"\n",
"\n", "\n",
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" # \"neo4j\" or \"networkx\"\n", "# \"neo4j\" or \"networkx\"\n",
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" \n",
"# Not needed if using networkx\n", "# Not needed if using networkx\n",
"#GRAPH_DATABASE_URL=\"\"\n", "#os.environ[\"GRAPH_DATABASE_URL\"]=\"\"\n",
"#GRAPH_DATABASE_USERNAME=\"\"\n", "#os.environ[\"GRAPH_DATABASE_USERNAME\"]=\"\"\n",
"#GRAPH_DATABASE_PASSWORD=\"\"\n", "#os.environ[\"GRAPH_DATABASE_PASSWORD\"]=\"\"\n",
"\n", "\n",
"os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" # \"qdrant\", \"weaviate\" or \"lancedb\"\n", "# \"pgvector\", \"qdrant\", \"weaviate\" or \"lancedb\"\n",
"# Not needed if using \"lancedb\"\n", "os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" \n",
"# Not needed if using \"lancedb\" or \"pgvector\"\n",
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n", "# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n", "# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
"\n", "\n",
"# Database provider\n", "# Relational Database provider \"sqlite\" or \"postgres\"\n",
"os.environ[\"DB_PROVIDER\"]=\"sqlite\" # or \"postgres\"\n", "os.environ[\"DB_PROVIDER\"]=\"sqlite\"\n",
"\n", "\n",
"# Database name\n", "# Database name\n",
"os.environ[\"DB_NAME\"]=\"cognee_db\"\n", "os.environ[\"DB_NAME\"]=\"cognee_db\"\n",
"\n", "\n",
"# Postgres specific parameters (Only if Postgres is run)\n", "# Postgres specific parameters (Only if Postgres or PGVector is used)\n",
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n", "# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
"# os.environ[\"DB_PORT\"]=\"5432\"\n", "# os.environ[\"DB_PORT\"]=\"5432\"\n",
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n", "# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
@ -620,7 +622,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 10,
"id": "7c431fdef4921ae0", "id": "7c431fdef4921ae0",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {

View file

@ -52,7 +52,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -71,7 +71,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -90,23 +90,23 @@
"# \"neo4j\" or \"networkx\"\n", "# \"neo4j\" or \"networkx\"\n",
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" \n", "os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" \n",
"# Not needed if using networkx\n", "# Not needed if using networkx\n",
"#GRAPH_DATABASE_URL=\"\"\n", "#os.environ[\"GRAPH_DATABASE_URL\"]=\"\"\n",
"#GRAPH_DATABASE_USERNAME=\"\"\n", "#os.environ[\"GRAPH_DATABASE_USERNAME\"]=\"\"\n",
"#GRAPH_DATABASE_PASSWORD=\"\"\n", "#os.environ[\"GRAPH_DATABASE_PASSWORD\"]=\"\"\n",
"\n", "\n",
"# \"qdrant\", \"weaviate\" or \"lancedb\"\n", "# \"pgvector\", \"qdrant\", \"weaviate\" or \"lancedb\"\n",
"os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" \n", "os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" \n",
"# Not needed if using \"lancedb\"\n", "# Not needed if using \"lancedb\" or \"pgvector\"\n",
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n", "# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n", "# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
"\n", "\n",
"# Database provider\n", "# Relational Database provider \"sqlite\" or \"postgres\"\n",
"os.environ[\"DB_PROVIDER\"]=\"sqlite\" # or \"postgres\"\n", "os.environ[\"DB_PROVIDER\"]=\"sqlite\"\n",
"\n", "\n",
"# Database name\n", "# Database name\n",
"os.environ[\"DB_NAME\"]=\"cognee_db\"\n", "os.environ[\"DB_NAME\"]=\"cognee_db\"\n",
"\n", "\n",
"# Postgres specific parameters (Only if Postgres is run)\n", "# Postgres specific parameters (Only if Postgres or PGVector is used)\n",
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n", "# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
"# os.environ[\"DB_PORT\"]=\"5432\"\n", "# os.environ[\"DB_PORT\"]=\"5432\"\n",
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n", "# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
@ -130,8 +130,6 @@
"\n", "\n",
"from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables\n", "from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables\n",
"from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables\n", "from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables\n",
"from cognee.infrastructure.databases.graph import get_graph_engine\n",
"from cognee.shared.utils import render_graph\n",
"from cognee.modules.users.models import User\n", "from cognee.modules.users.models import User\n",
"from cognee.modules.users.methods import get_default_user\n", "from cognee.modules.users.methods import get_default_user\n",
"from cognee.tasks.ingestion.ingest_data_with_metadata import ingest_data_with_metadata\n", "from cognee.tasks.ingestion.ingest_data_with_metadata import ingest_data_with_metadata\n",
@ -196,6 +194,9 @@
"source": [ "source": [
"import graphistry\n", "import graphistry\n",
"\n", "\n",
"from cognee.infrastructure.databases.graph import get_graph_engine\n",
"from cognee.shared.utils import render_graph\n",
"\n",
"# Get graph\n", "# Get graph\n",
"graphistry.login(username=os.getenv(\"GRAPHISTRY_USERNAME\"), password=os.getenv(\"GRAPHISTRY_PASSWORD\"))\n", "graphistry.login(username=os.getenv(\"GRAPHISTRY_USERNAME\"), password=os.getenv(\"GRAPHISTRY_PASSWORD\"))\n",
"graph_engine = await get_graph_engine()\n", "graph_engine = await get_graph_engine()\n",

View file

@ -0,0 +1,169 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cognee GraphRAG with Multimedia files"
]
},
{
"cell_type": "markdown",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"source": [
"## Load Data\n",
"\n",
"We will use a few sample multimedia files which we have on GitHub for easy access."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pathlib\n",
"\n",
"# cognee knowledge graph will be created based on the text\n",
"# and description of these files\n",
"mp3_file_path = os.path.join(\n",
" os.path.abspath(''), \"../\",\n",
" \".data/multimedia/text_to_speech.mp3\",\n",
")\n",
"png_file_path = os.path.join(\n",
" os.path.abspath(''), \"../\",\n",
" \".data/multimedia/example.png\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set environment variables"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# Setting environment variables\n",
"if \"GRAPHISTRY_USERNAME\" not in os.environ: \n",
" os.environ[\"GRAPHISTRY_USERNAME\"] = \"\"\n",
"\n",
"if \"GRAPHISTRY_PASSWORD\" not in os.environ: \n",
" os.environ[\"GRAPHISTRY_PASSWORD\"] = \"\"\n",
"\n",
"if \"LLM_API_KEY\" not in os.environ:\n",
" os.environ[\"LLM_API_KEY\"] = \"\"\n",
"\n",
"# \"neo4j\" or \"networkx\"\n",
"os.environ[\"GRAPH_DATABASE_PROVIDER\"]=\"networkx\" \n",
"# Not needed if using networkx\n",
"#os.environ[\"GRAPH_DATABASE_URL\"]=\"\"\n",
"#os.environ[\"GRAPH_DATABASE_USERNAME\"]=\"\"\n",
"#os.environ[\"GRAPH_DATABASE_PASSWORD\"]=\"\"\n",
"\n",
"# \"pgvector\", \"qdrant\", \"weaviate\" or \"lancedb\"\n",
"os.environ[\"VECTOR_DB_PROVIDER\"]=\"lancedb\" \n",
"# Not needed if using \"lancedb\" or \"pgvector\"\n",
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
"\n",
"# Relational Database provider \"sqlite\" or \"postgres\"\n",
"os.environ[\"DB_PROVIDER\"]=\"sqlite\"\n",
"\n",
"# Database name\n",
"os.environ[\"DB_NAME\"]=\"cognee_db\"\n",
"\n",
"# Postgres specific parameters (Only if Postgres or PGVector is used)\n",
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
"# os.environ[\"DB_PORT\"]=\"5432\"\n",
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
"# os.environ[\"DB_PASSWORD\"]=\"cognee\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run Cognee with multimedia files"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import cognee\n",
"\n",
"# Create a clean slate for cognee -- reset data and system state\n",
"await cognee.prune.prune_data()\n",
"await cognee.prune.prune_system(metadata=True)\n",
"\n",
"# Add multimedia files and make them available for cognify\n",
"await cognee.add([mp3_file_path, png_file_path])\n",
"\n",
"# Create knowledge graph with cognee\n",
"await cognee.cognify()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Query Cognee for summaries related to multimedia files"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from cognee.api.v1.search import SearchType\n",
"\n",
"# Query cognee for summaries of the data in the multimedia files\n",
"search_results = await cognee.search(\n",
" SearchType.SUMMARIES,\n",
" query_text=\"What is in the multimedia files?\",\n",
")\n",
"\n",
"# Display search results\n",
"for result_text in search_results:\n",
" print(result_text)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}