feat: pipelines and tasks (#119)

* feat: simple graph pipeline

* feat: implement incremental graph generation

* fix: various bug fixes

* fix: upgrade weaviate-client

---------

Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
This commit is contained in:
Boris 2024-07-20 16:49:00 +02:00 committed by GitHub
parent 9a57659266
commit 14555a25d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
121 changed files with 4409 additions and 1779 deletions

View file

@ -30,9 +30,9 @@ jobs:
# Test all python versions on ubuntu only
include:
- python-version: "3.9.x"
os: "ubuntu-latest"
os: "ubuntu-22.04"
- python-version: "3.10.x"
os: "ubuntu-latest"
os: "ubuntu-22.04"
# - python-version: "3.12.x"
# os: "ubuntu-latest"
@ -90,6 +90,12 @@ jobs:
ENV: 'dev'
run: poetry run python ./cognee/tests/test_library.py
- name: Clean up disk space
run: |
sudo rm -rf ~/.cache
sudo rm -rf /tmp/*
df -h
- name: Build with Poetry
run: poetry build

View file

@ -15,7 +15,6 @@ export interface Data {
name: string;
filePath: string;
mimeType: string;
keywords: string[];
}
interface DatasetLike {
@ -113,9 +112,6 @@ export default function DataView({ datasetId, data, onClose, onDataAdd }: DataVi
<td>
<Text>{dataItem.mimeType}</Text>
</td>
<td>
<Text>{dataItem.keywords.join(", ")}</Text>
</td>
</tr>
))}
</tbody>

View file

@ -1,8 +1,9 @@
export default function addData(dataset: { id: string }, files: File[]) {
const formData = new FormData();
files.forEach((file) => {
formData.append('data', file, file.name);
})
formData.append('datasetId', dataset.id);
const file = files[0];
formData.append('data', file, file.name);
return fetch('http://0.0.0.0:8000/add', {
method: 'POST',

View file

@ -33,8 +33,8 @@ export default function SearchView() {
value: 'ADJACENT',
label: 'Look for graph node\'s neighbors',
}, {
value: 'CATEGORIES',
label: 'Search by categories (Comma separated categories)',
value: 'TRAVERSE',
label: 'Traverse through the graph and get knowledge',
}];
const [searchType, setSearchType] = useState(searchOptions[0]);

View file

@ -1,6 +1,6 @@
from .api.v1.config.config import config
from .api.v1.add.add import add
from .api.v1.cognify.cognify import cognify
from .api.v1.cognify.cognify_v2 import cognify
from .api.v1.datasets.datasets import datasets
from .api.v1.search.search import search, SearchType
from .api.v1.prune import prune

View file

@ -2,9 +2,10 @@
import os
import aiohttp
import uvicorn
import asyncio
import json
import asyncio
import logging
import sentry_sdk
from typing import Dict, Any, List, Union, Optional, Literal
from typing_extensions import Annotated
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
@ -19,7 +20,14 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
app = FastAPI(debug = True)
if os.getenv("ENV") == "prod":
sentry_sdk.init(
dsn = os.getenv("SENTRY_REPORTING_URL"),
traces_sample_rate = 1.0,
profiles_sample_rate = 1.0,
)
app = FastAPI(debug = os.getenv("ENV") != "prod")
origins = [
"http://frontend:3000",
@ -69,10 +77,10 @@ async def delete_dataset(dataset_id: str):
@app.get("/datasets/{dataset_id}/graph", response_model=list)
async def get_dataset_graph(dataset_id: str):
from cognee.shared.utils import render_graph
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.infrastructure.databases.graph import get_graph_engine
try:
graph_client = await get_graph_client()
graph_client = await get_graph_engine()
graph_url = await render_graph(graph_client.graph)
return JSONResponse(
@ -95,7 +103,6 @@ async def get_dataset_data(dataset_id: str):
dict(
id=data["id"],
name=f"{data['name']}.{data['extension']}",
keywords=data["keywords"].split("|"),
filePath=data["file_path"],
mimeType=data["mime_type"],
)
@ -129,8 +136,8 @@ class AddPayload(BaseModel):
@app.post("/add", response_model=dict)
async def add(
data: List[UploadFile],
datasetId: str = Form(...),
data: List[UploadFile] = File(...),
):
""" This endpoint is responsible for adding data to the graph."""
from cognee.api.v1.add import add as cognee_add
@ -177,17 +184,17 @@ class CognifyPayload(BaseModel):
@app.post("/cognify", response_model=dict)
async def cognify(payload: CognifyPayload):
""" This endpoint is responsible for the cognitive processing of the content."""
from cognee.api.v1.cognify.cognify import cognify as cognee_cognify
from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify
try:
await cognee_cognify(payload.datasets)
return JSONResponse(
status_code=200,
content="OK"
status_code = 200,
content = "OK"
)
except Exception as error:
return JSONResponse(
status_code=409,
content={"error": str(error)}
status_code = 409,
content = {"error": str(error)}
)
class SearchPayload(BaseModel):
@ -255,30 +262,13 @@ def start_api_server(host: str = "0.0.0.0", port: int = 8000):
try:
logger.info("Starting server at %s:%s", host, port)
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.relational import get_relationaldb_config
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.graph import get_graph_config
cognee_directory_path = os.path.abspath(".cognee_system")
databases_directory_path = os.path.join(cognee_directory_path, "databases")
relational_config = get_relationaldb_config()
relational_config.db_path = databases_directory_path
relational_config.create_engine()
vector_config = get_vectordb_config()
vector_config.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
graph_config = get_graph_config()
graph_config.graph_file_path = os.path.join(databases_directory_path, "cognee.graph")
base_config = get_base_config()
data_directory_path = os.path.abspath(".data_storage")
base_config.data_root_directory = data_directory_path
from cognee.modules.data.deletion import prune_system
asyncio.run(prune_system())
from cognee.modules.data.deletion import prune_system, prune_data
asyncio.run(prune_data())
asyncio.run(prune_system(metadata = True))
uvicorn.run(app, host = host, port = port)
except Exception as e:

View file

@ -97,7 +97,6 @@ async def add_files(file_paths: List[str], dataset_name: str):
"file_path": file_metadata["file_path"],
"extension": file_metadata["extension"],
"mime_type": file_metadata["mime_type"],
"keywords": "|".join(file_metadata["keywords"]),
}
run_info = pipeline.run(

View file

@ -6,21 +6,19 @@ import nltk
from asyncio import Lock
from nltk.corpus import stopwords
from cognee.infrastructure.data.chunking.LangchainChunkingEngine import LangchainChunkEngine
from cognee.infrastructure.data.chunking.get_chunking_engine import get_chunk_engine
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
graph_ready_output, connect_nodes_in_graph
from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks, add_data_chunks_basic_rag
from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks
from cognee.modules.cognify.graph.add_document_node import add_document_node
from cognee.modules.cognify.graph.add_classification_nodes import add_classification_nodes
from cognee.modules.cognify.graph.add_cognitive_layer_graphs import add_cognitive_layer_graphs
from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.cognify.graph.add_cognitive_layers import add_cognitive_layers
# from cognee.modules.cognify.graph.initialize_graph import initialize_graph
from cognee.infrastructure.files.utils.guess_file_type import guess_file_type, FileTypeException
from cognee.infrastructure.files.utils.extract_text_from_file import extract_text_from_file
from cognee.modules.data.get_content_categories import get_content_categories
@ -49,9 +47,7 @@ async def cognify(datasets: Union[str, List[str]] = None):
stopwords.ensure_loaded()
create_task_status_table()
# graph_config = get_graph_config()
# graph_db_type = graph_config.graph_engine
graph_client = await get_graph_client()
graph_client = await get_graph_engine()
relational_config = get_relationaldb_config()
db_engine = relational_config.database_engine
@ -89,8 +85,8 @@ async def cognify(datasets: Union[str, List[str]] = None):
added_datasets = db_engine.get_datasets()
dataset_files = []
# datasets is a dataset name string
dataset_files = []
dataset_name = datasets.replace(".", "_").replace(" ", "_")
for added_dataset in added_datasets:
@ -145,10 +141,8 @@ async def cognify(datasets: Union[str, List[str]] = None):
batch_size = 20
file_count = 0
files_batch = []
from cognee.infrastructure.databases.graph.config import get_graph_config
graph_config = get_graph_config()
graph_topology = graph_config.graph_model
graph_config = get_graph_config()
if graph_config.infer_graph_topology and graph_config.graph_topology_task:
from cognee.modules.topology.topology import TopologyEngine
@ -173,8 +167,8 @@ async def cognify(datasets: Union[str, List[str]] = None):
else:
document_id = await add_document_node(
graph_client,
parent_node_id=file_metadata['id'],
document_metadata=file_metadata,
parent_node_id = file_metadata['id'],
document_metadata = file_metadata,
)
files_batch.append((dataset_name, file_metadata, document_id))
@ -196,7 +190,7 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
graph_config = get_graph_config()
graph_client = await get_graph_client()
graph_client = await get_graph_engine()
graph_topology = graph_config.graph_model
if graph_topology == SourceCodeGraph:
@ -206,8 +200,6 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
else:
classified_categories = [{"data_type": "text", "category_name": "Unclassified text"}]
# await add_label_nodes(graph_client, document_id, chunk_id, file_metadata["keywords"].split("|"))
await add_classification_nodes(
graph_client,
parent_node_id = document_id,
@ -271,10 +263,10 @@ if __name__ == "__main__":
# await prune.prune_system()
# #
# from cognee.api.v1.add import add
# data_directory_path = os.path.abspath("../../../.data")
# data_directory_path = os.path.abspath("../../.data")
# # print(data_directory_path)
# # config.data_root_directory(data_directory_path)
# # cognee_directory_path = os.path.abspath("../.cognee_system")
# # cognee_directory_path = os.path.abspath(".cognee_system")
# # config.system_root_directory(cognee_directory_path)
#
# await add("data://" +data_directory_path, "example")

View file

@ -0,0 +1,139 @@
import asyncio
import logging
from typing import Union
from cognee.infrastructure.databases.graph import get_graph_config
from cognee.modules.cognify.config import get_cognify_config
from cognee.infrastructure.databases.relational.config import get_relationaldb_config
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
from cognee.modules.data.processing.document_types.ImageDocument import ImageDocument
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.data.processing.document_types import PdfDocument, TextDocument
from cognee.modules.cognify.vector import save_data_chunks
from cognee.modules.data.processing.process_documents import process_documents
from cognee.modules.classification.classify_text_chunks import classify_text_chunks
from cognee.modules.data.extraction.data_summary.summarize_text_chunks import summarize_text_chunks
from cognee.modules.data.processing.filter_affected_chunks import filter_affected_chunks
from cognee.modules.data.processing.remove_obsolete_chunks import remove_obsolete_chunks
from cognee.modules.data.extraction.knowledge_graph.expand_knowledge_graph import expand_knowledge_graph
from cognee.modules.data.extraction.knowledge_graph.establish_graph_topology import establish_graph_topology
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines import run_tasks, run_tasks_parallel
from cognee.modules.tasks import create_task_status_table, update_task_status, get_task_status
logger = logging.getLogger("cognify.v2")
update_status_lock = asyncio.Lock()
async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = None):
relational_config = get_relationaldb_config()
db_engine = relational_config.database_engine
create_task_status_table()
if datasets is None or len(datasets) == 0:
return await cognify(db_engine.get_datasets())
async def run_cognify_pipeline(dataset_name: str, files: list[dict]):
async with update_status_lock:
task_status = get_task_status([dataset_name])
if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED":
logger.info(f"Dataset {dataset_name} is being processed.")
return
update_task_status(dataset_name, "DATASET_PROCESSING_STARTED")
try:
cognee_config = get_cognify_config()
graph_config = get_graph_config()
root_node_id = None
if graph_config.infer_graph_topology and graph_config.graph_topology_task:
from cognee.modules.topology.topology import TopologyEngine
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
root_node_id = await topology_engine.add_graph_topology(files = files)
elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology:
from cognee.modules.topology.topology import TopologyEngine
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
await topology_engine.add_graph_topology(graph_config.topology_file_path)
elif not graph_config.graph_topology_task:
root_node_id = "ROOT"
tasks = [
Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
Task(
save_data_chunks,
collection_name = "chunks",
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
run_tasks_parallel([
Task(
summarize_text_chunks,
summarization_model = cognee_config.summarization_model,
collection_name = "chunk_summaries",
), # Summarize the document chunks
Task(
classify_text_chunks,
classification_model = cognee_config.classification_model,
),
]),
Task(remove_obsolete_chunks), # Remove the obsolete document chunks.
]
pipeline = run_tasks(tasks, [
PdfDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else
AudioDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else
ImageDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else
TextDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"])
for file in files
])
async for result in pipeline:
print(result)
update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED")
except Exception as error:
update_task_status(dataset_name, "DATASET_PROCESSING_ERROR")
raise error
existing_datasets = db_engine.get_datasets()
awaitables = []
# dataset_files = []
# dataset_name = datasets.replace(".", "_").replace(" ", "_")
# for added_dataset in existing_datasets:
# if dataset_name in added_dataset:
# dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset)))
for dataset in datasets:
if dataset in existing_datasets:
# for file_metadata in files:
# if root_node_id is None:
# root_node_id=file_metadata['id']
awaitables.append(run_cognify_pipeline(dataset, db_engine.get_files_metadata(dataset)))
return await asyncio.gather(*awaitables)
#
# if __name__ == "__main__":
# from cognee.api.v1.add import add
# from cognee.api.v1.datasets.datasets import datasets
#
#
# async def aa():
# await add("TEXT ABOUT NLP AND MONKEYS")
#
# print(datasets.discover_datasets())
#
# return
# asyncio.run(cognify())

View file

@ -16,6 +16,9 @@ class config():
relational_config.db_path = databases_directory_path
relational_config.create_engine()
graph_config = get_graph_config()
graph_config.graph_file_path = os.path.join(databases_directory_path, "cognee.graph")
vector_config = get_vectordb_config()
if vector_config.vector_engine_provider == "lancedb":
vector_config.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")

View file

@ -1,17 +1,13 @@
from cognee.modules.data.deletion import prune_system
from cognee.base_config import get_base_config
from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.data.deletion import prune_system, prune_data
class prune():
@staticmethod
async def prune_data():
base_config = get_base_config()
data_root_directory = base_config.data_root_directory
LocalStorage.remove_all(data_root_directory)
await prune_data()
@staticmethod
async def prune_system(graph = True, vector = True):
await prune_system(graph, vector)
async def prune_system(graph = True, vector = True, metadata = False):
await prune_system(graph, vector, metadata)
if __name__ == "__main__":
import asyncio

View file

@ -6,19 +6,16 @@ from pydantic import BaseModel, field_validator
from cognee.modules.search.graph import search_cypher
from cognee.modules.search.graph.search_adjacent import search_adjacent
from cognee.modules.search.vector.search_similarity import search_similarity
from cognee.modules.search.graph.search_categories import search_categories
from cognee.modules.search.graph.search_neighbour import search_neighbour
from cognee.modules.search.vector.search_traverse import search_traverse
from cognee.modules.search.graph.search_summary import search_summary
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.modules.search.graph.search_similarity import search_similarity
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.shared.utils import send_telemetry
from cognee.infrastructure.databases.graph.config import get_graph_config
class SearchType(Enum):
ADJACENT = "ADJACENT"
TRAVERSE = "TRAVERSE"
SIMILARITY = "SIMILARITY"
CATEGORIES = "CATEGORIES"
NEIGHBOR = "NEIGHBOR"
SUMMARY = "SUMMARY"
SUMMARY_CLASSIFICATION = "SUMMARY_CLASSIFICATION"
NODE_CLASSIFICATION = "NODE_CLASSIFICATION"
@ -49,18 +46,15 @@ async def search(search_type: str, params: Dict[str, Any]) -> List:
async def specific_search(query_params: List[SearchParameters]) -> List:
graph_config = get_graph_config()
graph_client = await get_graph_client(graph_config.graph_database_provider)
graph_client = await get_graph_engine()
graph = graph_client.graph
search_functions: Dict[SearchType, Callable] = {
SearchType.ADJACENT: search_adjacent,
SearchType.SIMILARITY: search_similarity,
SearchType.CATEGORIES: search_categories,
SearchType.NEIGHBOR: search_neighbour,
SearchType.SUMMARY: search_summary,
SearchType.CYPHER: search_cypher
SearchType.CYPHER: search_cypher,
SearchType.TRAVERSE: search_traverse,
SearchType.SIMILARITY: search_similarity,
}
results = []
@ -103,7 +97,7 @@ if __name__ == "__main__":
# SearchType.SIMILARITY: {'query': 'your search query here'}
# }
# async def main():
# graph_client = get_graph_client(GraphDBType.NETWORKX)
# graph_client = get_graph_engine()
# await graph_client.load_graph_from_file()
# graph = graph_client.graph

View file

@ -5,7 +5,7 @@ from cognee.root_dir import get_absolute_path
from cognee.shared.data_models import MonitoringTool
class BaseConfig(BaseSettings):
data_root_directory: str = get_absolute_path(".data")
data_root_directory: str = get_absolute_path(".data_storage")
monitoring_tool: object = MonitoringTool.LANGFUSE
graphistry_username: Optional[str] = None
graphistry_password: Optional[str] = None

View file

@ -16,17 +16,15 @@ def create_chunking_engine(config: ChunkingConfig):
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"],
chunk_strategy=config["chunk_strategy"],
)
elif config["chunk_engine"] == ChunkEngine.DEFAULT_ENGINE:
from cognee.infrastructure.data.chunking.DefaultChunkEngine import DefaultChunkEngine
from cognee.infrastructure.data.chunking.DefaultChunkEngine import DefaultChunkEngine
return DefaultChunkEngine(
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"],
chunk_strategy=config["chunk_strategy"],
)
return DefaultChunkEngine(
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"],
chunk_strategy=config["chunk_strategy"],
)
elif config["chunk_engine"] == ChunkEngine.HAYSTACK_ENGINE:
from cognee.infrastructure.data.chunking.HaystackChunkEngine import HaystackChunkEngine

View file

@ -3,4 +3,4 @@ from .config import get_chunk_config
from .create_chunking_engine import create_chunking_engine
def get_chunk_engine():
return create_chunking_engine(get_chunk_config().to_dict())
return create_chunking_engine(get_chunk_config().to_dict())

View file

@ -1 +1,2 @@
from .config import get_graph_config
from .get_graph_engine import get_graph_engine

View file

@ -21,7 +21,7 @@ class GraphConfig(BaseSettings):
graph_model: object = KnowledgeGraph
graph_topology_task: bool = False
graph_topology: object = KnowledgeGraph
infer_graph_topology: bool = True
infer_graph_topology: bool = False
topology_file_path: str = os.path.join(
os.path.join(get_absolute_path(".cognee_system"), "databases"),
"graph_topology.json"

View file

@ -70,8 +70,6 @@ class FalcorDBAdapter(GraphDBInterface):
return await self.query(query, params)
async def add_nodes(self, nodes: list[tuple[str, dict[str, Any]]]) -> None:
# nodes_data = []
for node in nodes:
node_id, node_properties = node
node_id = node_id.replace(":", "_")
@ -112,18 +110,27 @@ class FalcorDBAdapter(GraphDBInterface):
query = """MATCH (node) WHERE node.layer_id IS NOT NULL
RETURN node"""
return [result['node'] for result in (await self.query(query))]
return [result["node"] for result in (await self.query(query))]
async def extract_node(self, node_id: str):
query= """
MATCH(node {id: $node_id})
RETURN node
"""
results = [node['node'] for node in (await self.query(query, dict(node_id = node_id)))]
results = self.extract_nodes([node_id])
return results[0] if len(results) > 0 else None
async def extract_nodes(self, node_ids: List[str]):
query = """
UNWIND $node_ids AS id
MATCH (node {id: id})
RETURN node"""
params = {
"node_ids": node_ids
}
results = await self.query(query, params)
return results
async def delete_node(self, node_id: str):
node_id = id.replace(":", "_")

View file

@ -1,12 +1,11 @@
"""Factory function to get the appropriate graph client based on the graph type."""
from cognee.shared.data_models import GraphDBType
from .config import get_graph_config
from .graph_db_interface import GraphDBInterface
from .networkx.adapter import NetworkXAdapter
async def get_graph_client(graph_type: GraphDBType=None, graph_file_name: str = None) -> GraphDBInterface :
async def get_graph_engine() -> GraphDBInterface :
"""Factory function to get the appropriate graph client based on the graph type."""
config = get_graph_config()
@ -34,8 +33,10 @@ async def get_graph_client(graph_type: GraphDBType=None, graph_file_name: str =
)
except:
pass
graph_client = NetworkXAdapter(filename = config.graph_file_path)
if (graph_client.graph is None):
if graph_client.graph is None:
await graph_client.load_graph_from_file()
return graph_client

View file

@ -25,12 +25,24 @@ class GraphDBInterface(Protocol):
node_id: str
): raise NotImplementedError
@abstractmethod
async def delete_nodes(
self,
node_ids: list[str]
): raise NotImplementedError
@abstractmethod
async def extract_node(
self,
node_id: str
): raise NotImplementedError
@abstractmethod
async def extract_nodes(
self,
node_ids: list[str]
): raise NotImplementedError
@abstractmethod
async def add_edge(
self,

View file

@ -1,6 +1,7 @@
""" Neo4j Adapter for Graph Database"""
import json
import logging
import asyncio
from typing import Optional, Any, List, Dict
from contextlib import asynccontextmanager
from neo4j import AsyncSession
@ -56,6 +57,7 @@ class Neo4jAdapter(GraphDBInterface):
if "name" not in serialized_properties:
serialized_properties["name"] = node_id
query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
ON CREATE SET node += $properties
RETURN ID(node) AS internal_id, node.id AS nodeId"""
@ -68,16 +70,22 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params)
async def add_nodes(self, nodes: list[tuple[str, dict[str, Any]]]) -> None:
# nodes_data = []
query = """
UNWIND $nodes AS node
MERGE (n {id: node.node_id})
ON CREATE SET n += node.properties
WITH n, node.node_id AS label
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
"""
for node in nodes:
node_id, node_properties = node
node_id = node_id.replace(":", "_")
nodes = [{
"node_id": node_id,
"properties": self.serialize_properties(node_properties),
} for (node_id, node_properties) in nodes]
await self.add_node(
node_id = node_id,
node_properties = node_properties,
)
results = await self.query(query, dict(nodes = nodes))
return results
async def extract_node_description(self, node_id: str):
query = """MATCH (n)-[r]->(m)
@ -111,15 +119,24 @@ class Neo4jAdapter(GraphDBInterface):
return [result["node"] for result in (await self.query(query))]
async def extract_node(self, node_id: str):
query= """
MATCH(node {id: $node_id})
RETURN node
"""
results = [node["node"] for node in (await self.query(query, dict(node_id = node_id)))]
results = await self.extract_nodes([node_id])
return results[0] if len(results) > 0 else None
async def extract_nodes(self, node_ids: List[str]):
query = """
UNWIND $node_ids AS id
MATCH (node {id: id})
RETURN node"""
params = {
"node_ids": node_ids
}
results = await self.query(query, params)
return [result["node"] for result in results]
async def delete_node(self, node_id: str):
node_id = id.replace(":", "_")
@ -128,6 +145,18 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params)
async def delete_nodes(self, node_ids: list[str]) -> None:
query = """
UNWIND $node_ids AS id
MATCH (node {id: id})
DETACH DELETE node"""
params = {
"node_ids": node_ids
}
return await self.query(query, params)
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
serialized_properties = self.serialize_properties(edge_properties)
from_node = from_node.replace(":", "_")
@ -150,19 +179,75 @@ class Neo4jAdapter(GraphDBInterface):
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
# edges_data = []
query = """
UNWIND $edges AS edge
MATCH (from_node {id: edge.from_node})
MATCH (to_node {id: edge.to_node})
CALL apoc.create.relationship(from_node, edge.relationship_name, edge.properties, to_node) YIELD rel
RETURN rel
"""
for edge in edges:
from_node, to_node, relationship_name, edge_properties = edge
from_node = from_node.replace(":", "_")
to_node = to_node.replace(":", "_")
edges = [{
"from_node": edge[0],
"to_node": edge[1],
"relationship_name": edge[2],
"properties": {
**(edge[3] if edge[3] else {}),
"source_node_id": edge[0],
"target_node_id": edge[1],
},
} for edge in edges]
await self.add_edge(
from_node = from_node,
to_node = to_node,
relationship_name = relationship_name,
edge_properties = edge_properties
)
results = await self.query(query, dict(edges = edges))
return results
async def get_edges(self, node_id: str):
query = """
MATCH (n {id: $node_id})-[r]-(m)
RETURN n, r, m
"""
results = await self.query(query, dict(node_id = node_id))
return [(result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]}) for result in results]
async def get_disconnected_nodes(self) -> list[str]:
# return await self.query(
# "MATCH (node) WHERE NOT (node)<-[:*]-() RETURN node.id as id",
# )
query = """
// Step 1: Collect all nodes
MATCH (n)
WITH COLLECT(n) AS nodes
// Step 2: Find all connected components
WITH nodes
CALL {
WITH nodes
UNWIND nodes AS startNode
MATCH path = (startNode)-[*]-(connectedNode)
WITH startNode, COLLECT(DISTINCT connectedNode) AS component
RETURN component
}
// Step 3: Aggregate components
WITH COLLECT(component) AS components
// Step 4: Identify the largest connected component
UNWIND components AS component
WITH component
ORDER BY SIZE(component) DESC
LIMIT 1
WITH component AS largestComponent
// Step 5: Find nodes not in the largest connected component
MATCH (n)
WHERE NOT n IN largestComponent
RETURN COLLECT(ID(n)) AS ids
"""
results = await self.query(query)
return results[0]["ids"] if len(results) > 0 else []
async def filter_nodes(self, search_criteria):
@ -170,10 +255,99 @@ class Neo4jAdapter(GraphDBInterface):
WHERE node.id CONTAINS '{search_criteria}'
RETURN node"""
return await self.query(query)
async def get_predecessor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None:
query = """
MATCH (node:`{node_id}`)-[r:`{edge_label}`]->(predecessor)
RETURN predecessor.id AS id
"""
results = await self.query(
query,
dict(
node_id = node_id,
edge_label = edge_label,
)
)
return [result["id"] for result in results]
else:
query = """
MATCH (node:`{node_id}`)-[r]->(predecessor)
RETURN predecessor.id AS id
"""
results = await self.query(
query,
dict(
node_id = node_id,
)
)
return [result["id"] for result in results]
async def get_successor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None:
query = """
MATCH (node:`{node_id}`)<-[r:`{edge_label}`]-(successor)
RETURN successor.id AS id
"""
results = await self.query(
query,
dict(
node_id = node_id,
edge_label = edge_label,
),
)
return [result["id"] for result in results]
else:
query = """
MATCH (node:`{node_id}`)<-[r]-(successor)
RETURN successor.id AS id
"""
results = await self.query(
query,
dict(
node_id = node_id,
)
)
return [result["id"] for result in results]
async def get_neighbours(self, node_id: str) -> list[str]:
results = await asyncio.gather(*[self.get_predecessor_ids(node_id)], self.get_successor_ids(node_id))
return [*results[0], *results[1]]
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
query = f"""
UNWIND $node_ids AS id
MATCH (node:`{id}`)-[r:{edge_label}]->(predecessor)
DELETE r;
"""
params = { "node_ids": node_ids }
return await self.query(query, params)
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str) -> None:
query = f"""
UNWIND $node_ids AS id
MATCH (node:`{id}`)<-[r:{edge_label}]-(successor)
DELETE r;
"""
params = { "node_ids": node_ids }
return await self.query(query, params)
async def delete_graph(self):
query = """MATCH (node)
DETACH DELETE node;"""
@ -186,3 +360,25 @@ class Neo4jAdapter(GraphDBInterface):
if isinstance(property_value, (dict, list))
else property_value for property_key, property_value in properties.items()
}
async def get_graph_data(self):
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
result = await self.query(query)
nodes = [(
record["properties"]["id"],
record["properties"],
) for record in result]
query = """
MATCH (n)-[r]->(m)
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result = await self.query(query)
edges = [(
record["properties"]["source_node_id"],
record["properties"]["target_node_id"],
record["type"],
record["properties"],
) for record in result]
return (nodes, edges)

View file

@ -44,7 +44,7 @@ class NetworkXAdapter(GraphDBInterface):
async def get_graph(self):
return self.graph
async def add_edge(
self,
from_node: str,
@ -62,12 +62,31 @@ class NetworkXAdapter(GraphDBInterface):
self.graph.add_edges_from(edges)
await self.save_graph_to_file(self.filename)
async def get_edges(self, node_id: str):
return list(self.graph.in_edges(node_id, data = True)) + list(self.graph.out_edges(node_id, data = True))
async def delete_node(self, node_id: str) -> None:
"""Asynchronously delete a node from the graph if it exists."""
if self.graph.has_node(id):
self.graph.remove_node(id)
await self.save_graph_to_file(self.filename)
async def delete_nodes(self, node_ids: List[str]) -> None:
self.graph.remove_nodes_from(node_ids)
await self.save_graph_to_file(self.filename)
async def get_disconnected_nodes(self) -> List[str]:
connected_components = list(nx.weakly_connected_components(self.graph))
disconnected_nodes = []
biggest_subgraph = max(connected_components, key = len)
for component in connected_components:
if component != biggest_subgraph:
disconnected_nodes.extend(list(component))
return disconnected_nodes
async def extract_node_description(self, node_id: str) -> Dict[str, Any]:
descriptions = []
@ -98,11 +117,69 @@ class NetworkXAdapter(GraphDBInterface):
async def extract_node(self, node_id: str) -> dict:
if self.graph.has_node(node_id):
return self.graph.nodes[node_id]
return None
async def extract_nodes(self, node_ids: List[str]) -> List[dict]:
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
async def get_predecessor_ids(self, node_id: str, edge_label: str = None) -> list:
if self.graph.has_node(node_id):
if edge_label is None:
return list(self.graph.predecessors(node_id))
nodes = []
for predecessor_id in list(self.graph.predecessors(node_id)):
if self.graph.has_edge(predecessor_id, node_id, edge_label):
nodes.append(predecessor_id)
return nodes
async def get_successor_ids(self, node_id: str, edge_label: str = None) -> list:
if self.graph.has_node(node_id):
if edge_label is None:
return list(self.graph.successors(node_id))
nodes = []
for successor_id in list(self.graph.successors(node_id)):
if self.graph.has_edge(node_id, successor_id, edge_label):
nodes.append(successor_id)
return nodes
async def get_neighbours(self, node_id: str) -> list:
if not self.graph.has_node(node_id):
return []
neighbour_ids = list(self.graph.neighbors(node_id))
if len(neighbour_ids) == 0:
return []
nodes = await self.extract_nodes(neighbour_ids)
return nodes
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
for node_id in node_ids:
if self.graph.has_node(node_id):
for predecessor_id in list(self.graph.predecessors(node_id)):
if self.graph.has_edge(predecessor_id, node_id, edge_label):
self.graph.remove_edge(predecessor_id, node_id, edge_label)
await self.save_graph_to_file(self.filename)
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str) -> None:
for node_id in node_ids:
if self.graph.has_node(node_id):
for successor_id in list(self.graph.successors(node_id)):
if self.graph.has_edge(node_id, successor_id, edge_label):
self.graph.remove_edge(node_id, successor_id, edge_label)
await self.save_graph_to_file(self.filename)
async def save_graph_to_file(self, file_path: str=None) -> None:
"""Asynchronously save the graph to a file in JSON format."""
@ -114,6 +191,7 @@ class NetworkXAdapter(GraphDBInterface):
async with aiofiles.open(file_path, "w") as file:
await file.write(json.dumps(graph_data))
async def load_graph_from_file(self, file_path: str = None):
"""Asynchronously load the graph from a file in JSON format."""
if file_path == self.filename:

View file

@ -2,10 +2,9 @@ import duckdb
import os
class DuckDBAdapter():
def __init__(self, db_path: str, db_name: str):
self.db_location = os.path.abspath(os.path.join(db_path, db_name))
db_location = os.path.abspath(os.path.join(db_path, db_name))
self.get_connection = lambda: duckdb.connect(db_location)
self.get_connection = lambda: duckdb.connect(self.db_location)
def get_datasets(self):
with self.get_connection() as connection:
@ -20,7 +19,7 @@ class DuckDBAdapter():
def get_files_metadata(self, dataset_name: str):
with self.get_connection() as connection:
return connection.sql(f"SELECT id, name, file_path, extension, mime_type, keywords FROM {dataset_name}.file_metadata;").to_df().to_dict("records")
return connection.sql(f"SELECT id, name, file_path, extension, mime_type FROM {dataset_name}.file_metadata;").to_df().to_dict("records")
def create_table(self, schema_name: str, table_name: str, table_config: list[dict]):
fields_query_parts = []
@ -163,3 +162,11 @@ class DuckDBAdapter():
connection.sql(select_data_sql)
drop_data_sql = "DROP TABLE cognify;"
connection.sql(drop_data_sql)
def delete_database(self):
from cognee.infrastructure.files.storage import LocalStorage
LocalStorage.remove(self.db_location)
if LocalStorage.file_exists(self.db_location + ".wal"):
LocalStorage.remove(self.db_location + ".wal")

View file

@ -40,7 +40,7 @@ class FalcorDBAdapter(VectorDBInterface):
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
pass
async def retrieve(self, collection_name: str, data_point_id: str):
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
pass
async def search(
@ -51,4 +51,7 @@ class FalcorDBAdapter(VectorDBInterface):
limit: int = 10,
with_vector: bool = False,
):
pass
pass
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
pass

View file

@ -33,7 +33,7 @@ class LanceDBAdapter(VectorDBInterface):
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def collection_exists(self, collection_name: str) -> bool:
async def has_collection(self, collection_name: str) -> bool:
connection = await self.get_connection()
collection_names = await connection.table_names()
return collection_name in collection_names
@ -47,7 +47,7 @@ class LanceDBAdapter(VectorDBInterface):
vector: Vector(vector_size)
payload: payload_schema
if not await self.collection_exists(collection_name):
if not await self.has_collection(collection_name):
connection = await self.get_connection()
return await connection.create_table(
name = collection_name,
@ -58,7 +58,7 @@ class LanceDBAdapter(VectorDBInterface):
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
connection = await self.get_connection()
if not await self.collection_exists(collection_name):
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name,
payload_schema = type(data_points[0].payload),
@ -89,17 +89,20 @@ class LanceDBAdapter(VectorDBInterface):
await collection.add(lance_data_points)
async def retrieve(self, collection_name: str, data_point_id: str):
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
results = await collection.query().where(f"id = '{data_point_id}'").to_pandas()
result = results.to_dict("index")[0]
return ScoredResult(
if len(data_point_ids) == 1:
results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
else:
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
return [ScoredResult(
id = result["id"],
payload = result["payload"],
score = 1,
)
) for result in results.to_dict("index").values()]
async def search(
self,
@ -122,8 +125,8 @@ class LanceDBAdapter(VectorDBInterface):
return [ScoredResult(
id = str(result["id"]),
score = float(result["_distance"]),
payload = result["payload"],
score = float(result["_distance"]),
) for result in results.to_dict("index").values()]
async def batch_search(
@ -144,6 +147,12 @@ class LanceDBAdapter(VectorDBInterface):
) for query_vector in query_vectors]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
return results
async def prune(self):
# Clean up the database if it was set up as temporary
if self.url.startswith("/"):

View file

@ -59,7 +59,7 @@ class QDrantAdapter(VectorDBInterface):
async def embed_data(self, data: List[str]) -> List[float]:
return await self.embedding_engine.embed_text(data)
async def collection_exists(self, collection_name: str) -> bool:
async def has_collection(self, collection_name: str) -> bool:
client = self.get_qdrant_client()
result = await client.collection_exists(collection_name)
await client.close()
@ -111,11 +111,11 @@ class QDrantAdapter(VectorDBInterface):
return result
async def retrieve(self, collection_name: str, data_point_id: str):
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
client = self.get_qdrant_client()
results = await client.retrieve(collection_name, [data_point_id], with_payload = True)
results = await client.retrieve(collection_name, data_point_ids, with_payload = True)
await client.close()
return results[0] if len(results) > 0 else None
return results
async def search(
self,
@ -185,6 +185,11 @@ class QDrantAdapter(VectorDBInterface):
return [filter(lambda result: result.score > 0.9, result_group) for result_group in results]
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
client = self.get_qdrant_client()
results = await client.delete(collection_name, data_point_ids)
return results
async def prune(self):
client = self.get_qdrant_client()

View file

@ -6,7 +6,7 @@ from .models.PayloadSchema import PayloadSchema
class VectorDBInterface(Protocol):
""" Collections """
@abstractmethod
async def collection_exists(self, collection_name: str) -> bool:
async def has_collection(self, collection_name: str) -> bool:
raise NotImplementedError
@abstractmethod
@ -28,7 +28,7 @@ class VectorDBInterface(Protocol):
async def retrieve(
self,
collection_name: str,
data_point_id: str
data_point_ids: list[str]
): raise NotImplementedError
""" Search """
@ -51,3 +51,13 @@ class VectorDBInterface(Protocol):
limit: int,
with_vectors: bool = False
): raise NotImplementedError
@abstractmethod
async def delete_data_points(
self,
collection_name: str,
data_point_ids: list[str]
): raise NotImplementedError
@abstractmethod
async def prune(self): raise NotImplementedError

View file

@ -1,5 +1,4 @@
import asyncio
from uuid import UUID
from typing import List, Optional
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
@ -31,7 +30,7 @@ class WeaviateAdapter(VectorDBInterface):
async def embed_data(self, data: List[str]) -> List[float]:
return await self.embedding_engine.embed_text(data)
async def collection_exists(self, collection_name: str) -> bool:
async def has_collection(self, collection_name: str) -> bool:
future = asyncio.Future()
future.set_result(self.client.collections.exists(collection_name))
@ -72,25 +71,41 @@ class WeaviateAdapter(VectorDBInterface):
list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
def convert_to_weaviate_data_points(data_point: DataPoint):
vector = data_vectors[data_points.index(data_point)]
return DataObject(
uuid = data_point.id,
properties = data_point.payload.dict(),
vector = data_vectors[data_points.index(data_point)]
vector = vector
)
objects = list(map(convert_to_weaviate_data_points, data_points))
return self.get_collection(collection_name).data.insert_many(objects)
collection = self.get_collection(collection_name)
async def retrieve(self, collection_name: str, data_point_id: str):
with collection.batch.dynamic() as batch:
for data_row in objects:
batch.add_object(
properties = data_row.properties,
vector = data_row.vector
)
return
# return self.get_collection(collection_name).data.insert_many(objects)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
from weaviate.classes.query import Filter
future = asyncio.Future()
data_point = self.get_collection(collection_name).query.fetch_object_by_id(UUID(data_point_id))
data_points = self.get_collection(collection_name).query.fetch_objects(
filters = Filter.by_id().contains_any(data_point_ids)
)
data_point.payload = data_point.properties
del data_point.properties
for data_point in data_points:
data_point.payload = data_point.properties
del data_point.properties
future.set_result(data_point)
future.set_result(data_points)
return await future
@ -131,6 +146,17 @@ class WeaviateAdapter(VectorDBInterface):
return self.search(collection_name, query_vector=query_vector, limit=limit, with_vector=with_vectors)
return [await query_search(query_vector) for query_vector in await self.embed_data(query_texts)]
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
from weaviate.classes.query import Filter
future = asyncio.Future()
result = self.get_collection(collection_name).data.delete_many(
filters = Filter.by_id().contains_any(data_point_ids)
)
future.set_result(result)
return await future
async def prune(self):
self.client.collections.delete_all()

View file

@ -32,13 +32,19 @@ class LocalStorage(Storage):
f.seek(0)
return f.read()
@staticmethod
def file_exists(file_path: str):
return os.path.exists(file_path)
@staticmethod
def ensure_directory_exists(file_path: str):
if not os.path.exists(file_path):
os.makedirs(file_path, exist_ok = True)
def remove(self, file_path: str):
os.remove(self.storage_path + "/" + file_path)
@staticmethod
def remove(file_path: str):
if os.path.exists(file_path):
os.remove(file_path)
@staticmethod
def copy_file(source_file_path: str, destination_file_path: str):

View file

@ -7,7 +7,8 @@ class Storage(Protocol):
def retrieve(self, file_path: str):
pass
def remove(self, file_path: str):
@staticmethod
def remove(file_path: str):
pass
class StorageManager():

View file

@ -1,6 +1,4 @@
from typing import BinaryIO, TypedDict
from cognee.infrastructure.data.utils.extract_keywords import extract_keywords
from .extract_text_from_file import extract_text_from_file
from .guess_file_type import guess_file_type
@ -8,24 +6,12 @@ class FileMetadata(TypedDict):
name: str
mime_type: str
extension: str
keywords: list[str]
def get_file_metadata(file: BinaryIO) -> FileMetadata:
"""Get metadata from a file"""
file.seek(0)
file_type = guess_file_type(file)
file.seek(0)
file_text = extract_text_from_file(file, file_type)
import uuid
try:
keywords = extract_keywords(file_text)
except:
keywords = ["no keywords detected" + str(uuid.uuid4())]
file_path = file.name
file_name = file_path.split("/")[-1].split(".")[0] if file_path else None
@ -34,5 +20,4 @@ def get_file_metadata(file: BinaryIO) -> FileMetadata:
file_path = file_path,
mime_type = file_type.mime,
extension = file_type.extension,
keywords = keywords
)

View file

@ -19,7 +19,6 @@ class AnthropicAdapter(LLMInterface):
)
self.model = model
@retry(stop = stop_after_attempt(5))
async def acreate_structured_output(
self,
text_input: str,
@ -31,7 +30,7 @@ class AnthropicAdapter(LLMInterface):
return await self.aclient(
model = self.model,
max_tokens = 4096,
max_retries = 0,
max_retries = 5,
messages = [{
"role": "user",
"content": f"""Use the given format to extract information

View file

@ -9,6 +9,7 @@ class LLMConfig(BaseSettings):
llm_api_key: Optional[str] = None
llm_temperature: float = 0.0
llm_streaming: bool = False
transcription_model: str = "whisper-1"
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
@ -18,6 +19,9 @@ class LLMConfig(BaseSettings):
"model": self.llm_model,
"endpoint": self.llm_endpoint,
"apiKey": self.llm_api_key,
"temperature": self.llm_temperature,
"streaming": self.llm_stream,
"transcriptionModel": self.transcription_model
}
@lru_cache

View file

@ -20,7 +20,7 @@ def get_llm_client():
raise ValueError("LLM API key is not set.")
from .openai.adapter import OpenAIAdapter
return OpenAIAdapter(llm_config.llm_api_key, llm_config.llm_model, llm_config.llm_streaming)
return OpenAIAdapter(api_key=llm_config.llm_api_key, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, streaming=llm_config.llm_streaming)
elif provider == LLMProvider.OLLAMA:
if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.")

View file

@ -1,5 +1,10 @@
import asyncio
import base64
import os
from pathlib import Path
from typing import List, Type
import aiofiles
import openai
import instructor
from pydantic import BaseModel
@ -9,6 +14,8 @@ from cognee.base_config import get_base_config
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.shared.data_models import MonitoringTool
import logging
logging.basicConfig(level=logging.DEBUG)
class OpenAIAdapter(LLMInterface):
name = "OpenAI"
@ -16,20 +23,22 @@ class OpenAIAdapter(LLMInterface):
api_key: str
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
def __init__(self, api_key: str, model: str, streaming: bool = False):
def __init__(self, api_key: str, model: str, transcription_model:str, streaming: bool = False):
base_config = get_base_config()
if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
from langfuse.openai import AsyncOpenAI, OpenAI
elif base_config.monitoring_tool == MonitoringTool.LANGSMITH:
from langsmith import wrappers
from openai import AsyncOpenAI
AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
else:
from openai import AsyncOpenAI, OpenAI
# if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
# from langfuse.openai import AsyncOpenAI, OpenAI
# elif base_config.monitoring_tool == MonitoringTool.LANGSMITH:
# from langsmith import wrappers
# from openai import AsyncOpenAI
# AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
# else:
from openai import AsyncOpenAI, OpenAI
self.aclient = instructor.from_openai(AsyncOpenAI(api_key = api_key))
self.client = instructor.from_openai(OpenAI(api_key = api_key))
self.base_openai_client = OpenAI(api_key = api_key)
self.transcription_model = "whisper-1"
self.model = model
self.api_key = api_key
self.streaming = streaming
@ -120,6 +129,49 @@ class OpenAIAdapter(LLMInterface):
response_model = response_model,
)
@retry(stop = stop_after_attempt(5))
def create_transcript(self, input):
"""Generate a audio transcript from a user query."""
if not os.path.isfile(input):
raise FileNotFoundError(f"The file {input} does not exist.")
with open(input, 'rb') as audio_file:
audio_data = audio_file.read()
transcription = self.base_openai_client.audio.transcriptions.create(
model=self.transcription_model ,
file=Path(input),
)
return transcription
@retry(stop = stop_after_attempt(5))
def transcribe_image(self, input) -> BaseModel:
with open(input, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
return self.base_openai_client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}",
},
},
],
}
],
max_tokens=300,
)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query."""
if not text_input:

View file

@ -1,2 +1,2 @@
Chose the summary that is the most relevant to the query`{{ query }}`
Here are the summaries:`{{ summaries }}`
Chose the summaries that are relevant to the following query: `{{ query }}`
Here are the all summaries: `{{ summaries }}`

View file

@ -1,36 +1,26 @@
You are a top-tier algorithm
designed for extracting information in structured formats to build a knowledge graph.
- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
- **Edges** represent relationships between concepts. They're akin to Wikipedia links.
- The aim is to achieve simplicity and clarity in the
knowledge graph, making it accessible for a vast audience.
YOU ARE ONLY EXTRACTING DATA FOR COGNITIVE LAYER `{{ layer }}`
## 1. Labeling Nodes
- **Consistency**: Ensure you use basic or elementary types for node labels.
- For example, when you identify an entity representing a person,
always label it as **"Person"**.
Avoid using more specific terms like "mathematician" or "scientist".
- Include event, entity, time, or action nodes to the category.
- Classify the memory type as episodic or semantic.
- **Node IDs**: Never utilize integers as node IDs.
Node IDs should be names or human-readable identifiers found in the text.
## 2. Handling Numerical Data and Dates
- Numerical data, like age or other related information,
should be incorporated as attributes or properties of the respective nodes.
- **No Separate Nodes for Dates/Numbers**:
Do not create separate nodes for dates or numerical values.
Always attach them as attributes or properties of nodes.
- **Property Format**: Properties must be in a key-value format.
- **Quotation Marks**: Never use escaped single or double quotes within property values.
- **Naming Convention**: Use snake_case for relationship names, e.g., `acted_in`.
## 3. Coreference Resolution
- **Maintain Entity Consistency**:
When extracting entities, it's vital to ensure consistency.
If an entity, such as "John Doe", is mentioned multiple times
in the text but is referred to by different names or pronouns (e.g., "Joe", "he"),
always use the most complete identifier for that entity throughout the knowledge graph.
In this example, use "John Doe" as the entity ID.
Remember, the knowledge graph should be coherent and easily understandable,
so maintaining consistency in entity references is crucial.
## 4. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination"""
You are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.
**Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
**Edges** represent relationships between concepts. They're akin to Wikipedia links.
The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.
# 1. Labeling Nodes
**Consistency**: Ensure you use basic or elementary types for node labels.
- For example, when you identify an entity representing a person, always label it as **"Person"**.
- Avoid using more specific terms like "Mathematician" or "Scientist".
- Don't use too generic terms like "Entity".
**Node IDs**: Never utilize integers as node IDs.
- Node IDs should be names or human-readable identifiers found in the text.
# 2. Handling Numerical Data and Dates
- For example, when you identify an entity representing a date, always label it as **"Date"**.
- Extract the date in the format "YYYY-MM-DD"
- If not possible to extract the whole date, extract month or year, or both if available.
- **Property Format**: Properties must be in a key-value format.
- **Quotation Marks**: Never use escaped single or double quotes within property values.
- **Naming Convention**: Use snake_case for relationship names, e.g., `acted_in`.
# 3. Coreference Resolution
- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.
If an entity, such as "John Doe", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., "Joe", "he"),
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Persons ID.
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
# 4. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination"""

View file

@ -1 +1,3 @@
You are a summarization engine and you should sumamarize content. Be brief and concise
You are a top-tier summarization engine. Your task is to summarize text and make it versatile.
Be brief and concise, but keep the important information and the subject.
Use synonim words where possible in order to change the wording but keep the meaning.

View file

@ -0,0 +1,152 @@
import asyncio
from uuid import uuid5, NAMESPACE_OID
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
from ..data.extraction.extract_categories import extract_categories
async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
if len(data_chunks) == 0:
return data_chunks
chunk_classifications = await asyncio.gather(
*[extract_categories(chunk.text, classification_model) for chunk in data_chunks],
)
classification_data_points = []
for chunk_index, chunk in enumerate(data_chunks):
chunk_classification = chunk_classifications[chunk_index]
classification_data_points.append(uuid5(NAMESPACE_OID, chunk_classification.label.type))
classification_data_points.append(uuid5(NAMESPACE_OID, chunk_classification.label.type))
for classification_subclass in chunk_classification.label.subclass:
classification_data_points.append(uuid5(NAMESPACE_OID, classification_subclass.value))
vector_engine = get_vector_engine()
class Keyword(BaseModel):
id: str
text: str
chunk_id: str
document_id: str
collection_name = "classification"
if await vector_engine.has_collection(collection_name):
existing_data_points = await vector_engine.retrieve(
collection_name,
list(set(classification_data_points)),
) if len(classification_data_points) > 0 else []
existing_points_map = {point.id: True for point in existing_data_points}
else:
existing_points_map = {}
await vector_engine.create_collection(collection_name, payload_schema = Keyword)
data_points = []
nodes = []
edges = []
for (chunk_index, data_chunk) in enumerate(data_chunks):
chunk_classification = chunk_classifications[chunk_index]
classification_type_label = chunk_classification.label.type
classification_type_id = uuid5(NAMESPACE_OID, classification_type_label)
if classification_type_id not in existing_points_map:
data_points.append(
DataPoint[Keyword](
id = str(classification_type_id),
payload = Keyword.parse_obj({
"id": str(classification_type_id),
"text": classification_type_label,
"chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id),
}),
embed_field = "text",
)
)
nodes.append((
str(classification_type_id),
dict(
id = str(classification_type_id),
name = classification_type_label,
type = classification_type_label,
)
))
existing_points_map[classification_type_id] = True
edges.append((
str(data_chunk.chunk_id),
str(classification_type_id),
"is_media_type",
dict(
relationship_name = "is_media_type",
source_node_id = str(data_chunk.chunk_id),
target_node_id = str(classification_type_id),
),
))
for classification_subclass in chunk_classification.label.subclass:
classification_subtype_label = classification_subclass.value
classification_subtype_id = uuid5(NAMESPACE_OID, classification_subtype_label)
if classification_subtype_id not in existing_points_map:
data_points.append(
DataPoint[Keyword](
id = str(classification_subtype_id),
payload = Keyword.parse_obj({
"id": str(classification_subtype_id),
"text": classification_subtype_label,
"chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id),
}),
embed_field = "text",
)
)
nodes.append((
str(classification_subtype_id),
dict(
id = str(classification_subtype_id),
name = classification_subtype_label,
type = classification_subtype_label,
)
))
edges.append((
str(classification_type_id),
str(classification_subtype_id),
"contains",
dict(
relationship_name = "contains",
source_node_id = str(classification_type_id),
target_node_id = str(classification_subtype_id),
),
))
existing_points_map[classification_subtype_id] = True
edges.append((
str(data_chunk.chunk_id),
str(classification_subtype_id),
"is_classified_as",
dict(
relationship_name = "is_classified_as",
source_node_id = str(data_chunk.chunk_id),
target_node_id = str(classification_subtype_id),
),
))
if len(nodes) > 0 or len(edges) > 0:
await vector_engine.create_data_points(collection_name, data_points)
graph_engine = await get_graph_engine()
await graph_engine.add_nodes(nodes)
await graph_engine.add_edges(edges)
return data_chunks

View file

@ -1,6 +1,6 @@
import uuid
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.shared.data_models import GraphDBType
from cognee.infrastructure.databases.graph.config import get_graph_config
@ -92,7 +92,7 @@ def graph_ready_output(results):
if __name__ == "__main__":
async def main():
graph_client = await get_graph_client()
graph_client = await get_graph_engine()
graph = graph_client.graph
# for nodes, attr in graph.nodes(data=True):

View file

@ -0,0 +1,18 @@
from uuid import UUID
from cognee.shared.data_models import Document
from cognee.infrastructure.databases.graph import get_graph_engine
async def save_document_node(document: Document, parent_node_id: UUID = None):
graph_engine = get_graph_engine()
await graph_engine.add_node(document.id, document.model_dump())
if parent_node_id:
await graph_engine.add_edge(
parent_node_id,
document.id,
"has_document",
dict(relationship_name = "has_document"),
)
return document

View file

@ -0,0 +1 @@
from .save_data_chunks import save_data_chunks

View file

@ -0,0 +1,97 @@
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
async def save_data_chunks(data_chunks: list[DocumentChunk], collection_name: str):
if len(data_chunks) == 0:
return data_chunks
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
# Remove and unlink existing chunks
if await vector_engine.has_collection(collection_name):
existing_chunks = [DocumentChunk.parse_obj(chunk.payload) for chunk in (await vector_engine.retrieve(
collection_name,
[str(chunk.chunk_id) for chunk in data_chunks],
))]
if len(existing_chunks) > 0:
await vector_engine.delete_data_points(collection_name, [str(chunk.chunk_id) for chunk in existing_chunks])
await graph_engine.remove_connection_to_successors_of([chunk.chunk_id for chunk in existing_chunks], "next_chunk")
await graph_engine.remove_connection_to_predecessors_of([chunk.chunk_id for chunk in existing_chunks], "has_chunk")
else:
await vector_engine.create_collection(collection_name, payload_schema = DocumentChunk)
# Add to vector storage
await vector_engine.create_data_points(
collection_name,
[
DataPoint[DocumentChunk](
id = str(chunk.chunk_id),
payload = chunk,
embed_field = "text",
) for chunk in data_chunks
],
)
# Add to graph storage
chunk_nodes = []
chunk_edges = []
for chunk in data_chunks:
chunk_nodes.append((
str(chunk.chunk_id),
dict(
id = str(chunk.chunk_id),
chunk_id = str(chunk.chunk_id),
document_id = str(chunk.document_id),
word_count = chunk.word_count,
chunk_index = chunk.chunk_index,
cut_type = chunk.cut_type,
pages = chunk.pages,
)
))
chunk_edges.append((
str(chunk.document_id),
str(chunk.chunk_id),
"has_chunk",
dict(
relationship_name = "has_chunk",
source_node_id = str(chunk.document_id),
target_node_id = str(chunk.chunk_id),
),
))
previous_chunk_id = get_previous_chunk_id(data_chunks, chunk)
if previous_chunk_id is not None:
chunk_edges.append((
str(previous_chunk_id),
str(chunk.chunk_id),
"next_chunk",
dict(
relationship_name = "next_chunk",
source_node_id = str(previous_chunk_id),
target_node_id = str(chunk.chunk_id),
),
))
await graph_engine.add_nodes(chunk_nodes)
await graph_engine.add_edges(chunk_edges)
return data_chunks
def get_previous_chunk_id(document_chunks: list[DocumentChunk], current_chunk: DocumentChunk) -> DocumentChunk:
if current_chunk.chunk_index == 0:
return current_chunk.document_id
for chunk in document_chunks:
if str(chunk.document_id) == str(current_chunk.document_id) \
and chunk.chunk_index == current_chunk.chunk_index - 1:
return chunk.chunk_id
return None

View file

@ -0,0 +1,3 @@
from .chunk_by_word import chunk_by_word
from .chunk_by_sentence import chunk_by_sentence
from .chunk_by_paragraph import chunk_by_paragraph

View file

@ -0,0 +1,53 @@
from cognee.modules.data.chunking import chunk_by_paragraph
if __name__ == "__main__":
def test_chunking_on_whole_text():
test_text = """This is example text. It contains multiple sentences.
This is a second paragraph. First two paragraphs are whole.
Third paragraph is a bit longer and is finished with a dot."""
chunks = []
for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs = False):
chunks.append(chunk_data)
assert len(chunks) == 3
assert chunks[0]["text"] == "This is example text. It contains multiple sentences."
assert chunks[0]["word_count"] == 8
assert chunks[0]["cut_type"] == "paragraph_end"
assert chunks[1]["text"] == "This is a second paragraph. First two paragraphs are whole."
assert chunks[1]["word_count"] == 10
assert chunks[1]["cut_type"] == "paragraph_end"
assert chunks[2]["text"] == "Third paragraph is a bit longer and is finished with a dot."
assert chunks[2]["word_count"] == 12
assert chunks[2]["cut_type"] == "sentence_end"
def test_chunking_on_cut_text():
test_text = """This is example text. It contains multiple sentences.
This is a second paragraph. First two paragraphs are whole.
Third paragraph is cut and is missing the dot at the end"""
chunks = []
for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs = False):
chunks.append(chunk_data)
assert len(chunks) == 3
assert chunks[0]["text"] == "This is example text. It contains multiple sentences."
assert chunks[0]["word_count"] == 8
assert chunks[0]["cut_type"] == "paragraph_end"
assert chunks[1]["text"] == "This is a second paragraph. First two paragraphs are whole."
assert chunks[1]["word_count"] == 10
assert chunks[1]["cut_type"] == "paragraph_end"
assert chunks[2]["text"] == "Third paragraph is cut and is missing the dot at the end"
assert chunks[2]["word_count"] == 12
assert chunks[2]["cut_type"] == "sentence_cut"
test_chunking_on_whole_text()
test_chunking_on_cut_text()

View file

@ -0,0 +1,69 @@
from uuid import uuid5, NAMESPACE_OID
from .chunk_by_sentence import chunk_by_sentence
def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs = True):
paragraph = ""
last_cut_type = None
last_paragraph_id = None
paragraph_word_count = 0
paragraph_chunk_index = 0
for (paragraph_id, __, sentence, word_count, end_type) in chunk_by_sentence(data):
if paragraph_word_count > 0 and paragraph_word_count + word_count > paragraph_length:
if batch_paragraphs is True:
chunk_id = uuid5(NAMESPACE_OID, paragraph)
yield dict(
text = paragraph.strip(),
word_count = paragraph_word_count,
id = chunk_id, # When batching paragraphs, the paragraph_id is the same as chunk_id.
# paragraph_id doens't mean anything since multiple paragraphs are merged.
chunk_id = chunk_id,
chunk_index = paragraph_chunk_index,
cut_type = last_cut_type,
)
else:
yield dict(
text = paragraph.strip(),
word_count = paragraph_word_count,
id = last_paragraph_id,
chunk_id = uuid5(NAMESPACE_OID, paragraph),
chunk_index = paragraph_chunk_index,
cut_type = last_cut_type,
)
paragraph_chunk_index += 1
paragraph_word_count = 0
paragraph = ""
paragraph += (" " if len(paragraph) > 0 else "") + sentence
paragraph_word_count += word_count
if end_type == "paragraph_end" or end_type == "sentence_cut":
if batch_paragraphs is True:
paragraph += "\n\n" if end_type == "paragraph_end" else ""
else:
yield dict(
text = paragraph.strip(),
word_count = paragraph_word_count,
paragraph_id = paragraph_id,
chunk_id = uuid5(NAMESPACE_OID, paragraph),
chunk_index = paragraph_chunk_index,
cut_type = end_type,
)
paragraph_chunk_index = 0
paragraph_word_count = 0
paragraph = ""
last_cut_type = end_type
last_paragraph_id = paragraph_id
if len(paragraph) > 0:
yield dict(
chunk_id = uuid5(NAMESPACE_OID, paragraph),
text = paragraph,
word_count = paragraph_word_count,
paragraph_id = last_paragraph_id,
chunk_index = paragraph_chunk_index,
cut_type = last_cut_type,
)

View file

@ -0,0 +1,28 @@
from uuid import uuid4
from .chunk_by_word import chunk_by_word
def chunk_by_sentence(data: str):
sentence = ""
paragraph_id = uuid4()
chunk_index = 0
word_count = 0
for (word, word_type) in chunk_by_word(data):
sentence += (" " if len(sentence) > 0 else "") + word
word_count += 1
if word_type == "paragraph_end" or word_type == "sentence_end":
yield (paragraph_id, chunk_index, sentence, word_count, word_type)
sentence = ""
word_count = 0
paragraph_id = uuid4() if word_type == "paragraph_end" else paragraph_id
chunk_index = 0 if word_type == "paragraph_end" else chunk_index + 1
if len(sentence) > 0:
yield (
paragraph_id,
chunk_index,
sentence,
word_count,
"sentence_cut",
)

View file

@ -0,0 +1,60 @@
import re
def chunk_by_word(data: str):
sentence_endings = r"[.;!?…]"
paragraph_endings = r"[\n\r]"
last_processed_character = ""
word = ""
i = 0
while i < len(data):
character = data[i]
if word == "" and (re.match(paragraph_endings, character) or character == " "):
i = i + 1
continue
def is_real_paragraph_end():
if re.match(sentence_endings, last_processed_character):
return True
j = i + 1
next_character = data[j] if j < len(data) else None
while next_character is not None and (re.match(paragraph_endings, next_character) or next_character == " "):
j += 1
next_character = data[j] if j < len(data) else None
if next_character.isupper():
return True
return False
if re.match(paragraph_endings, character):
yield (word, "paragraph_end" if is_real_paragraph_end() else "word")
word = ""
i = i + 1
continue
if character == " ":
yield [word, "word"]
word = ""
i = i + 1
continue
word += character
last_processed_character = character
if re.match(sentence_endings, character):
# Check for ellipses.
if i + 2 <= len(data) and data[i] == "." and data[i + 1] == "." and data[i + 2] == ".":
word += ".."
i = i + 2
is_paragraph_end = i + 1 < len(data) and re.match(paragraph_endings, data[i + 1])
yield (word, "paragraph_end" if is_paragraph_end else "sentence_end")
word = ""
i += 1
if len(word) > 0:
yield (word, "word")

View file

@ -1 +1,2 @@
from .prune_data import prune_data
from .prune_system import prune_system

View file

@ -0,0 +1,7 @@
from cognee.base_config import get_base_config
from cognee.infrastructure.files.storage import LocalStorage
async def prune_data():
base_config = get_base_config()
data_root_directory = base_config.data_root_directory
LocalStorage.remove_all(data_root_directory)

View file

@ -1,13 +1,17 @@
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.relational import get_relationaldb_config
async def prune_system(graph = True, vector = True):
async def prune_system(graph = True, vector = True, metadata = False):
if graph:
graph_config = get_graph_config()
graph_client = await get_graph_client()
await graph_client.delete_graph()
graph_engine = await get_graph_engine()
await graph_engine.delete_graph()
if vector:
vector_engine = get_vector_engine()
await vector_engine.prune()
if metadata:
db_config = get_relationaldb_config()
db_engine = db_config.database_engine
db_engine.delete_database()

View file

@ -0,0 +1 @@
from .extract_topics import extract_topics_yake, extract_topics_keybert

View file

@ -0,0 +1,5 @@
from pydantic import BaseModel
class TextSummary(BaseModel):
text: str
chunk_id: str

View file

@ -0,0 +1,36 @@
import asyncio
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
from ...processing.chunk_types.DocumentChunk import DocumentChunk
from ...extraction.extract_summary import extract_summary
from .models.TextSummary import TextSummary
async def summarize_text_chunks(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel], collection_name: str = "summaries"):
if len(data_chunks) == 0:
return data_chunks
chunk_summaries = await asyncio.gather(
*[extract_summary(chunk.text, summarization_model) for chunk in data_chunks]
)
vector_engine = get_vector_engine()
await vector_engine.create_collection(collection_name, payload_schema = TextSummary)
await vector_engine.create_data_points(
collection_name,
[
DataPoint[TextSummary](
id = str(chunk.chunk_id),
payload = dict(
chunk_id = str(chunk.chunk_id),
text = chunk_summaries[chunk_index].summary,
),
embed_field = "text",
) for (chunk_index, chunk) in enumerate(data_chunks)
],
)
return data_chunks

View file

@ -11,16 +11,4 @@ async def extract_categories(content: str, response_model: Type[BaseModel]):
llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model)
return process_categories(llm_output.model_dump())
def process_categories(llm_output) -> List[dict]:
# Extract the first subclass from the list (assuming there could be more)
data_category = llm_output["label"]["subclass"][0] if len(llm_output["label"]["subclass"]) > 0 else None
data_type = llm_output["label"]["type"].lower()
return [{
"data_type": data_type,
# The data_category is the value of the Enum member (e.g., "News stories and blog posts")
"category_name": data_category.value if data_category else "Other types of text data",
}]
return llm_output

View file

@ -10,4 +10,4 @@ async def extract_summary(content: str, response_model: Type[BaseModel]):
llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model)
return llm_output.model_dump()
return llm_output

View file

@ -0,0 +1,113 @@
import re
import nltk
from nltk.tag import pos_tag
from nltk.corpus import stopwords, wordnet
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
def extract_topics_yake(texts: list[str]):
from yake import KeywordExtractor
keyword_extractor = KeywordExtractor(
top = 3,
n = 2,
dedupLim = 0.2,
dedupFunc = "levenshtein", # "seqm" | "levenshtein"
windowsSize = 1,
)
for text in texts:
topics = keyword_extractor.extract_keywords(preprocess_text(text))
yield [topic[0] for topic in topics]
def extract_topics_keybert(texts: list[str]):
from keybert import KeyBERT
kw_model = KeyBERT()
for text in texts:
topics = kw_model.extract_keywords(
preprocess_text(text),
keyphrase_ngram_range = (1, 2),
top_n = 3,
# use_mmr = True,
# diversity = 0.9,
)
yield [topic[0] for topic in topics]
def preprocess_text(text: str):
try:
# Used for stopwords removal.
stopwords.ensure_loaded()
except LookupError:
nltk.download("stopwords", quiet = True)
stopwords.ensure_loaded()
try:
# Used in WordNetLemmatizer.
wordnet.ensure_loaded()
except LookupError:
nltk.download("wordnet", quiet = True)
wordnet.ensure_loaded()
try:
# Used in word_tokenize.
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt", quiet = True)
text = text.lower()
# Remove punctuation
text = re.sub(r"[^\w\s-]", "", text)
# Tokenize the text
tokens = word_tokenize(text)
tagged_tokens = pos_tag(tokens)
tokens = [word for word, tag in tagged_tokens if tag in ["NNP", "NN", "JJ"]]
# Remove stop words
stop_words = set(stopwords.words("english"))
tokens = [word for word in tokens if word not in stop_words]
# Lemmatize the text
lemmatizer = WordNetLemmatizer()
tokens = [lemmatizer.lemmatize(word) for word in tokens]
# Join tokens back to a single string
processed_text = " ".join(tokens)
return processed_text
# def clean_text(text: str):
# text = re.sub(r"[ \t]{2,}|[\n\r]", " ", text.lower())
# # text = re.sub(r"[`\"'.,;!?…]", "", text).strip()
# return text
# def remove_stop_words(text: str):
# try:
# stopwords.ensure_loaded()
# except LookupError:
# download("stopwords")
# stopwords.ensure_loaded()
# stop_words = set(stopwords.words("english"))
# text = text.split()
# text = [word for word in text if not word in stop_words]
# return " ".join(text)
if __name__ == "__main__":
import os
file_dir = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(file_dir, "texts.json"), "r", encoding = "utf-8") as file:
import json
texts = json.load(file)
for topics in extract_topics_yake(texts):
print(topics)
print("\n")

View file

@ -0,0 +1,66 @@
import re
from nltk.downloader import download
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords, wordnet
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
def extract_topics(text: str):
sentences = sent_tokenize(text)
try:
wordnet.ensure_loaded()
except LookupError:
download("wordnet")
wordnet.ensure_loaded()
lemmatizer = WordNetLemmatizer()
base_notation_sentences = [lemmatizer.lemmatize(sentence) for sentence in sentences]
tf_vectorizer = TfidfVectorizer(tokenizer = word_tokenize, token_pattern = None)
transformed_corpus = tf_vectorizer.fit_transform(base_notation_sentences)
svd = TruncatedSVD(n_components = 10)
svd_corpus = svd.fit(transformed_corpus)
feature_scores = dict(
zip(
tf_vectorizer.vocabulary_,
svd_corpus.components_[0]
)
)
topics = sorted(
feature_scores,
# key = feature_scores.get,
key = lambda x: transformed_corpus[0, tf_vectorizer.vocabulary_[x]],
reverse = True,
)[:10]
return topics
def clean_text(text: str):
text = re.sub(r"[ \t]{2,}|[\n\r]", " ", text.lower())
return re.sub(r"[`\"'.,;!?…]", "", text).strip()
def remove_stop_words(text: str):
try:
stopwords.ensure_loaded()
except LookupError:
download("stopwords")
stopwords.ensure_loaded()
stop_words = set(stopwords.words("english"))
text = text.split()
text = [word for word in text if not word in stop_words]
return " ".join(text)
if __name__ == "__main__":
text = """Lorem Ipsum is simply dummy text of the printing and typesetting industry... Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book… It has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.
Why do we use it?
It is a long established fact that a reader will be distracted by the readable content of a page when looking at its layout! The point of using Lorem Ipsum is that it has a more-or-less normal distribution of letters, as opposed to using 'Content here, content here', making it look like readable English. Many desktop publishing packages and web page editors now use Lorem Ipsum as their default model text, and a search for 'lorem ipsum' will uncover many web sites still in their infancy. Various versions have evolved over the years, sometimes by accident, sometimes on purpose (injected humour and the like).
"""
print(extract_topics(remove_stop_words(clean_text(text))))

View file

@ -0,0 +1,70 @@
from typing import Type, Optional, get_args, get_origin
from pydantic import BaseModel
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
async def add_model_class_to_graph(
model_class: Type[BaseModel],
graph: GraphDBInterface,
parent: Optional[str] = None,
relationship: Optional[str] = None,
):
model_name = model_class.__name__
if await graph.extract_node(model_name):
return
await graph.add_node(model_name, dict(type = "model"))
if parent and relationship:
await graph.add_edge(
parent,
model_name,
relationship,
dict(
relationship_name = relationship,
source_node_id = parent,
target_node_id = model_name,
),
)
for field_name, field in model_class.model_fields.items():
original_types = get_args(field.annotation)
field_type = original_types[0] if len(original_types) > 0 else None
if field_type is None:
continue
if hasattr(field_type, "model_fields"): # Check if field type is a Pydantic model
await add_model_class_to_graph(field_type, graph, model_name, field_name)
elif get_origin(field.annotation) == list:
list_types = get_args(field_type)
for item_type in list_types:
await add_model_class_to_graph(item_type, graph, model_name, field_name)
elif isinstance(field_type, list):
item_type = get_args(field_type)[0]
if hasattr(item_type, "model_fields"):
await add_model_class_to_graph(item_type, graph, model_name, field_name)
else:
await graph.add_node(str(item_type), dict(type = "value"))
await graph.add_edge(
model_name,
str(item_type),
field_name,
dict(
relationship_name = field_name,
source_node_id = model_name,
target_node_id = str(item_type),
),
)
else:
await graph.add_node(str(field_type), dict(type = "value"))
await graph.add_edge(
model_name,
str(field_type),
field_name,
dict(
relationship_name = field_name,
source_node_id = model_name,
target_node_id = str(field_type),
),
)

View file

@ -0,0 +1,20 @@
from typing import Type
from pydantic import BaseModel
from cognee.shared.data_models import KnowledgeGraph
from cognee.infrastructure.databases.graph import get_graph_engine
from ...processing.chunk_types.DocumentChunk import DocumentChunk
from .add_model_class_to_graph import add_model_class_to_graph
async def establish_graph_topology(data_chunks: list[DocumentChunk], topology_model: Type[BaseModel]):
if topology_model == KnowledgeGraph:
return data_chunks
graph_engine = await get_graph_engine()
await add_model_class_to_graph(topology_model, graph_engine)
return data_chunks
def generate_node_id(node_id: str) -> str:
return node_id.upper().replace(" ", "_").replace("'", "")

View file

@ -0,0 +1,117 @@
import asyncio
from datetime import datetime
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from ...processing.chunk_types.DocumentChunk import DocumentChunk
from .extract_knowledge_graph import extract_content_graph
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
graph_engine = await get_graph_engine()
type_ids = [generate_node_id(node.type) for chunk_graph in chunk_graphs for node in chunk_graph.nodes]
graph_type_node_ids = list(set(type_ids))
graph_type_nodes = await graph_engine.extract_nodes(graph_type_node_ids)
existing_type_nodes_map = {node["id"]: node for node in graph_type_nodes}
graph_nodes = []
graph_edges = []
for (chunk_index, chunk) in enumerate(data_chunks):
graph = chunk_graphs[chunk_index]
if graph is None:
continue
for node in graph.nodes:
node_id = generate_node_id(node.id)
graph_nodes.append((
node_id,
dict(
id = node_id,
chunk_id = str(chunk.chunk_id),
document_id = str(chunk.document_id),
name = node.name,
type = node.type.lower().capitalize(),
description = node.description,
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)
))
graph_edges.append((
str(chunk.chunk_id),
node_id,
"contains",
dict(
relationship_name = "contains",
source_node_id = str(chunk.chunk_id),
target_node_id = node_id,
),
))
type_node_id = generate_node_id(node.type)
if type_node_id not in existing_type_nodes_map:
node_name = node.type.lower().capitalize()
type_node = dict(
id = type_node_id,
name = node_name,
type = node_name,
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)
graph_nodes.append((type_node_id, type_node))
existing_type_nodes_map[type_node_id] = type_node
graph_edges.append((
str(chunk.chunk_id),
type_node_id,
"contains_entity_type",
dict(
relationship_name = "contains_entity_type",
source_node_id = str(chunk.chunk_id),
target_node_id = type_node_id,
),
))
# Add relationship between entity type and entity itself: "Jake is Person"
graph_edges.append((
type_node_id,
node_id,
"is_entity_type",
dict(
relationship_name = "is_entity_type",
source_node_id = type_node_id,
target_node_id = node_id,
),
))
# Add relationship that came from graphs.
for edge in graph.edges:
graph_edges.append((
generate_node_id(edge.source_node_id),
generate_node_id(edge.target_node_id),
edge.relationship_name,
dict(
relationship_name = edge.relationship_name,
source_node_id = generate_node_id(edge.source_node_id),
target_node_id = generate_node_id(edge.target_node_id),
),
))
await graph_engine.add_nodes(graph_nodes)
await graph_engine.add_edges(graph_edges)
return data_chunks
def generate_node_id(node_id: str) -> str:
return node_id.upper().replace(" ", "_").replace("'", "")

View file

@ -3,10 +3,10 @@ from pydantic import BaseModel
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import render_prompt
async def extract_content_graph(content: str, cognitive_layer: str, response_model: Type[BaseModel]):
async def extract_content_graph(content: str, response_model: Type[BaseModel]):
llm_client = get_llm_client()
system_prompt = render_prompt("generate_graph_prompt.txt", { "layer": cognitive_layer })
output = await llm_client.acreate_structured_output(content, system_prompt, response_model)
system_prompt = render_prompt("generate_graph_prompt.txt", {})
content_graph = await llm_client.acreate_structured_output(content, system_prompt, response_model)
return output.model_dump()
return content_graph

View file

@ -0,0 +1,7 @@
[
"Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book.\nIt has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.",
"It is a long established fact that a reader will be distracted by the readable content of a page when looking at its layout.\n\tThe point of using Lorem Ipsum is that it has a more-or-less normal distribution of letters, as opposed to using 'Content here, content here', making it look like readable English. Many desktop publishing packages and web page editors now use Lorem Ipsum as their default model text, and a search for 'lorem ipsum' will uncover many web sites still in their infancy.\n Various versions have evolved over the years, sometimes by accident, sometimes on purpose (injected humour and the like).",
"Contrary to popular belief, Lorem Ipsum is not simply random text. It has roots in a piece of classical Latin literature from 45 BC, making it over 2000 years old. Richard McClintock, a Latin professor at Hampden-Sydney College in Virginia, looked up one of the more obscure Latin words, consectetur, from a Lorem Ipsum passage, and going through the cites of the word in classical literature, discovered the undoubtable source. Lorem Ipsum comes from sections 1.10.32 and 1.10.33 of \"de Finibus Bonorum et Malorum\" (The Extremes of Good and Evil) by Cicero, written in 45 BC. This book is a treatise on the theory of ethics, very popular during the Renaissance.\n The first line of Lorem Ipsum, \"Lorem ipsum dolor sit amet..\", comes from a line in section 1.10.32.",
"The standard chunk of Lorem Ipsum used since the 1500s is reproduced below for those interested. Sections 1.10.32 and 1.10.33 from \"de Finibus Bonorum et Malorum\" by Cicero are also reproduced in their exact original form, accompanied by English versions from the 1914 translation by H. Rackham.",
"There are many variations of passages of Lorem Ipsum available, but the majority have suffered alteration in some form, by injected humour, or randomised words which don't look even slightly believable. If you are going to use a passage of Lorem Ipsum, you need to be sure there isn't anything embarrassing hidden in the middle of text. All the Lorem Ipsum generators on the Internet tend to repeat predefined chunks as necessary, making this the first true generator on the Internet. It uses a dictionary of over 200 Latin words, combined with a handful of model sentence structures, to generate Lorem Ipsum which looks reasonable. The generated Lorem Ipsum is therefore always free from repetition, injected humour, or non-characteristic words etc."
]

View file

@ -0,0 +1,10 @@
from pydantic import BaseModel
class DocumentChunk(BaseModel):
text: str
word_count: int
document_id: str
chunk_id: str
chunk_index: int
cut_type: str
pages: list[int]

View file

@ -0,0 +1,122 @@
from uuid import uuid5, NAMESPACE_OID
from typing import Optional, Generator
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.data.chunking import chunk_by_paragraph
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
from cognee.modules.data.processing.document_types.Document import Document
class AudioReader:
id: str
file_path: str
def __init__(self, id: str, file_path: str):
self.id = id
self.file_path = file_path
self.llm_client = get_llm_client() # You can choose different models like "tiny", "base", "small", etc.
def read(self, max_chunk_size: Optional[int] = 1024):
chunk_index = 0
chunk_size = 0
chunked_pages = []
paragraph_chunks = []
# Transcribe the audio file
result = self.llm_client.create_transcript(self.file_path)
text = result.text
# Simulate reading text in chunks as done in TextReader
def read_text_chunks(text, chunk_size):
for i in range(0, len(text), chunk_size):
yield text[i:i + chunk_size]
page_index = 0
for page_text in read_text_chunks(text, max_chunk_size):
chunked_pages.append(page_index)
page_index += 1
for chunk_data in chunk_by_paragraph(page_text, max_chunk_size, batch_paragraphs=True):
if chunk_size + chunk_data["word_count"] <= max_chunk_size:
paragraph_chunks.append(chunk_data)
chunk_size += chunk_data["word_count"]
else:
if len(paragraph_chunks) == 0:
yield DocumentChunk(
text=chunk_data["text"],
word_count=chunk_data["word_count"],
document_id=str(self.id),
chunk_id=str(chunk_data["chunk_id"]),
chunk_index=chunk_index,
cut_type=chunk_data["cut_type"],
pages=[page_index],
)
paragraph_chunks = []
chunk_size = 0
else:
chunk_text = " ".join(chunk["text"] for chunk in paragraph_chunks)
yield DocumentChunk(
text=chunk_text,
word_count=chunk_size,
document_id=str(self.id),
chunk_id=str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
chunk_index=chunk_index,
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
pages=chunked_pages,
)
chunked_pages = [page_index]
paragraph_chunks = [chunk_data]
chunk_size = chunk_data["word_count"]
chunk_index += 1
if len(paragraph_chunks) > 0:
yield DocumentChunk(
text=" ".join(chunk["text"] for chunk in paragraph_chunks),
word_count=chunk_size,
document_id=str(self.id),
chunk_id=str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
chunk_index=chunk_index,
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
pages=chunked_pages,
)
class AudioDocument(Document):
type: str = "audio"
title: str
file_path: str
def __init__(self, title: str, file_path: str):
self.id = uuid5(NAMESPACE_OID, title)
self.title = title
self.file_path = file_path
reader = AudioReader(self.id, self.file_path)
def get_reader(self) -> AudioReader:
reader = AudioReader(self.id, self.file_path)
return reader
def to_dict(self) -> dict:
return dict(
id=str(self.id),
type=self.type,
title=self.title,
file_path=self.file_path,
)
# if __name__ == "__main__":
# # Sample usage of AudioDocument
# audio_document = AudioDocument("sample_audio", "/Users/vasa/Projects/cognee/cognee/modules/data/processing/document_types/preamble10.wav")
# audio_reader = audio_document.get_reader()
# for chunk in audio_reader.read():
# print(chunk.text)
# print(chunk.word_count)
# print(chunk.document_id)
# print(chunk.chunk_id)
# print(chunk.chunk_index)
# print(chunk.cut_type)
# print(chunk.pages)
# print("----")

View file

@ -0,0 +1,8 @@
from uuid import UUID
from typing import Protocol
class Document(Protocol):
id: UUID
type: str
title: str
file_path: str

View file

@ -0,0 +1,124 @@
from uuid import uuid5, NAMESPACE_OID
from typing import Optional, Generator
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.data.chunking import chunk_by_paragraph
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
from cognee.modules.data.processing.document_types.Document import Document
class ImageReader:
id: str
file_path: str
def __init__(self, id: str, file_path: str):
self.id = id
self.file_path = file_path
self.llm_client = get_llm_client() # You can choose different models like "tiny", "base", "small", etc.
def read(self, max_chunk_size: Optional[int] = 1024):
chunk_index = 0
chunk_size = 0
chunked_pages = []
paragraph_chunks = []
# Transcribe the image file
result = self.llm_client.transcribe_image(self.file_path)
print("Transcription result: ", result.choices[0].message.content)
text = result.choices[0].message.content
# Simulate reading text in chunks as done in TextReader
def read_text_chunks(text, chunk_size):
for i in range(0, len(text), chunk_size):
yield text[i:i + chunk_size]
page_index = 0
for page_text in read_text_chunks(text, max_chunk_size):
chunked_pages.append(page_index)
page_index += 1
for chunk_data in chunk_by_paragraph(page_text, max_chunk_size, batch_paragraphs=True):
if chunk_size + chunk_data["word_count"] <= max_chunk_size:
paragraph_chunks.append(chunk_data)
chunk_size += chunk_data["word_count"]
else:
if len(paragraph_chunks) == 0:
yield DocumentChunk(
text=chunk_data["text"],
word_count=chunk_data["word_count"],
document_id=str(self.id),
chunk_id=str(chunk_data["chunk_id"]),
chunk_index=chunk_index,
cut_type=chunk_data["cut_type"],
pages=[page_index],
)
paragraph_chunks = []
chunk_size = 0
else:
chunk_text = " ".join(chunk["text"] for chunk in paragraph_chunks)
yield DocumentChunk(
text=chunk_text,
word_count=chunk_size,
document_id=str(self.id),
chunk_id=str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
chunk_index=chunk_index,
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
pages=chunked_pages,
)
chunked_pages = [page_index]
paragraph_chunks = [chunk_data]
chunk_size = chunk_data["word_count"]
chunk_index += 1
if len(paragraph_chunks) > 0:
yield DocumentChunk(
text=" ".join(chunk["text"] for chunk in paragraph_chunks),
word_count=chunk_size,
document_id=str(self.id),
chunk_id=str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
chunk_index=chunk_index,
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
pages=chunked_pages,
)
class ImageDocument(Document):
type: str = "image"
title: str
file_path: str
def __init__(self, title: str, file_path: str):
self.id = uuid5(NAMESPACE_OID, title)
self.title = title
self.file_path = file_path
reader = ImageReader(self.id, self.file_path)
def get_reader(self) -> ImageReader:
reader = ImageReader(self.id, self.file_path)
return reader
def to_dict(self) -> dict:
return dict(
id=str(self.id),
type=self.type,
title=self.title,
file_path=self.file_path,
)
# if __name__ == "__main__":
# # Sample usage of AudioDocument
# audio_document = ImageDocument("sample_audio", "/Users/vasa/Projects/cognee/assets/architecture.png")
# audio_reader = audio_document.get_reader()
# for chunk in audio_reader.read():
# print(chunk.text)
# print(chunk.word_count)
# print(chunk.document_id)
# print(chunk.chunk_id)
# print(chunk.chunk_index)
# print(chunk.cut_type)
# print(chunk.pages)
# print("----")

View file

@ -0,0 +1,109 @@
# import pdfplumber
import logging
from uuid import uuid5, NAMESPACE_OID
from typing import Optional
from pypdf import PdfReader as pypdf_PdfReader
from cognee.modules.data.chunking import chunk_by_paragraph
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
from .Document import Document
class PdfReader():
id: str
file_path: str
def __init__(self, id: str, file_path: str):
self.id = id
self.file_path = file_path
def get_number_of_pages(self):
file = pypdf_PdfReader(self.file_path)
num_pages = file.get_num_pages()
file.stream.close()
return num_pages
def read(self, max_chunk_size: Optional[int] = 1024):
chunk_index = 0
chunk_size = 0
chunked_pages = []
paragraph_chunks = []
file = pypdf_PdfReader(self.file_path)
for (page_index, page) in enumerate(file.pages):
page_text = page.extract_text()
chunked_pages.append(page_index)
for chunk_data in chunk_by_paragraph(page_text, max_chunk_size, batch_paragraphs = True):
if chunk_size + chunk_data["word_count"] <= max_chunk_size:
paragraph_chunks.append(chunk_data)
chunk_size += chunk_data["word_count"]
else:
if len(paragraph_chunks) == 0:
yield DocumentChunk(
text = chunk_data["text"],
word_count = chunk_data["word_count"],
document_id = str(self.id),
chunk_id = str(chunk_data["chunk_id"]),
chunk_index = chunk_index,
cut_type = chunk_data["cut_type"],
pages = [page_index],
)
paragraph_chunks = []
chunk_size = 0
else:
chunk_text = " ".join(chunk["text"] for chunk in paragraph_chunks)
yield DocumentChunk(
text = chunk_text,
word_count = chunk_size,
document_id = str(self.id),
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
chunk_index = chunk_index,
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
pages = chunked_pages,
)
chunked_pages = [page_index]
paragraph_chunks = [chunk_data]
chunk_size = chunk_data["word_count"]
chunk_index += 1
if len(paragraph_chunks) > 0:
yield DocumentChunk(
text = " ".join(chunk["text"] for chunk in paragraph_chunks),
word_count = chunk_size,
document_id = str(self.id),
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
chunk_index = chunk_index,
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
pages = chunked_pages,
)
file.stream.close()
class PdfDocument(Document):
type: str = "pdf"
title: str
num_pages: int
file_path: str
def __init__(self, title: str, file_path: str):
self.id = uuid5(NAMESPACE_OID, title)
self.title = title
self.file_path = file_path
logging.debug("file_path: %s", self.file_path)
reader = PdfReader(self.id, self.file_path)
self.num_pages = reader.get_number_of_pages()
def get_reader(self) -> PdfReader:
logging.debug("file_path: %s", self.file_path)
reader = PdfReader(self.id, self.file_path)
return reader
def to_dict(self) -> dict:
return dict(
id = str(self.id),
type = self.type,
title = self.title,
num_pages = self.num_pages,
file_path = self.file_path,
)

View file

@ -0,0 +1,112 @@
from uuid import uuid5, NAMESPACE_OID
from typing import Optional
from cognee.modules.data.chunking import chunk_by_paragraph
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
from .Document import Document
class TextReader():
id: str
file_path: str
def __init__(self, id: str, file_path: str):
self.id = id
self.file_path = file_path
def get_number_of_pages(self):
num_pages = 1 # Pure text is not formatted
return num_pages
def read(self, max_chunk_size: Optional[int] = 1024):
chunk_index = 0
chunk_size = 0
chunked_pages = []
paragraph_chunks = []
def read_text_chunks(file_path):
with open(file_path, mode = "r", encoding = "utf-8") as file:
while True:
text = file.read(1024)
if len(text.strip()) == 0:
break
yield text
page_index = 0
for page_text in read_text_chunks(self.file_path):
chunked_pages.append(page_index)
page_index += 1
for chunk_data in chunk_by_paragraph(page_text, max_chunk_size, batch_paragraphs = True):
if chunk_size + chunk_data["word_count"] <= max_chunk_size:
paragraph_chunks.append(chunk_data)
chunk_size += chunk_data["word_count"]
else:
if len(paragraph_chunks) == 0:
yield DocumentChunk(
text = chunk_data["text"],
word_count = chunk_data["word_count"],
document_id = str(self.id),
chunk_id = str(chunk_data["chunk_id"]),
chunk_index = chunk_index,
cut_type = chunk_data["cut_type"],
pages = [page_index],
)
paragraph_chunks = []
chunk_size = 0
else:
chunk_text = " ".join(chunk["text"] for chunk in paragraph_chunks)
yield DocumentChunk(
text = chunk_text,
word_count = chunk_size,
document_id = str(self.id),
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
chunk_index = chunk_index,
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
pages = chunked_pages,
)
chunked_pages = [page_index]
paragraph_chunks = [chunk_data]
chunk_size = chunk_data["word_count"]
chunk_index += 1
if len(paragraph_chunks) > 0:
yield DocumentChunk(
text = " ".join(chunk["text"] for chunk in paragraph_chunks),
word_count = chunk_size,
document_id = str(self.id),
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
chunk_index = chunk_index,
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
pages = chunked_pages,
)
class TextDocument(Document):
type: str = "text"
title: str
num_pages: int
file_path: str
def __init__(self, title: str, file_path: str):
self.id = uuid5(NAMESPACE_OID, title)
self.title = title
self.file_path = file_path
reader = TextReader(self.id, self.file_path)
self.num_pages = reader.get_number_of_pages()
def get_reader(self) -> TextReader:
reader = TextReader(self.id, self.file_path)
return reader
def to_dict(self) -> dict:
return dict(
id = str(self.id),
type = self.type,
title = self.title,
num_pages = self.num_pages,
file_path = self.file_path,
)

View file

@ -0,0 +1,2 @@
from .PdfDocument import PdfDocument
from .TextDocument import TextDocument

View file

@ -0,0 +1,13 @@
import os
from cognee.modules.data.processing.document_types.PdfDocument import PdfDocument
if __name__ == "__main__":
test_file_path = os.path.join(os.path.dirname(__file__), "artificial-inteligence.pdf")
pdf_doc = PdfDocument("Test document.pdf", test_file_path)
reader = pdf_doc.get_reader()
for paragraph_data in reader.read():
print(paragraph_data["word_count"])
print(paragraph_data["text"])
print(paragraph_data["cut_type"])
print("\n")

View file

@ -0,0 +1,25 @@
from cognee.infrastructure.databases.vector import get_vector_engine
from .chunk_types import DocumentChunk
async def filter_affected_chunks(data_chunks: list[DocumentChunk], collection_name: str) -> list[DocumentChunk]:
vector_engine = get_vector_engine()
if not await vector_engine.has_collection(collection_name):
# If collection doesn't exist, all data_chunks are new
return data_chunks
existing_chunks = await vector_engine.retrieve(
collection_name,
[str(chunk.chunk_id) for chunk in data_chunks],
)
existing_chunks_map = {chunk.id: chunk.payload for chunk in existing_chunks}
affected_data_chunks = []
for chunk in data_chunks:
if chunk.chunk_id not in existing_chunks_map or \
chunk.text != existing_chunks_map[chunk.chunk_id]["text"]:
affected_data_chunks.append(chunk)
return affected_data_chunks

View file

@ -0,0 +1,30 @@
from cognee.infrastructure.databases.vector import get_vector_engine
from .chunk_types import DocumentChunk
async def has_new_chunks(data_chunks: list[DocumentChunk], collection_name: str) -> list[DocumentChunk]:
vector_engine = get_vector_engine()
if not await vector_engine.has_collection(collection_name):
# There is no collection created,
# so no existing chunks, all chunks are new.
return True
existing_chunks = await vector_engine.retrieve(
collection_name,
[str(chunk.chunk_id) for chunk in data_chunks],
)
if len(existing_chunks) == 0:
# If we don't find any existing chunk,
# all chunks are new.
return True
existing_chunks_map = {chunk.id: chunk.payload for chunk in existing_chunks}
new_data_chunks = [
chunk for chunk in data_chunks \
if chunk.chunk_id not in existing_chunks_map \
or chunk.text != existing_chunks_map[chunk.chunk_id]["text"]
]
return len(new_data_chunks) > 0

View file

@ -0,0 +1,41 @@
from cognee.infrastructure.databases.graph import get_graph_engine
from .document_types import Document
async def process_documents(documents: list[Document], parent_node_id: str = None):
graph_engine = await get_graph_engine()
nodes = []
edges = []
if parent_node_id and await graph_engine.extract_node(parent_node_id) is None:
nodes.append((parent_node_id, {}))
document_nodes = await graph_engine.extract_nodes([str(document.id) for document in documents])
for (document_index, document) in enumerate(documents):
document_node = document_nodes[document_index] if document_index in document_nodes else None
if document_node is None:
nodes.append((str(document.id), document.to_dict()))
if parent_node_id:
edges.append((
parent_node_id,
str(document.id),
"has_document",
dict(
relationship_name = "has_document",
source_node_id = parent_node_id,
target_node_id = str(document.id),
),
))
if len(nodes) > 0:
await graph_engine.add_nodes(nodes)
await graph_engine.add_edges(edges)
for document in documents:
document_reader = document.get_reader()
for document_chunk in document_reader.read(max_chunk_size = 1024):
yield document_chunk

View file

@ -0,0 +1,29 @@
from cognee.infrastructure.databases.graph import get_graph_engine
# from cognee.infrastructure.databases.vector import get_vector_engine
from .chunk_types import DocumentChunk
async def remove_obsolete_chunks(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
graph_engine = await get_graph_engine()
document_ids = set((data_chunk.document_id for data_chunk in data_chunks))
obsolete_chunk_ids = []
for document_id in document_ids:
chunk_ids = await graph_engine.get_successor_ids(document_id, edge_label = "has_chunk")
for chunk_id in chunk_ids:
previous_chunks = await graph_engine.get_predecessor_ids(chunk_id, edge_label = "next_chunk")
if len(previous_chunks) == 0:
obsolete_chunk_ids.append(chunk_id)
if len(obsolete_chunk_ids) > 0:
await graph_engine.delete_nodes(obsolete_chunk_ids)
disconnected_nodes = await graph_engine.get_disconnected_nodes()
if len(disconnected_nodes) > 0:
await graph_engine.delete_nodes(disconnected_nodes)
return data_chunks

View file

@ -8,7 +8,7 @@ def classify(data: Union[str, BinaryIO], filename: str = None):
return TextData(data)
if isinstance(data, BufferedReader):
return BinaryData(data)
return BinaryData(data, data.name.split("/")[-1] if data.name else filename)
if hasattr(data, "file"):
return BinaryData(data.file, filename)

View file

@ -17,7 +17,7 @@ class BinaryData(IngestionData):
def get_identifier(self):
metadata = self.get_metadata()
return metadata["mime_type"] + "_" + "|".join(metadata["keywords"])
return self.name + "_" + metadata["mime_type"]
def get_metadata(self):
self.ensure_metadata()

View file

@ -13,7 +13,7 @@ class TextData(IngestionData):
self.data = data
def get_identifier(self):
keywords = self.get_metadata()["keywords"]
keywords = extract_keywords(self.data)
return "text/plain" + "_" + "|".join(keywords)
@ -24,7 +24,7 @@ class TextData(IngestionData):
def ensure_metadata(self):
if self.metadata is None:
self.metadata = dict(keywords = extract_keywords(self.data))
self.metadata = {}
def get_data(self):
return self.data

View file

@ -17,8 +17,8 @@ def save_data_to_file(data: Union[str, BinaryIO], dataset_name: str, filename: s
file_metadata = classified_data.get_metadata()
if "name" not in file_metadata or file_metadata["name"] is None:
letters = string.ascii_lowercase
random_string = ''.join(random.choice(letters) for _ in range(32))
file_metadata["name"] = "file" + random_string
random_string = "".join(random.choice(letters) for _ in range(32))
file_metadata["name"] = "text_" + random_string + ".txt"
file_name = file_metadata["name"]
LocalStorage(storage_path).store(file_name, classified_data.get_data())

View file

@ -0,0 +1,18 @@
from uuid import UUID, uuid4
from typing import Optional
from pydantic import BaseModel
from .models.Task import Task
class PipelineConfig(BaseModel):
batch_count: int = 10
description: Optional[str]
class Pipeline():
id: UUID = uuid4()
name: str
description: str
tasks: list[Task] = []
def __init__(self, name: str, pipeline_config: PipelineConfig):
self.name = name
self.description = pipeline_config.description

View file

@ -0,0 +1,2 @@
from .operations.run_tasks import run_tasks
from .operations.run_parallel import run_tasks_parallel

View file

@ -0,0 +1,22 @@
from typing import List
from uuid import uuid4
from datetime import datetime, timezone
from sqlalchemy import Column, UUID, DateTime, String, Text
from sqlalchemy.orm import relationship, Mapped
from cognee.infrastructure.databases.relational import ModelBase
from .PipelineTask import PipelineTask
class Pipeline(ModelBase):
__tablename__ = "pipelines"
id = Column(UUID, primary_key = True, default = uuid4())
name = Column(String)
description = Column(Text, nullable = True)
created_at = Column(DateTime, default = datetime.now(timezone.utc))
updated_at = Column(DateTime, onupdate = datetime.now(timezone.utc))
tasks = Mapped[List["Task"]] = relationship(
secondary = PipelineTask.__tablename__,
back_populates = "pipeline",
)

View file

@ -0,0 +1,14 @@
from uuid import uuid4
from datetime import datetime, timezone
from sqlalchemy import Column, DateTime, UUID, ForeignKey
from cognee.infrastructure.databases.relational import ModelBase
class PipelineTask(ModelBase):
__tablename__ = "pipeline_task"
id = Column(UUID, primary_key = True, default = uuid4())
created_at = Column(DateTime, default = datetime.now(timezone.utc))
pipeline_id = Column("pipeline", UUID, ForeignKey("pipeline.id"), primary_key = True)
task_id = Column("task", UUID, ForeignKey("task.id"), primary_key = True)

View file

@ -0,0 +1,24 @@
from uuid import uuid4
from typing import List
from datetime import datetime, timezone
from sqlalchemy.orm import relationship, Mapped
from sqlalchemy import Column, String, DateTime, UUID, Text
from cognee.infrastructure.databases.relational import ModelBase
from .PipelineTask import PipelineTask
class Task(ModelBase):
__tablename__ = "tasks"
id = Column(UUID, primary_key = True, default = uuid4())
name = Column(String)
description = Column(Text, nullable = True)
executable = Column(Text)
created_at = Column(DateTime, default = datetime.now(timezone.utc))
updated_at = Column(DateTime, onupdate = datetime.now(timezone.utc))
datasets: Mapped[List["Pipeline"]] = relationship(
secondary = PipelineTask.__tablename__,
back_populates = "task"
)

View file

@ -0,0 +1,14 @@
import asyncio
from cognee.shared.utils import render_graph
from cognee.infrastructure.databases.graph import get_graph_engine
if __name__ == "__main__":
async def main():
graph_client = await get_graph_engine()
graph = graph_client.graph
graph_url = await render_graph(graph)
print(graph_url)
asyncio.run(main())

View file

@ -0,0 +1,34 @@
import asyncio
from cognee.modules.pipelines.operations.run_tasks import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
async def main():
def number_generator(num):
for i in range(num):
yield i + 1
async def add_one(num):
yield num + 1
async def multiply_by_two(nums):
for num in nums:
yield num * 2
async def add_one_to_batched_data(num):
yield num + 1
pipeline = run_tasks([
Task(number_generator, task_config = {"batch_size": 1}),
Task(add_one, task_config = {"batch_size": 5}),
Task(multiply_by_two, task_config = {"batch_size": 1}),
Task(add_one_to_batched_data),
], 10)
async for result in pipeline:
print("\n")
print(result)
print("\n")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,4 @@
from ..models import Pipeline, Task
def add_task(pipeline: Pipeline, task: Task):
pipeline.tasks.append(task)

View file

@ -0,0 +1,12 @@
from typing import Any, Callable, Generator
import asyncio
from ..tasks.Task import Task
def run_tasks_parallel(tasks: [Task]) -> Callable[[Any], Generator[Any, Any, Any]]:
async def parallel_run(*args, **kwargs):
parallel_tasks = [asyncio.create_task(task.run(*args, **kwargs)) for task in tasks]
results = await asyncio.gather(*parallel_tasks)
return results[len(results) - 1] if len(results) > 1 else []
return Task(parallel_run)

View file

@ -0,0 +1,90 @@
import inspect
import logging
from ..tasks.Task import Task
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
async def run_tasks(tasks: [Task], data):
if len(tasks) == 0:
yield data
return
running_task = tasks[0]
batch_size = running_task.task_config["batch_size"]
leftover_tasks = tasks[1:]
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
# next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
if inspect.isasyncgenfunction(running_task.executable):
logger.info(f"Running async generator task: `{running_task.executable.__name__}`")
try:
results = []
async_iterator = running_task.run(data)
async for partial_result in async_iterator:
results.append(partial_result)
if len(results) == batch_size:
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
yield result
results = []
if len(results) > 0:
async for result in run_tasks(leftover_tasks, results):
yield result
results = []
logger.info(f"Finished async generator task: `{running_task.executable.__name__}`")
except Exception as error:
logger.error(
"Error occurred while running async generator task: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info = True,
)
raise error
elif inspect.isgeneratorfunction(running_task.executable):
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
try:
results = []
for partial_result in running_task.run(data):
results.append(partial_result)
if len(results) == batch_size:
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
yield result
results = []
if len(results) > 0:
async for result in run_tasks(leftover_tasks, results):
yield result
results = []
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
except Exception as error:
logger.error(
"Error occurred while running generator task: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info = True,
)
raise error
elif inspect.iscoroutinefunction(running_task.executable):
task_result = await running_task.run(data)
async for result in run_tasks(leftover_tasks, task_result):
yield result
elif inspect.isfunction(running_task.executable):
task_result = running_task.run(data)
async for result in run_tasks(leftover_tasks, task_result):
yield result

View file

@ -0,0 +1,32 @@
from typing import Union, Callable, Any, Coroutine, Generator, AsyncGenerator
class Task():
executable: Union[
Callable[..., Any],
Callable[..., Coroutine[Any, Any, Any]],
Generator[Any, Any, Any],
AsyncGenerator[Any, Any],
]
task_config: dict[str, Any] = {
"batch_size": 1,
}
default_params: dict[str, Any] = {}
def __init__(self, executable, *args, task_config = None, **kwargs):
self.executable = executable
self.default_params = {
"args": args,
"kwargs": kwargs
}
if task_config is not None:
self.task_config = task_config
if "batch_size" not in task_config:
self.task_config["batch_size"] = 1
def run(self, *args, **kwargs):
combined_args = self.default_params["args"] + args
combined_kwargs = { **self.default_params["kwargs"], **kwargs }
return self.executable(*combined_args, **combined_kwargs)

Some files were not shown because too many files have changed in this diff Show more