feat: pipelines and tasks (#119)
* feat: simple graph pipeline * feat: implement incremental graph generation * fix: various bug fixes * fix: upgrade weaviate-client --------- Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
This commit is contained in:
parent
9a57659266
commit
14555a25d0
121 changed files with 4409 additions and 1779 deletions
10
.github/workflows/test_common.yml
vendored
10
.github/workflows/test_common.yml
vendored
|
|
@ -30,9 +30,9 @@ jobs:
|
|||
# Test all python versions on ubuntu only
|
||||
include:
|
||||
- python-version: "3.9.x"
|
||||
os: "ubuntu-latest"
|
||||
os: "ubuntu-22.04"
|
||||
- python-version: "3.10.x"
|
||||
os: "ubuntu-latest"
|
||||
os: "ubuntu-22.04"
|
||||
# - python-version: "3.12.x"
|
||||
# os: "ubuntu-latest"
|
||||
|
||||
|
|
@ -90,6 +90,12 @@ jobs:
|
|||
ENV: 'dev'
|
||||
run: poetry run python ./cognee/tests/test_library.py
|
||||
|
||||
- name: Clean up disk space
|
||||
run: |
|
||||
sudo rm -rf ~/.cache
|
||||
sudo rm -rf /tmp/*
|
||||
df -h
|
||||
|
||||
- name: Build with Poetry
|
||||
run: poetry build
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ export interface Data {
|
|||
name: string;
|
||||
filePath: string;
|
||||
mimeType: string;
|
||||
keywords: string[];
|
||||
}
|
||||
|
||||
interface DatasetLike {
|
||||
|
|
@ -113,9 +112,6 @@ export default function DataView({ datasetId, data, onClose, onDataAdd }: DataVi
|
|||
<td>
|
||||
<Text>{dataItem.mimeType}</Text>
|
||||
</td>
|
||||
<td>
|
||||
<Text>{dataItem.keywords.join(", ")}</Text>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
export default function addData(dataset: { id: string }, files: File[]) {
|
||||
const formData = new FormData();
|
||||
files.forEach((file) => {
|
||||
formData.append('data', file, file.name);
|
||||
})
|
||||
formData.append('datasetId', dataset.id);
|
||||
const file = files[0];
|
||||
formData.append('data', file, file.name);
|
||||
|
||||
return fetch('http://0.0.0.0:8000/add', {
|
||||
method: 'POST',
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ export default function SearchView() {
|
|||
value: 'ADJACENT',
|
||||
label: 'Look for graph node\'s neighbors',
|
||||
}, {
|
||||
value: 'CATEGORIES',
|
||||
label: 'Search by categories (Comma separated categories)',
|
||||
value: 'TRAVERSE',
|
||||
label: 'Traverse through the graph and get knowledge',
|
||||
}];
|
||||
const [searchType, setSearchType] = useState(searchOptions[0]);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from .api.v1.config.config import config
|
||||
from .api.v1.add.add import add
|
||||
from .api.v1.cognify.cognify import cognify
|
||||
from .api.v1.cognify.cognify_v2 import cognify
|
||||
from .api.v1.datasets.datasets import datasets
|
||||
from .api.v1.search.search import search, SearchType
|
||||
from .api.v1.prune import prune
|
||||
|
|
|
|||
|
|
@ -2,9 +2,10 @@
|
|||
import os
|
||||
import aiohttp
|
||||
import uvicorn
|
||||
import asyncio
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
import sentry_sdk
|
||||
from typing import Dict, Any, List, Union, Optional, Literal
|
||||
from typing_extensions import Annotated
|
||||
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
|
||||
|
|
@ -19,7 +20,14 @@ logging.basicConfig(
|
|||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(debug = True)
|
||||
if os.getenv("ENV") == "prod":
|
||||
sentry_sdk.init(
|
||||
dsn = os.getenv("SENTRY_REPORTING_URL"),
|
||||
traces_sample_rate = 1.0,
|
||||
profiles_sample_rate = 1.0,
|
||||
)
|
||||
|
||||
app = FastAPI(debug = os.getenv("ENV") != "prod")
|
||||
|
||||
origins = [
|
||||
"http://frontend:3000",
|
||||
|
|
@ -69,10 +77,10 @@ async def delete_dataset(dataset_id: str):
|
|||
@app.get("/datasets/{dataset_id}/graph", response_model=list)
|
||||
async def get_dataset_graph(dataset_id: str):
|
||||
from cognee.shared.utils import render_graph
|
||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
try:
|
||||
graph_client = await get_graph_client()
|
||||
graph_client = await get_graph_engine()
|
||||
graph_url = await render_graph(graph_client.graph)
|
||||
|
||||
return JSONResponse(
|
||||
|
|
@ -95,7 +103,6 @@ async def get_dataset_data(dataset_id: str):
|
|||
dict(
|
||||
id=data["id"],
|
||||
name=f"{data['name']}.{data['extension']}",
|
||||
keywords=data["keywords"].split("|"),
|
||||
filePath=data["file_path"],
|
||||
mimeType=data["mime_type"],
|
||||
)
|
||||
|
|
@ -129,8 +136,8 @@ class AddPayload(BaseModel):
|
|||
|
||||
@app.post("/add", response_model=dict)
|
||||
async def add(
|
||||
data: List[UploadFile],
|
||||
datasetId: str = Form(...),
|
||||
data: List[UploadFile] = File(...),
|
||||
):
|
||||
""" This endpoint is responsible for adding data to the graph."""
|
||||
from cognee.api.v1.add import add as cognee_add
|
||||
|
|
@ -177,17 +184,17 @@ class CognifyPayload(BaseModel):
|
|||
@app.post("/cognify", response_model=dict)
|
||||
async def cognify(payload: CognifyPayload):
|
||||
""" This endpoint is responsible for the cognitive processing of the content."""
|
||||
from cognee.api.v1.cognify.cognify import cognify as cognee_cognify
|
||||
from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify
|
||||
try:
|
||||
await cognee_cognify(payload.datasets)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content="OK"
|
||||
status_code = 200,
|
||||
content = "OK"
|
||||
)
|
||||
except Exception as error:
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content={"error": str(error)}
|
||||
status_code = 409,
|
||||
content = {"error": str(error)}
|
||||
)
|
||||
|
||||
class SearchPayload(BaseModel):
|
||||
|
|
@ -255,30 +262,13 @@ def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
|||
try:
|
||||
logger.info("Starting server at %s:%s", host, port)
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.relational import get_relationaldb_config
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
|
||||
cognee_directory_path = os.path.abspath(".cognee_system")
|
||||
databases_directory_path = os.path.join(cognee_directory_path, "databases")
|
||||
|
||||
relational_config = get_relationaldb_config()
|
||||
relational_config.db_path = databases_directory_path
|
||||
relational_config.create_engine()
|
||||
|
||||
vector_config = get_vectordb_config()
|
||||
vector_config.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
|
||||
|
||||
graph_config = get_graph_config()
|
||||
graph_config.graph_file_path = os.path.join(databases_directory_path, "cognee.graph")
|
||||
|
||||
base_config = get_base_config()
|
||||
data_directory_path = os.path.abspath(".data_storage")
|
||||
base_config.data_root_directory = data_directory_path
|
||||
|
||||
from cognee.modules.data.deletion import prune_system
|
||||
asyncio.run(prune_system())
|
||||
from cognee.modules.data.deletion import prune_system, prune_data
|
||||
asyncio.run(prune_data())
|
||||
asyncio.run(prune_system(metadata = True))
|
||||
|
||||
uvicorn.run(app, host = host, port = port)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -97,7 +97,6 @@ async def add_files(file_paths: List[str], dataset_name: str):
|
|||
"file_path": file_metadata["file_path"],
|
||||
"extension": file_metadata["extension"],
|
||||
"mime_type": file_metadata["mime_type"],
|
||||
"keywords": "|".join(file_metadata["keywords"]),
|
||||
}
|
||||
|
||||
run_info = pipeline.run(
|
||||
|
|
|
|||
|
|
@ -6,21 +6,19 @@ import nltk
|
|||
from asyncio import Lock
|
||||
from nltk.corpus import stopwords
|
||||
|
||||
from cognee.infrastructure.data.chunking.LangchainChunkingEngine import LangchainChunkEngine
|
||||
from cognee.infrastructure.data.chunking.get_chunking_engine import get_chunk_engine
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
||||
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
|
||||
graph_ready_output, connect_nodes_in_graph
|
||||
from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks, add_data_chunks_basic_rag
|
||||
from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks
|
||||
from cognee.modules.cognify.graph.add_document_node import add_document_node
|
||||
from cognee.modules.cognify.graph.add_classification_nodes import add_classification_nodes
|
||||
from cognee.modules.cognify.graph.add_cognitive_layer_graphs import add_cognitive_layer_graphs
|
||||
from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
|
||||
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
|
||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.cognify.graph.add_cognitive_layers import add_cognitive_layers
|
||||
# from cognee.modules.cognify.graph.initialize_graph import initialize_graph
|
||||
from cognee.infrastructure.files.utils.guess_file_type import guess_file_type, FileTypeException
|
||||
from cognee.infrastructure.files.utils.extract_text_from_file import extract_text_from_file
|
||||
from cognee.modules.data.get_content_categories import get_content_categories
|
||||
|
|
@ -49,9 +47,7 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
stopwords.ensure_loaded()
|
||||
create_task_status_table()
|
||||
|
||||
# graph_config = get_graph_config()
|
||||
# graph_db_type = graph_config.graph_engine
|
||||
graph_client = await get_graph_client()
|
||||
graph_client = await get_graph_engine()
|
||||
|
||||
relational_config = get_relationaldb_config()
|
||||
db_engine = relational_config.database_engine
|
||||
|
|
@ -89,8 +85,8 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
|
||||
added_datasets = db_engine.get_datasets()
|
||||
|
||||
dataset_files = []
|
||||
# datasets is a dataset name string
|
||||
dataset_files = []
|
||||
dataset_name = datasets.replace(".", "_").replace(" ", "_")
|
||||
|
||||
for added_dataset in added_datasets:
|
||||
|
|
@ -145,10 +141,8 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
batch_size = 20
|
||||
file_count = 0
|
||||
files_batch = []
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
graph_config = get_graph_config()
|
||||
graph_topology = graph_config.graph_model
|
||||
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.infer_graph_topology and graph_config.graph_topology_task:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
|
|
@ -173,8 +167,8 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
else:
|
||||
document_id = await add_document_node(
|
||||
graph_client,
|
||||
parent_node_id=file_metadata['id'],
|
||||
document_metadata=file_metadata,
|
||||
parent_node_id = file_metadata['id'],
|
||||
document_metadata = file_metadata,
|
||||
)
|
||||
|
||||
files_batch.append((dataset_name, file_metadata, document_id))
|
||||
|
|
@ -196,7 +190,7 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
|||
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
|
||||
|
||||
graph_config = get_graph_config()
|
||||
graph_client = await get_graph_client()
|
||||
graph_client = await get_graph_engine()
|
||||
graph_topology = graph_config.graph_model
|
||||
|
||||
if graph_topology == SourceCodeGraph:
|
||||
|
|
@ -206,8 +200,6 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
|||
else:
|
||||
classified_categories = [{"data_type": "text", "category_name": "Unclassified text"}]
|
||||
|
||||
# await add_label_nodes(graph_client, document_id, chunk_id, file_metadata["keywords"].split("|"))
|
||||
|
||||
await add_classification_nodes(
|
||||
graph_client,
|
||||
parent_node_id = document_id,
|
||||
|
|
@ -271,10 +263,10 @@ if __name__ == "__main__":
|
|||
# await prune.prune_system()
|
||||
# #
|
||||
# from cognee.api.v1.add import add
|
||||
# data_directory_path = os.path.abspath("../../../.data")
|
||||
# data_directory_path = os.path.abspath("../../.data")
|
||||
# # print(data_directory_path)
|
||||
# # config.data_root_directory(data_directory_path)
|
||||
# # cognee_directory_path = os.path.abspath("../.cognee_system")
|
||||
# # cognee_directory_path = os.path.abspath(".cognee_system")
|
||||
# # config.system_root_directory(cognee_directory_path)
|
||||
#
|
||||
# await add("data://" +data_directory_path, "example")
|
||||
|
|
|
|||
139
cognee/api/v1/cognify/cognify_v2.py
Normal file
139
cognee/api/v1/cognify/cognify_v2.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.infrastructure.databases.relational.config import get_relationaldb_config
|
||||
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
|
||||
from cognee.modules.data.processing.document_types.ImageDocument import ImageDocument
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.data.processing.document_types import PdfDocument, TextDocument
|
||||
from cognee.modules.cognify.vector import save_data_chunks
|
||||
from cognee.modules.data.processing.process_documents import process_documents
|
||||
from cognee.modules.classification.classify_text_chunks import classify_text_chunks
|
||||
from cognee.modules.data.extraction.data_summary.summarize_text_chunks import summarize_text_chunks
|
||||
from cognee.modules.data.processing.filter_affected_chunks import filter_affected_chunks
|
||||
from cognee.modules.data.processing.remove_obsolete_chunks import remove_obsolete_chunks
|
||||
from cognee.modules.data.extraction.knowledge_graph.expand_knowledge_graph import expand_knowledge_graph
|
||||
from cognee.modules.data.extraction.knowledge_graph.establish_graph_topology import establish_graph_topology
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.pipelines import run_tasks, run_tasks_parallel
|
||||
from cognee.modules.tasks import create_task_status_table, update_task_status, get_task_status
|
||||
|
||||
logger = logging.getLogger("cognify.v2")
|
||||
|
||||
update_status_lock = asyncio.Lock()
|
||||
|
||||
async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = None):
|
||||
relational_config = get_relationaldb_config()
|
||||
db_engine = relational_config.database_engine
|
||||
create_task_status_table()
|
||||
|
||||
if datasets is None or len(datasets) == 0:
|
||||
return await cognify(db_engine.get_datasets())
|
||||
|
||||
|
||||
async def run_cognify_pipeline(dataset_name: str, files: list[dict]):
|
||||
async with update_status_lock:
|
||||
task_status = get_task_status([dataset_name])
|
||||
|
||||
if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED":
|
||||
logger.info(f"Dataset {dataset_name} is being processed.")
|
||||
return
|
||||
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_STARTED")
|
||||
try:
|
||||
cognee_config = get_cognify_config()
|
||||
graph_config = get_graph_config()
|
||||
root_node_id = None
|
||||
|
||||
if graph_config.infer_graph_topology and graph_config.graph_topology_task:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
|
||||
root_node_id = await topology_engine.add_graph_topology(files = files)
|
||||
elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology:
|
||||
from cognee.modules.topology.topology import TopologyEngine
|
||||
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
|
||||
await topology_engine.add_graph_topology(graph_config.topology_file_path)
|
||||
elif not graph_config.graph_topology_task:
|
||||
root_node_id = "ROOT"
|
||||
|
||||
tasks = [
|
||||
Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||
Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data
|
||||
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
||||
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
||||
Task(
|
||||
save_data_chunks,
|
||||
collection_name = "chunks",
|
||||
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
|
||||
run_tasks_parallel([
|
||||
Task(
|
||||
summarize_text_chunks,
|
||||
summarization_model = cognee_config.summarization_model,
|
||||
collection_name = "chunk_summaries",
|
||||
), # Summarize the document chunks
|
||||
Task(
|
||||
classify_text_chunks,
|
||||
classification_model = cognee_config.classification_model,
|
||||
),
|
||||
]),
|
||||
Task(remove_obsolete_chunks), # Remove the obsolete document chunks.
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, [
|
||||
PdfDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else
|
||||
AudioDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else
|
||||
ImageDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else
|
||||
TextDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"])
|
||||
for file in files
|
||||
])
|
||||
|
||||
async for result in pipeline:
|
||||
print(result)
|
||||
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED")
|
||||
except Exception as error:
|
||||
update_task_status(dataset_name, "DATASET_PROCESSING_ERROR")
|
||||
raise error
|
||||
|
||||
|
||||
existing_datasets = db_engine.get_datasets()
|
||||
|
||||
awaitables = []
|
||||
|
||||
|
||||
# dataset_files = []
|
||||
# dataset_name = datasets.replace(".", "_").replace(" ", "_")
|
||||
|
||||
# for added_dataset in existing_datasets:
|
||||
# if dataset_name in added_dataset:
|
||||
# dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset)))
|
||||
|
||||
for dataset in datasets:
|
||||
if dataset in existing_datasets:
|
||||
# for file_metadata in files:
|
||||
# if root_node_id is None:
|
||||
# root_node_id=file_metadata['id']
|
||||
awaitables.append(run_cognify_pipeline(dataset, db_engine.get_files_metadata(dataset)))
|
||||
|
||||
return await asyncio.gather(*awaitables)
|
||||
|
||||
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# from cognee.api.v1.add import add
|
||||
# from cognee.api.v1.datasets.datasets import datasets
|
||||
#
|
||||
#
|
||||
# async def aa():
|
||||
# await add("TEXT ABOUT NLP AND MONKEYS")
|
||||
#
|
||||
# print(datasets.discover_datasets())
|
||||
#
|
||||
# return
|
||||
|
||||
|
||||
|
||||
# asyncio.run(cognify())
|
||||
|
|
@ -16,6 +16,9 @@ class config():
|
|||
relational_config.db_path = databases_directory_path
|
||||
relational_config.create_engine()
|
||||
|
||||
graph_config = get_graph_config()
|
||||
graph_config.graph_file_path = os.path.join(databases_directory_path, "cognee.graph")
|
||||
|
||||
vector_config = get_vectordb_config()
|
||||
if vector_config.vector_engine_provider == "lancedb":
|
||||
vector_config.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
|
||||
|
|
|
|||
|
|
@ -1,17 +1,13 @@
|
|||
from cognee.modules.data.deletion import prune_system
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.modules.data.deletion import prune_system, prune_data
|
||||
|
||||
class prune():
|
||||
@staticmethod
|
||||
async def prune_data():
|
||||
base_config = get_base_config()
|
||||
data_root_directory = base_config.data_root_directory
|
||||
LocalStorage.remove_all(data_root_directory)
|
||||
await prune_data()
|
||||
|
||||
@staticmethod
|
||||
async def prune_system(graph = True, vector = True):
|
||||
await prune_system(graph, vector)
|
||||
async def prune_system(graph = True, vector = True, metadata = False):
|
||||
await prune_system(graph, vector, metadata)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
|
|
|||
|
|
@ -6,19 +6,16 @@ from pydantic import BaseModel, field_validator
|
|||
|
||||
from cognee.modules.search.graph import search_cypher
|
||||
from cognee.modules.search.graph.search_adjacent import search_adjacent
|
||||
from cognee.modules.search.vector.search_similarity import search_similarity
|
||||
from cognee.modules.search.graph.search_categories import search_categories
|
||||
from cognee.modules.search.graph.search_neighbour import search_neighbour
|
||||
from cognee.modules.search.vector.search_traverse import search_traverse
|
||||
from cognee.modules.search.graph.search_summary import search_summary
|
||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
from cognee.modules.search.graph.search_similarity import search_similarity
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
|
||||
class SearchType(Enum):
|
||||
ADJACENT = "ADJACENT"
|
||||
TRAVERSE = "TRAVERSE"
|
||||
SIMILARITY = "SIMILARITY"
|
||||
CATEGORIES = "CATEGORIES"
|
||||
NEIGHBOR = "NEIGHBOR"
|
||||
SUMMARY = "SUMMARY"
|
||||
SUMMARY_CLASSIFICATION = "SUMMARY_CLASSIFICATION"
|
||||
NODE_CLASSIFICATION = "NODE_CLASSIFICATION"
|
||||
|
|
@ -49,18 +46,15 @@ async def search(search_type: str, params: Dict[str, Any]) -> List:
|
|||
|
||||
|
||||
async def specific_search(query_params: List[SearchParameters]) -> List:
|
||||
graph_config = get_graph_config()
|
||||
graph_client = await get_graph_client(graph_config.graph_database_provider)
|
||||
graph_client = await get_graph_engine()
|
||||
graph = graph_client.graph
|
||||
|
||||
search_functions: Dict[SearchType, Callable] = {
|
||||
SearchType.ADJACENT: search_adjacent,
|
||||
SearchType.SIMILARITY: search_similarity,
|
||||
SearchType.CATEGORIES: search_categories,
|
||||
SearchType.NEIGHBOR: search_neighbour,
|
||||
SearchType.SUMMARY: search_summary,
|
||||
SearchType.CYPHER: search_cypher
|
||||
|
||||
SearchType.CYPHER: search_cypher,
|
||||
SearchType.TRAVERSE: search_traverse,
|
||||
SearchType.SIMILARITY: search_similarity,
|
||||
}
|
||||
|
||||
results = []
|
||||
|
|
@ -103,7 +97,7 @@ if __name__ == "__main__":
|
|||
# SearchType.SIMILARITY: {'query': 'your search query here'}
|
||||
# }
|
||||
# async def main():
|
||||
# graph_client = get_graph_client(GraphDBType.NETWORKX)
|
||||
# graph_client = get_graph_engine()
|
||||
|
||||
# await graph_client.load_graph_from_file()
|
||||
# graph = graph_client.graph
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from cognee.root_dir import get_absolute_path
|
|||
from cognee.shared.data_models import MonitoringTool
|
||||
|
||||
class BaseConfig(BaseSettings):
|
||||
data_root_directory: str = get_absolute_path(".data")
|
||||
data_root_directory: str = get_absolute_path(".data_storage")
|
||||
monitoring_tool: object = MonitoringTool.LANGFUSE
|
||||
graphistry_username: Optional[str] = None
|
||||
graphistry_password: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -16,17 +16,15 @@ def create_chunking_engine(config: ChunkingConfig):
|
|||
chunk_size=config["chunk_size"],
|
||||
chunk_overlap=config["chunk_overlap"],
|
||||
chunk_strategy=config["chunk_strategy"],
|
||||
|
||||
|
||||
)
|
||||
elif config["chunk_engine"] == ChunkEngine.DEFAULT_ENGINE:
|
||||
from cognee.infrastructure.data.chunking.DefaultChunkEngine import DefaultChunkEngine
|
||||
from cognee.infrastructure.data.chunking.DefaultChunkEngine import DefaultChunkEngine
|
||||
|
||||
return DefaultChunkEngine(
|
||||
chunk_size=config["chunk_size"],
|
||||
chunk_overlap=config["chunk_overlap"],
|
||||
chunk_strategy=config["chunk_strategy"],
|
||||
)
|
||||
return DefaultChunkEngine(
|
||||
chunk_size=config["chunk_size"],
|
||||
chunk_overlap=config["chunk_overlap"],
|
||||
chunk_strategy=config["chunk_strategy"],
|
||||
)
|
||||
elif config["chunk_engine"] == ChunkEngine.HAYSTACK_ENGINE:
|
||||
from cognee.infrastructure.data.chunking.HaystackChunkEngine import HaystackChunkEngine
|
||||
|
||||
|
|
|
|||
|
|
@ -3,4 +3,4 @@ from .config import get_chunk_config
|
|||
from .create_chunking_engine import create_chunking_engine
|
||||
|
||||
def get_chunk_engine():
|
||||
return create_chunking_engine(get_chunk_config().to_dict())
|
||||
return create_chunking_engine(get_chunk_config().to_dict())
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from .config import get_graph_config
|
||||
from .get_graph_engine import get_graph_engine
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class GraphConfig(BaseSettings):
|
|||
graph_model: object = KnowledgeGraph
|
||||
graph_topology_task: bool = False
|
||||
graph_topology: object = KnowledgeGraph
|
||||
infer_graph_topology: bool = True
|
||||
infer_graph_topology: bool = False
|
||||
topology_file_path: str = os.path.join(
|
||||
os.path.join(get_absolute_path(".cognee_system"), "databases"),
|
||||
"graph_topology.json"
|
||||
|
|
|
|||
|
|
@ -70,8 +70,6 @@ class FalcorDBAdapter(GraphDBInterface):
|
|||
return await self.query(query, params)
|
||||
|
||||
async def add_nodes(self, nodes: list[tuple[str, dict[str, Any]]]) -> None:
|
||||
# nodes_data = []
|
||||
|
||||
for node in nodes:
|
||||
node_id, node_properties = node
|
||||
node_id = node_id.replace(":", "_")
|
||||
|
|
@ -112,18 +110,27 @@ class FalcorDBAdapter(GraphDBInterface):
|
|||
query = """MATCH (node) WHERE node.layer_id IS NOT NULL
|
||||
RETURN node"""
|
||||
|
||||
return [result['node'] for result in (await self.query(query))]
|
||||
return [result["node"] for result in (await self.query(query))]
|
||||
|
||||
async def extract_node(self, node_id: str):
|
||||
query= """
|
||||
MATCH(node {id: $node_id})
|
||||
RETURN node
|
||||
"""
|
||||
|
||||
results = [node['node'] for node in (await self.query(query, dict(node_id = node_id)))]
|
||||
results = self.extract_nodes([node_id])
|
||||
|
||||
return results[0] if len(results) > 0 else None
|
||||
|
||||
async def extract_nodes(self, node_ids: List[str]):
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (node {id: id})
|
||||
RETURN node"""
|
||||
|
||||
params = {
|
||||
"node_ids": node_ids
|
||||
}
|
||||
|
||||
results = await self.query(query, params)
|
||||
|
||||
return results
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
node_id = id.replace(":", "_")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||
|
||||
from cognee.shared.data_models import GraphDBType
|
||||
from .config import get_graph_config
|
||||
from .graph_db_interface import GraphDBInterface
|
||||
from .networkx.adapter import NetworkXAdapter
|
||||
|
||||
|
||||
async def get_graph_client(graph_type: GraphDBType=None, graph_file_name: str = None) -> GraphDBInterface :
|
||||
async def get_graph_engine() -> GraphDBInterface :
|
||||
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||
config = get_graph_config()
|
||||
|
||||
|
|
@ -34,8 +33,10 @@ async def get_graph_client(graph_type: GraphDBType=None, graph_file_name: str =
|
|||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
graph_client = NetworkXAdapter(filename = config.graph_file_path)
|
||||
if (graph_client.graph is None):
|
||||
|
||||
if graph_client.graph is None:
|
||||
await graph_client.load_graph_from_file()
|
||||
|
||||
return graph_client
|
||||
|
|
@ -25,12 +25,24 @@ class GraphDBInterface(Protocol):
|
|||
node_id: str
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete_nodes(
|
||||
self,
|
||||
node_ids: list[str]
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def extract_node(
|
||||
self,
|
||||
node_id: str
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def extract_nodes(
|
||||
self,
|
||||
node_ids: list[str]
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_edge(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
""" Neo4j Adapter for Graph Database"""
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Optional, Any, List, Dict
|
||||
from contextlib import asynccontextmanager
|
||||
from neo4j import AsyncSession
|
||||
|
|
@ -56,6 +57,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
if "name" not in serialized_properties:
|
||||
serialized_properties["name"] = node_id
|
||||
|
||||
query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
|
||||
ON CREATE SET node += $properties
|
||||
RETURN ID(node) AS internal_id, node.id AS nodeId"""
|
||||
|
|
@ -68,16 +70,22 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
return await self.query(query, params)
|
||||
|
||||
async def add_nodes(self, nodes: list[tuple[str, dict[str, Any]]]) -> None:
|
||||
# nodes_data = []
|
||||
query = """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n {id: node.node_id})
|
||||
ON CREATE SET n += node.properties
|
||||
WITH n, node.node_id AS label
|
||||
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
|
||||
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
||||
"""
|
||||
|
||||
for node in nodes:
|
||||
node_id, node_properties = node
|
||||
node_id = node_id.replace(":", "_")
|
||||
nodes = [{
|
||||
"node_id": node_id,
|
||||
"properties": self.serialize_properties(node_properties),
|
||||
} for (node_id, node_properties) in nodes]
|
||||
|
||||
await self.add_node(
|
||||
node_id = node_id,
|
||||
node_properties = node_properties,
|
||||
)
|
||||
results = await self.query(query, dict(nodes = nodes))
|
||||
return results
|
||||
|
||||
async def extract_node_description(self, node_id: str):
|
||||
query = """MATCH (n)-[r]->(m)
|
||||
|
|
@ -111,15 +119,24 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
return [result["node"] for result in (await self.query(query))]
|
||||
|
||||
async def extract_node(self, node_id: str):
|
||||
query= """
|
||||
MATCH(node {id: $node_id})
|
||||
RETURN node
|
||||
"""
|
||||
|
||||
results = [node["node"] for node in (await self.query(query, dict(node_id = node_id)))]
|
||||
results = await self.extract_nodes([node_id])
|
||||
|
||||
return results[0] if len(results) > 0 else None
|
||||
|
||||
async def extract_nodes(self, node_ids: List[str]):
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (node {id: id})
|
||||
RETURN node"""
|
||||
|
||||
params = {
|
||||
"node_ids": node_ids
|
||||
}
|
||||
|
||||
results = await self.query(query, params)
|
||||
|
||||
return [result["node"] for result in results]
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
node_id = id.replace(":", "_")
|
||||
|
||||
|
|
@ -128,6 +145,18 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def delete_nodes(self, node_ids: list[str]) -> None:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (node {id: id})
|
||||
DETACH DELETE node"""
|
||||
|
||||
params = {
|
||||
"node_ids": node_ids
|
||||
}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
|
||||
serialized_properties = self.serialize_properties(edge_properties)
|
||||
from_node = from_node.replace(":", "_")
|
||||
|
|
@ -150,19 +179,75 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||
# edges_data = []
|
||||
query = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (from_node {id: edge.from_node})
|
||||
MATCH (to_node {id: edge.to_node})
|
||||
CALL apoc.create.relationship(from_node, edge.relationship_name, edge.properties, to_node) YIELD rel
|
||||
RETURN rel
|
||||
"""
|
||||
|
||||
for edge in edges:
|
||||
from_node, to_node, relationship_name, edge_properties = edge
|
||||
from_node = from_node.replace(":", "_")
|
||||
to_node = to_node.replace(":", "_")
|
||||
edges = [{
|
||||
"from_node": edge[0],
|
||||
"to_node": edge[1],
|
||||
"relationship_name": edge[2],
|
||||
"properties": {
|
||||
**(edge[3] if edge[3] else {}),
|
||||
"source_node_id": edge[0],
|
||||
"target_node_id": edge[1],
|
||||
},
|
||||
} for edge in edges]
|
||||
|
||||
await self.add_edge(
|
||||
from_node = from_node,
|
||||
to_node = to_node,
|
||||
relationship_name = relationship_name,
|
||||
edge_properties = edge_properties
|
||||
)
|
||||
results = await self.query(query, dict(edges = edges))
|
||||
return results
|
||||
|
||||
async def get_edges(self, node_id: str):
|
||||
query = """
|
||||
MATCH (n {id: $node_id})-[r]-(m)
|
||||
RETURN n, r, m
|
||||
"""
|
||||
|
||||
results = await self.query(query, dict(node_id = node_id))
|
||||
|
||||
return [(result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]}) for result in results]
|
||||
|
||||
async def get_disconnected_nodes(self) -> list[str]:
|
||||
# return await self.query(
|
||||
# "MATCH (node) WHERE NOT (node)<-[:*]-() RETURN node.id as id",
|
||||
# )
|
||||
query = """
|
||||
// Step 1: Collect all nodes
|
||||
MATCH (n)
|
||||
WITH COLLECT(n) AS nodes
|
||||
|
||||
// Step 2: Find all connected components
|
||||
WITH nodes
|
||||
CALL {
|
||||
WITH nodes
|
||||
UNWIND nodes AS startNode
|
||||
MATCH path = (startNode)-[*]-(connectedNode)
|
||||
WITH startNode, COLLECT(DISTINCT connectedNode) AS component
|
||||
RETURN component
|
||||
}
|
||||
|
||||
// Step 3: Aggregate components
|
||||
WITH COLLECT(component) AS components
|
||||
|
||||
// Step 4: Identify the largest connected component
|
||||
UNWIND components AS component
|
||||
WITH component
|
||||
ORDER BY SIZE(component) DESC
|
||||
LIMIT 1
|
||||
WITH component AS largestComponent
|
||||
|
||||
// Step 5: Find nodes not in the largest connected component
|
||||
MATCH (n)
|
||||
WHERE NOT n IN largestComponent
|
||||
RETURN COLLECT(ID(n)) AS ids
|
||||
"""
|
||||
|
||||
results = await self.query(query)
|
||||
return results[0]["ids"] if len(results) > 0 else []
|
||||
|
||||
|
||||
async def filter_nodes(self, search_criteria):
|
||||
|
|
@ -170,10 +255,99 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
WHERE node.id CONTAINS '{search_criteria}'
|
||||
RETURN node"""
|
||||
|
||||
|
||||
return await self.query(query)
|
||||
|
||||
|
||||
async def get_predecessor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||
if edge_label is not None:
|
||||
query = """
|
||||
MATCH (node:`{node_id}`)-[r:`{edge_label}`]->(predecessor)
|
||||
RETURN predecessor.id AS id
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
query,
|
||||
dict(
|
||||
node_id = node_id,
|
||||
edge_label = edge_label,
|
||||
)
|
||||
)
|
||||
|
||||
return [result["id"] for result in results]
|
||||
else:
|
||||
query = """
|
||||
MATCH (node:`{node_id}`)-[r]->(predecessor)
|
||||
RETURN predecessor.id AS id
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
query,
|
||||
dict(
|
||||
node_id = node_id,
|
||||
)
|
||||
)
|
||||
|
||||
return [result["id"] for result in results]
|
||||
|
||||
async def get_successor_ids(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||
if edge_label is not None:
|
||||
query = """
|
||||
MATCH (node:`{node_id}`)<-[r:`{edge_label}`]-(successor)
|
||||
RETURN successor.id AS id
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
query,
|
||||
dict(
|
||||
node_id = node_id,
|
||||
edge_label = edge_label,
|
||||
),
|
||||
)
|
||||
|
||||
return [result["id"] for result in results]
|
||||
else:
|
||||
query = """
|
||||
MATCH (node:`{node_id}`)<-[r]-(successor)
|
||||
RETURN successor.id AS id
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
query,
|
||||
dict(
|
||||
node_id = node_id,
|
||||
)
|
||||
)
|
||||
|
||||
return [result["id"] for result in results]
|
||||
|
||||
async def get_neighbours(self, node_id: str) -> list[str]:
|
||||
results = await asyncio.gather(*[self.get_predecessor_ids(node_id)], self.get_successor_ids(node_id))
|
||||
|
||||
return [*results[0], *results[1]]
|
||||
|
||||
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
|
||||
query = f"""
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (node:`{id}`)-[r:{edge_label}]->(predecessor)
|
||||
DELETE r;
|
||||
"""
|
||||
|
||||
params = { "node_ids": node_ids }
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str) -> None:
|
||||
query = f"""
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (node:`{id}`)<-[r:{edge_label}]-(successor)
|
||||
DELETE r;
|
||||
"""
|
||||
|
||||
params = { "node_ids": node_ids }
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
|
||||
async def delete_graph(self):
|
||||
query = """MATCH (node)
|
||||
DETACH DELETE node;"""
|
||||
|
|
@ -186,3 +360,25 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
if isinstance(property_value, (dict, list))
|
||||
else property_value for property_key, property_value in properties.items()
|
||||
}
|
||||
|
||||
async def get_graph_data(self):
|
||||
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
||||
result = await self.query(query)
|
||||
nodes = [(
|
||||
record["properties"]["id"],
|
||||
record["properties"],
|
||||
) for record in result]
|
||||
|
||||
query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||
"""
|
||||
result = await self.query(query)
|
||||
edges = [(
|
||||
record["properties"]["source_node_id"],
|
||||
record["properties"]["target_node_id"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
) for record in result]
|
||||
|
||||
return (nodes, edges)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
async def get_graph(self):
|
||||
return self.graph
|
||||
|
||||
|
||||
async def add_edge(
|
||||
self,
|
||||
from_node: str,
|
||||
|
|
@ -62,12 +62,31 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
self.graph.add_edges_from(edges)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def get_edges(self, node_id: str):
|
||||
return list(self.graph.in_edges(node_id, data = True)) + list(self.graph.out_edges(node_id, data = True))
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""Asynchronously delete a node from the graph if it exists."""
|
||||
if self.graph.has_node(id):
|
||||
self.graph.remove_node(id)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def delete_nodes(self, node_ids: List[str]) -> None:
|
||||
self.graph.remove_nodes_from(node_ids)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def get_disconnected_nodes(self) -> List[str]:
|
||||
connected_components = list(nx.weakly_connected_components(self.graph))
|
||||
|
||||
disconnected_nodes = []
|
||||
biggest_subgraph = max(connected_components, key = len)
|
||||
|
||||
for component in connected_components:
|
||||
if component != biggest_subgraph:
|
||||
disconnected_nodes.extend(list(component))
|
||||
|
||||
return disconnected_nodes
|
||||
|
||||
async def extract_node_description(self, node_id: str) -> Dict[str, Any]:
|
||||
descriptions = []
|
||||
|
||||
|
|
@ -98,11 +117,69 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
async def extract_node(self, node_id: str) -> dict:
|
||||
if self.graph.has_node(node_id):
|
||||
|
||||
return self.graph.nodes[node_id]
|
||||
|
||||
return None
|
||||
|
||||
async def extract_nodes(self, node_ids: List[str]) -> List[dict]:
|
||||
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
|
||||
|
||||
async def get_predecessor_ids(self, node_id: str, edge_label: str = None) -> list:
|
||||
if self.graph.has_node(node_id):
|
||||
if edge_label is None:
|
||||
return list(self.graph.predecessors(node_id))
|
||||
|
||||
nodes = []
|
||||
|
||||
for predecessor_id in list(self.graph.predecessors(node_id)):
|
||||
if self.graph.has_edge(predecessor_id, node_id, edge_label):
|
||||
nodes.append(predecessor_id)
|
||||
|
||||
return nodes
|
||||
|
||||
async def get_successor_ids(self, node_id: str, edge_label: str = None) -> list:
|
||||
if self.graph.has_node(node_id):
|
||||
if edge_label is None:
|
||||
return list(self.graph.successors(node_id))
|
||||
|
||||
nodes = []
|
||||
|
||||
for successor_id in list(self.graph.successors(node_id)):
|
||||
if self.graph.has_edge(node_id, successor_id, edge_label):
|
||||
nodes.append(successor_id)
|
||||
|
||||
return nodes
|
||||
|
||||
async def get_neighbours(self, node_id: str) -> list:
|
||||
if not self.graph.has_node(node_id):
|
||||
return []
|
||||
|
||||
neighbour_ids = list(self.graph.neighbors(node_id))
|
||||
|
||||
if len(neighbour_ids) == 0:
|
||||
return []
|
||||
|
||||
nodes = await self.extract_nodes(neighbour_ids)
|
||||
|
||||
return nodes
|
||||
|
||||
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
|
||||
for node_id in node_ids:
|
||||
if self.graph.has_node(node_id):
|
||||
for predecessor_id in list(self.graph.predecessors(node_id)):
|
||||
if self.graph.has_edge(predecessor_id, node_id, edge_label):
|
||||
self.graph.remove_edge(predecessor_id, node_id, edge_label)
|
||||
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def remove_connection_to_successors_of(self, node_ids: list[str], edge_label: str) -> None:
|
||||
for node_id in node_ids:
|
||||
if self.graph.has_node(node_id):
|
||||
for successor_id in list(self.graph.successors(node_id)):
|
||||
if self.graph.has_edge(node_id, successor_id, edge_label):
|
||||
self.graph.remove_edge(node_id, successor_id, edge_label)
|
||||
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def save_graph_to_file(self, file_path: str=None) -> None:
|
||||
"""Asynchronously save the graph to a file in JSON format."""
|
||||
|
|
@ -114,6 +191,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
async with aiofiles.open(file_path, "w") as file:
|
||||
await file.write(json.dumps(graph_data))
|
||||
|
||||
|
||||
async def load_graph_from_file(self, file_path: str = None):
|
||||
"""Asynchronously load the graph from a file in JSON format."""
|
||||
if file_path == self.filename:
|
||||
|
|
|
|||
|
|
@ -2,10 +2,9 @@ import duckdb
|
|||
import os
|
||||
class DuckDBAdapter():
|
||||
def __init__(self, db_path: str, db_name: str):
|
||||
self.db_location = os.path.abspath(os.path.join(db_path, db_name))
|
||||
|
||||
db_location = os.path.abspath(os.path.join(db_path, db_name))
|
||||
|
||||
self.get_connection = lambda: duckdb.connect(db_location)
|
||||
self.get_connection = lambda: duckdb.connect(self.db_location)
|
||||
|
||||
def get_datasets(self):
|
||||
with self.get_connection() as connection:
|
||||
|
|
@ -20,7 +19,7 @@ class DuckDBAdapter():
|
|||
|
||||
def get_files_metadata(self, dataset_name: str):
|
||||
with self.get_connection() as connection:
|
||||
return connection.sql(f"SELECT id, name, file_path, extension, mime_type, keywords FROM {dataset_name}.file_metadata;").to_df().to_dict("records")
|
||||
return connection.sql(f"SELECT id, name, file_path, extension, mime_type FROM {dataset_name}.file_metadata;").to_df().to_dict("records")
|
||||
|
||||
def create_table(self, schema_name: str, table_name: str, table_config: list[dict]):
|
||||
fields_query_parts = []
|
||||
|
|
@ -163,3 +162,11 @@ class DuckDBAdapter():
|
|||
connection.sql(select_data_sql)
|
||||
drop_data_sql = "DROP TABLE cognify;"
|
||||
connection.sql(drop_data_sql)
|
||||
|
||||
def delete_database(self):
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
|
||||
LocalStorage.remove(self.db_location)
|
||||
|
||||
if LocalStorage.file_exists(self.db_location + ".wal"):
|
||||
LocalStorage.remove(self.db_location + ".wal")
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class FalcorDBAdapter(VectorDBInterface):
|
|||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
pass
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_id: str):
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
pass
|
||||
|
||||
async def search(
|
||||
|
|
@ -51,4 +51,7 @@ class FalcorDBAdapter(VectorDBInterface):
|
|||
limit: int = 10,
|
||||
with_vector: bool = False,
|
||||
):
|
||||
pass
|
||||
pass
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
async def collection_exists(self, collection_name: str) -> bool:
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
connection = await self.get_connection()
|
||||
collection_names = await connection.table_names()
|
||||
return collection_name in collection_names
|
||||
|
|
@ -47,7 +47,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
vector: Vector(vector_size)
|
||||
payload: payload_schema
|
||||
|
||||
if not await self.collection_exists(collection_name):
|
||||
if not await self.has_collection(collection_name):
|
||||
connection = await self.get_connection()
|
||||
return await connection.create_table(
|
||||
name = collection_name,
|
||||
|
|
@ -58,7 +58,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
connection = await self.get_connection()
|
||||
|
||||
if not await self.collection_exists(collection_name):
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name,
|
||||
payload_schema = type(data_points[0].payload),
|
||||
|
|
@ -89,17 +89,20 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
|
||||
await collection.add(lance_data_points)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_id: str):
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
connection = await self.get_connection()
|
||||
collection = await connection.open_table(collection_name)
|
||||
results = await collection.query().where(f"id = '{data_point_id}'").to_pandas()
|
||||
result = results.to_dict("index")[0]
|
||||
|
||||
return ScoredResult(
|
||||
if len(data_point_ids) == 1:
|
||||
results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
|
||||
else:
|
||||
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
|
||||
|
||||
return [ScoredResult(
|
||||
id = result["id"],
|
||||
payload = result["payload"],
|
||||
score = 1,
|
||||
)
|
||||
) for result in results.to_dict("index").values()]
|
||||
|
||||
async def search(
|
||||
self,
|
||||
|
|
@ -122,8 +125,8 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
|
||||
return [ScoredResult(
|
||||
id = str(result["id"]),
|
||||
score = float(result["_distance"]),
|
||||
payload = result["payload"],
|
||||
score = float(result["_distance"]),
|
||||
) for result in results.to_dict("index").values()]
|
||||
|
||||
async def batch_search(
|
||||
|
|
@ -144,6 +147,12 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
) for query_vector in query_vectors]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
connection = await self.get_connection()
|
||||
collection = await connection.open_table(collection_name)
|
||||
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
|
||||
return results
|
||||
|
||||
async def prune(self):
|
||||
# Clean up the database if it was set up as temporary
|
||||
if self.url.startswith("/"):
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class QDrantAdapter(VectorDBInterface):
|
|||
async def embed_data(self, data: List[str]) -> List[float]:
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
async def collection_exists(self, collection_name: str) -> bool:
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
client = self.get_qdrant_client()
|
||||
result = await client.collection_exists(collection_name)
|
||||
await client.close()
|
||||
|
|
@ -111,11 +111,11 @@ class QDrantAdapter(VectorDBInterface):
|
|||
|
||||
return result
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_id: str):
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
client = self.get_qdrant_client()
|
||||
results = await client.retrieve(collection_name, [data_point_id], with_payload = True)
|
||||
results = await client.retrieve(collection_name, data_point_ids, with_payload = True)
|
||||
await client.close()
|
||||
return results[0] if len(results) > 0 else None
|
||||
return results
|
||||
|
||||
async def search(
|
||||
self,
|
||||
|
|
@ -185,6 +185,11 @@ class QDrantAdapter(VectorDBInterface):
|
|||
|
||||
return [filter(lambda result: result.score > 0.9, result_group) for result_group in results]
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
client = self.get_qdrant_client()
|
||||
results = await client.delete(collection_name, data_point_ids)
|
||||
return results
|
||||
|
||||
async def prune(self):
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from .models.PayloadSchema import PayloadSchema
|
|||
class VectorDBInterface(Protocol):
|
||||
""" Collections """
|
||||
@abstractmethod
|
||||
async def collection_exists(self, collection_name: str) -> bool:
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -28,7 +28,7 @@ class VectorDBInterface(Protocol):
|
|||
async def retrieve(
|
||||
self,
|
||||
collection_name: str,
|
||||
data_point_id: str
|
||||
data_point_ids: list[str]
|
||||
): raise NotImplementedError
|
||||
|
||||
""" Search """
|
||||
|
|
@ -51,3 +51,13 @@ class VectorDBInterface(Protocol):
|
|||
limit: int,
|
||||
with_vectors: bool = False
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete_data_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
data_point_ids: list[str]
|
||||
): raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def prune(self): raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import asyncio
|
||||
from uuid import UUID
|
||||
from typing import List, Optional
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..models.DataPoint import DataPoint
|
||||
|
|
@ -31,7 +30,7 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
async def embed_data(self, data: List[str]) -> List[float]:
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
async def collection_exists(self, collection_name: str) -> bool:
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
future = asyncio.Future()
|
||||
|
||||
future.set_result(self.client.collections.exists(collection_name))
|
||||
|
|
@ -72,25 +71,41 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
|
||||
|
||||
def convert_to_weaviate_data_points(data_point: DataPoint):
|
||||
vector = data_vectors[data_points.index(data_point)]
|
||||
return DataObject(
|
||||
uuid = data_point.id,
|
||||
properties = data_point.payload.dict(),
|
||||
vector = data_vectors[data_points.index(data_point)]
|
||||
vector = vector
|
||||
)
|
||||
|
||||
|
||||
objects = list(map(convert_to_weaviate_data_points, data_points))
|
||||
|
||||
return self.get_collection(collection_name).data.insert_many(objects)
|
||||
collection = self.get_collection(collection_name)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_id: str):
|
||||
with collection.batch.dynamic() as batch:
|
||||
for data_row in objects:
|
||||
batch.add_object(
|
||||
properties = data_row.properties,
|
||||
vector = data_row.vector
|
||||
)
|
||||
|
||||
return
|
||||
# return self.get_collection(collection_name).data.insert_many(objects)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
from weaviate.classes.query import Filter
|
||||
future = asyncio.Future()
|
||||
|
||||
data_point = self.get_collection(collection_name).query.fetch_object_by_id(UUID(data_point_id))
|
||||
data_points = self.get_collection(collection_name).query.fetch_objects(
|
||||
filters = Filter.by_id().contains_any(data_point_ids)
|
||||
)
|
||||
|
||||
data_point.payload = data_point.properties
|
||||
del data_point.properties
|
||||
for data_point in data_points:
|
||||
data_point.payload = data_point.properties
|
||||
del data_point.properties
|
||||
|
||||
future.set_result(data_point)
|
||||
future.set_result(data_points)
|
||||
|
||||
return await future
|
||||
|
||||
|
|
@ -131,6 +146,17 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
return self.search(collection_name, query_vector=query_vector, limit=limit, with_vector=with_vectors)
|
||||
|
||||
return [await query_search(query_vector) for query_vector in await self.embed_data(query_texts)]
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
from weaviate.classes.query import Filter
|
||||
future = asyncio.Future()
|
||||
|
||||
result = self.get_collection(collection_name).data.delete_many(
|
||||
filters = Filter.by_id().contains_any(data_point_ids)
|
||||
)
|
||||
future.set_result(result)
|
||||
|
||||
return await future
|
||||
|
||||
async def prune(self):
|
||||
self.client.collections.delete_all()
|
||||
|
|
|
|||
|
|
@ -32,13 +32,19 @@ class LocalStorage(Storage):
|
|||
f.seek(0)
|
||||
return f.read()
|
||||
|
||||
@staticmethod
|
||||
def file_exists(file_path: str):
|
||||
return os.path.exists(file_path)
|
||||
|
||||
@staticmethod
|
||||
def ensure_directory_exists(file_path: str):
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(file_path, exist_ok = True)
|
||||
|
||||
def remove(self, file_path: str):
|
||||
os.remove(self.storage_path + "/" + file_path)
|
||||
@staticmethod
|
||||
def remove(file_path: str):
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
@staticmethod
|
||||
def copy_file(source_file_path: str, destination_file_path: str):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ class Storage(Protocol):
|
|||
def retrieve(self, file_path: str):
|
||||
pass
|
||||
|
||||
def remove(self, file_path: str):
|
||||
@staticmethod
|
||||
def remove(file_path: str):
|
||||
pass
|
||||
|
||||
class StorageManager():
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
from typing import BinaryIO, TypedDict
|
||||
from cognee.infrastructure.data.utils.extract_keywords import extract_keywords
|
||||
from .extract_text_from_file import extract_text_from_file
|
||||
from .guess_file_type import guess_file_type
|
||||
|
||||
|
||||
|
|
@ -8,24 +6,12 @@ class FileMetadata(TypedDict):
|
|||
name: str
|
||||
mime_type: str
|
||||
extension: str
|
||||
keywords: list[str]
|
||||
|
||||
def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
||||
"""Get metadata from a file"""
|
||||
file.seek(0)
|
||||
file_type = guess_file_type(file)
|
||||
|
||||
file.seek(0)
|
||||
file_text = extract_text_from_file(file, file_type)
|
||||
|
||||
import uuid
|
||||
|
||||
try:
|
||||
keywords = extract_keywords(file_text)
|
||||
except:
|
||||
keywords = ["no keywords detected" + str(uuid.uuid4())]
|
||||
|
||||
|
||||
file_path = file.name
|
||||
file_name = file_path.split("/")[-1].split(".")[0] if file_path else None
|
||||
|
||||
|
|
@ -34,5 +20,4 @@ def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
|||
file_path = file_path,
|
||||
mime_type = file_type.mime,
|
||||
extension = file_type.extension,
|
||||
keywords = keywords
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ class AnthropicAdapter(LLMInterface):
|
|||
)
|
||||
self.model = model
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
async def acreate_structured_output(
|
||||
self,
|
||||
text_input: str,
|
||||
|
|
@ -31,7 +30,7 @@ class AnthropicAdapter(LLMInterface):
|
|||
return await self.aclient(
|
||||
model = self.model,
|
||||
max_tokens = 4096,
|
||||
max_retries = 0,
|
||||
max_retries = 5,
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ class LLMConfig(BaseSettings):
|
|||
llm_api_key: Optional[str] = None
|
||||
llm_temperature: float = 0.0
|
||||
llm_streaming: bool = False
|
||||
transcription_model: str = "whisper-1"
|
||||
|
||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||
|
||||
|
|
@ -18,6 +19,9 @@ class LLMConfig(BaseSettings):
|
|||
"model": self.llm_model,
|
||||
"endpoint": self.llm_endpoint,
|
||||
"apiKey": self.llm_api_key,
|
||||
"temperature": self.llm_temperature,
|
||||
"streaming": self.llm_stream,
|
||||
"transcriptionModel": self.transcription_model
|
||||
}
|
||||
|
||||
@lru_cache
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def get_llm_client():
|
|||
raise ValueError("LLM API key is not set.")
|
||||
|
||||
from .openai.adapter import OpenAIAdapter
|
||||
return OpenAIAdapter(llm_config.llm_api_key, llm_config.llm_model, llm_config.llm_streaming)
|
||||
return OpenAIAdapter(api_key=llm_config.llm_api_key, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, streaming=llm_config.llm_streaming)
|
||||
elif provider == LLMProvider.OLLAMA:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise ValueError("LLM API key is not set.")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Type
|
||||
|
||||
import aiofiles
|
||||
import openai
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -9,6 +14,8 @@ from cognee.base_config import get_base_config
|
|||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.shared.data_models import MonitoringTool
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
class OpenAIAdapter(LLMInterface):
|
||||
name = "OpenAI"
|
||||
|
|
@ -16,20 +23,22 @@ class OpenAIAdapter(LLMInterface):
|
|||
api_key: str
|
||||
|
||||
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
||||
def __init__(self, api_key: str, model: str, streaming: bool = False):
|
||||
def __init__(self, api_key: str, model: str, transcription_model:str, streaming: bool = False):
|
||||
base_config = get_base_config()
|
||||
|
||||
if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
|
||||
from langfuse.openai import AsyncOpenAI, OpenAI
|
||||
elif base_config.monitoring_tool == MonitoringTool.LANGSMITH:
|
||||
from langsmith import wrappers
|
||||
from openai import AsyncOpenAI
|
||||
AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
|
||||
else:
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
# if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
|
||||
# from langfuse.openai import AsyncOpenAI, OpenAI
|
||||
# elif base_config.monitoring_tool == MonitoringTool.LANGSMITH:
|
||||
# from langsmith import wrappers
|
||||
# from openai import AsyncOpenAI
|
||||
# AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
|
||||
# else:
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
self.aclient = instructor.from_openai(AsyncOpenAI(api_key = api_key))
|
||||
self.client = instructor.from_openai(OpenAI(api_key = api_key))
|
||||
self.base_openai_client = OpenAI(api_key = api_key)
|
||||
self.transcription_model = "whisper-1"
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.streaming = streaming
|
||||
|
|
@ -120,6 +129,49 @@ class OpenAIAdapter(LLMInterface):
|
|||
response_model = response_model,
|
||||
)
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
def create_transcript(self, input):
|
||||
"""Generate a audio transcript from a user query."""
|
||||
|
||||
if not os.path.isfile(input):
|
||||
raise FileNotFoundError(f"The file {input} does not exist.")
|
||||
|
||||
with open(input, 'rb') as audio_file:
|
||||
audio_data = audio_file.read()
|
||||
|
||||
|
||||
|
||||
transcription = self.base_openai_client.audio.transcriptions.create(
|
||||
model=self.transcription_model ,
|
||||
file=Path(input),
|
||||
)
|
||||
|
||||
return transcription
|
||||
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
def transcribe_image(self, input) -> BaseModel:
|
||||
with open(input, "rb") as image_file:
|
||||
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
return self.base_openai_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What’s in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{encoded_image}",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=300,
|
||||
)
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""Format and display the prompt for a user query."""
|
||||
if not text_input:
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
Chose the summary that is the most relevant to the query`{{ query }}`
|
||||
Here are the summaries:`{{ summaries }}`
|
||||
Chose the summaries that are relevant to the following query: `{{ query }}`
|
||||
Here are the all summaries: `{{ summaries }}`
|
||||
|
|
|
|||
|
|
@ -1,36 +1,26 @@
|
|||
You are a top-tier algorithm
|
||||
designed for extracting information in structured formats to build a knowledge graph.
|
||||
- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
|
||||
- **Edges** represent relationships between concepts. They're akin to Wikipedia links.
|
||||
- The aim is to achieve simplicity and clarity in the
|
||||
knowledge graph, making it accessible for a vast audience.
|
||||
YOU ARE ONLY EXTRACTING DATA FOR COGNITIVE LAYER `{{ layer }}`
|
||||
## 1. Labeling Nodes
|
||||
- **Consistency**: Ensure you use basic or elementary types for node labels.
|
||||
- For example, when you identify an entity representing a person,
|
||||
always label it as **"Person"**.
|
||||
Avoid using more specific terms like "mathematician" or "scientist".
|
||||
- Include event, entity, time, or action nodes to the category.
|
||||
- Classify the memory type as episodic or semantic.
|
||||
- **Node IDs**: Never utilize integers as node IDs.
|
||||
Node IDs should be names or human-readable identifiers found in the text.
|
||||
## 2. Handling Numerical Data and Dates
|
||||
- Numerical data, like age or other related information,
|
||||
should be incorporated as attributes or properties of the respective nodes.
|
||||
- **No Separate Nodes for Dates/Numbers**:
|
||||
Do not create separate nodes for dates or numerical values.
|
||||
Always attach them as attributes or properties of nodes.
|
||||
- **Property Format**: Properties must be in a key-value format.
|
||||
- **Quotation Marks**: Never use escaped single or double quotes within property values.
|
||||
- **Naming Convention**: Use snake_case for relationship names, e.g., `acted_in`.
|
||||
## 3. Coreference Resolution
|
||||
- **Maintain Entity Consistency**:
|
||||
When extracting entities, it's vital to ensure consistency.
|
||||
If an entity, such as "John Doe", is mentioned multiple times
|
||||
in the text but is referred to by different names or pronouns (e.g., "Joe", "he"),
|
||||
always use the most complete identifier for that entity throughout the knowledge graph.
|
||||
In this example, use "John Doe" as the entity ID.
|
||||
Remember, the knowledge graph should be coherent and easily understandable,
|
||||
so maintaining consistency in entity references is crucial.
|
||||
## 4. Strict Compliance
|
||||
Adhere to the rules strictly. Non-compliance will result in termination"""
|
||||
You are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.
|
||||
**Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
|
||||
**Edges** represent relationships between concepts. They're akin to Wikipedia links.
|
||||
|
||||
The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.
|
||||
# 1. Labeling Nodes
|
||||
**Consistency**: Ensure you use basic or elementary types for node labels.
|
||||
- For example, when you identify an entity representing a person, always label it as **"Person"**.
|
||||
- Avoid using more specific terms like "Mathematician" or "Scientist".
|
||||
- Don't use too generic terms like "Entity".
|
||||
**Node IDs**: Never utilize integers as node IDs.
|
||||
- Node IDs should be names or human-readable identifiers found in the text.
|
||||
# 2. Handling Numerical Data and Dates
|
||||
- For example, when you identify an entity representing a date, always label it as **"Date"**.
|
||||
- Extract the date in the format "YYYY-MM-DD"
|
||||
- If not possible to extract the whole date, extract month or year, or both if available.
|
||||
- **Property Format**: Properties must be in a key-value format.
|
||||
- **Quotation Marks**: Never use escaped single or double quotes within property values.
|
||||
- **Naming Convention**: Use snake_case for relationship names, e.g., `acted_in`.
|
||||
# 3. Coreference Resolution
|
||||
- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.
|
||||
If an entity, such as "John Doe", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., "Joe", "he"),
|
||||
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the Persons ID.
|
||||
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
|
||||
# 4. Strict Compliance
|
||||
Adhere to the rules strictly. Non-compliance will result in termination"""
|
||||
|
|
|
|||
|
|
@ -1 +1,3 @@
|
|||
You are a summarization engine and you should sumamarize content. Be brief and concise
|
||||
You are a top-tier summarization engine. Your task is to summarize text and make it versatile.
|
||||
Be brief and concise, but keep the important information and the subject.
|
||||
Use synonim words where possible in order to change the wording but keep the meaning.
|
||||
|
|
|
|||
0
cognee/modules/classification/__init__.py
Normal file
0
cognee/modules/classification/__init__.py
Normal file
152
cognee/modules/classification/classify_text_chunks.py
Normal file
152
cognee/modules/classification/classify_text_chunks.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
|
||||
import asyncio
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
from ..data.extraction.extract_categories import extract_categories
|
||||
|
||||
async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
|
||||
if len(data_chunks) == 0:
|
||||
return data_chunks
|
||||
|
||||
chunk_classifications = await asyncio.gather(
|
||||
*[extract_categories(chunk.text, classification_model) for chunk in data_chunks],
|
||||
)
|
||||
|
||||
classification_data_points = []
|
||||
|
||||
for chunk_index, chunk in enumerate(data_chunks):
|
||||
chunk_classification = chunk_classifications[chunk_index]
|
||||
classification_data_points.append(uuid5(NAMESPACE_OID, chunk_classification.label.type))
|
||||
classification_data_points.append(uuid5(NAMESPACE_OID, chunk_classification.label.type))
|
||||
|
||||
for classification_subclass in chunk_classification.label.subclass:
|
||||
classification_data_points.append(uuid5(NAMESPACE_OID, classification_subclass.value))
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
class Keyword(BaseModel):
|
||||
id: str
|
||||
text: str
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
|
||||
collection_name = "classification"
|
||||
|
||||
if await vector_engine.has_collection(collection_name):
|
||||
existing_data_points = await vector_engine.retrieve(
|
||||
collection_name,
|
||||
list(set(classification_data_points)),
|
||||
) if len(classification_data_points) > 0 else []
|
||||
|
||||
existing_points_map = {point.id: True for point in existing_data_points}
|
||||
else:
|
||||
existing_points_map = {}
|
||||
await vector_engine.create_collection(collection_name, payload_schema = Keyword)
|
||||
|
||||
data_points = []
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
for (chunk_index, data_chunk) in enumerate(data_chunks):
|
||||
chunk_classification = chunk_classifications[chunk_index]
|
||||
classification_type_label = chunk_classification.label.type
|
||||
classification_type_id = uuid5(NAMESPACE_OID, classification_type_label)
|
||||
|
||||
if classification_type_id not in existing_points_map:
|
||||
data_points.append(
|
||||
DataPoint[Keyword](
|
||||
id = str(classification_type_id),
|
||||
payload = Keyword.parse_obj({
|
||||
"id": str(classification_type_id),
|
||||
"text": classification_type_label,
|
||||
"chunk_id": str(data_chunk.chunk_id),
|
||||
"document_id": str(data_chunk.document_id),
|
||||
}),
|
||||
embed_field = "text",
|
||||
)
|
||||
)
|
||||
|
||||
nodes.append((
|
||||
str(classification_type_id),
|
||||
dict(
|
||||
id = str(classification_type_id),
|
||||
name = classification_type_label,
|
||||
type = classification_type_label,
|
||||
)
|
||||
))
|
||||
existing_points_map[classification_type_id] = True
|
||||
|
||||
edges.append((
|
||||
str(data_chunk.chunk_id),
|
||||
str(classification_type_id),
|
||||
"is_media_type",
|
||||
dict(
|
||||
relationship_name = "is_media_type",
|
||||
source_node_id = str(data_chunk.chunk_id),
|
||||
target_node_id = str(classification_type_id),
|
||||
),
|
||||
))
|
||||
|
||||
for classification_subclass in chunk_classification.label.subclass:
|
||||
classification_subtype_label = classification_subclass.value
|
||||
classification_subtype_id = uuid5(NAMESPACE_OID, classification_subtype_label)
|
||||
|
||||
if classification_subtype_id not in existing_points_map:
|
||||
data_points.append(
|
||||
DataPoint[Keyword](
|
||||
id = str(classification_subtype_id),
|
||||
payload = Keyword.parse_obj({
|
||||
"id": str(classification_subtype_id),
|
||||
"text": classification_subtype_label,
|
||||
"chunk_id": str(data_chunk.chunk_id),
|
||||
"document_id": str(data_chunk.document_id),
|
||||
}),
|
||||
embed_field = "text",
|
||||
)
|
||||
)
|
||||
|
||||
nodes.append((
|
||||
str(classification_subtype_id),
|
||||
dict(
|
||||
id = str(classification_subtype_id),
|
||||
name = classification_subtype_label,
|
||||
type = classification_subtype_label,
|
||||
)
|
||||
))
|
||||
edges.append((
|
||||
str(classification_type_id),
|
||||
str(classification_subtype_id),
|
||||
"contains",
|
||||
dict(
|
||||
relationship_name = "contains",
|
||||
source_node_id = str(classification_type_id),
|
||||
target_node_id = str(classification_subtype_id),
|
||||
),
|
||||
))
|
||||
|
||||
existing_points_map[classification_subtype_id] = True
|
||||
|
||||
edges.append((
|
||||
str(data_chunk.chunk_id),
|
||||
str(classification_subtype_id),
|
||||
"is_classified_as",
|
||||
dict(
|
||||
relationship_name = "is_classified_as",
|
||||
source_node_id = str(data_chunk.chunk_id),
|
||||
target_node_id = str(classification_subtype_id),
|
||||
),
|
||||
))
|
||||
|
||||
if len(nodes) > 0 or len(edges) > 0:
|
||||
await vector_engine.create_data_points(collection_name, data_points)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
await graph_engine.add_nodes(nodes)
|
||||
await graph_engine.add_edges(edges)
|
||||
|
||||
return data_chunks
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import uuid
|
||||
|
||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.shared.data_models import GraphDBType
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
|
||||
|
|
@ -92,7 +92,7 @@ def graph_ready_output(results):
|
|||
if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
graph_client = await get_graph_client()
|
||||
graph_client = await get_graph_engine()
|
||||
graph = graph_client.graph
|
||||
|
||||
# for nodes, attr in graph.nodes(data=True):
|
||||
|
|
|
|||
0
cognee/modules/cognify/graph/save_chunk_relationships.py
Normal file
0
cognee/modules/cognify/graph/save_chunk_relationships.py
Normal file
18
cognee/modules/cognify/graph/save_document_node.py
Normal file
18
cognee/modules/cognify/graph/save_document_node.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from uuid import UUID
|
||||
from cognee.shared.data_models import Document
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
async def save_document_node(document: Document, parent_node_id: UUID = None):
|
||||
graph_engine = get_graph_engine()
|
||||
|
||||
await graph_engine.add_node(document.id, document.model_dump())
|
||||
|
||||
if parent_node_id:
|
||||
await graph_engine.add_edge(
|
||||
parent_node_id,
|
||||
document.id,
|
||||
"has_document",
|
||||
dict(relationship_name = "has_document"),
|
||||
)
|
||||
|
||||
return document
|
||||
1
cognee/modules/cognify/vector/__init__.py
Normal file
1
cognee/modules/cognify/vector/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .save_data_chunks import save_data_chunks
|
||||
97
cognee/modules/cognify/vector/save_data_chunks.py
Normal file
97
cognee/modules/cognify/vector/save_data_chunks.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
|
||||
async def save_data_chunks(data_chunks: list[DocumentChunk], collection_name: str):
|
||||
if len(data_chunks) == 0:
|
||||
return data_chunks
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
# Remove and unlink existing chunks
|
||||
if await vector_engine.has_collection(collection_name):
|
||||
existing_chunks = [DocumentChunk.parse_obj(chunk.payload) for chunk in (await vector_engine.retrieve(
|
||||
collection_name,
|
||||
[str(chunk.chunk_id) for chunk in data_chunks],
|
||||
))]
|
||||
|
||||
if len(existing_chunks) > 0:
|
||||
await vector_engine.delete_data_points(collection_name, [str(chunk.chunk_id) for chunk in existing_chunks])
|
||||
|
||||
await graph_engine.remove_connection_to_successors_of([chunk.chunk_id for chunk in existing_chunks], "next_chunk")
|
||||
await graph_engine.remove_connection_to_predecessors_of([chunk.chunk_id for chunk in existing_chunks], "has_chunk")
|
||||
else:
|
||||
await vector_engine.create_collection(collection_name, payload_schema = DocumentChunk)
|
||||
|
||||
# Add to vector storage
|
||||
await vector_engine.create_data_points(
|
||||
collection_name,
|
||||
[
|
||||
DataPoint[DocumentChunk](
|
||||
id = str(chunk.chunk_id),
|
||||
payload = chunk,
|
||||
embed_field = "text",
|
||||
) for chunk in data_chunks
|
||||
],
|
||||
)
|
||||
|
||||
# Add to graph storage
|
||||
chunk_nodes = []
|
||||
chunk_edges = []
|
||||
|
||||
for chunk in data_chunks:
|
||||
chunk_nodes.append((
|
||||
str(chunk.chunk_id),
|
||||
dict(
|
||||
id = str(chunk.chunk_id),
|
||||
chunk_id = str(chunk.chunk_id),
|
||||
document_id = str(chunk.document_id),
|
||||
word_count = chunk.word_count,
|
||||
chunk_index = chunk.chunk_index,
|
||||
cut_type = chunk.cut_type,
|
||||
pages = chunk.pages,
|
||||
)
|
||||
))
|
||||
|
||||
chunk_edges.append((
|
||||
str(chunk.document_id),
|
||||
str(chunk.chunk_id),
|
||||
"has_chunk",
|
||||
dict(
|
||||
relationship_name = "has_chunk",
|
||||
source_node_id = str(chunk.document_id),
|
||||
target_node_id = str(chunk.chunk_id),
|
||||
),
|
||||
))
|
||||
|
||||
previous_chunk_id = get_previous_chunk_id(data_chunks, chunk)
|
||||
|
||||
if previous_chunk_id is not None:
|
||||
chunk_edges.append((
|
||||
str(previous_chunk_id),
|
||||
str(chunk.chunk_id),
|
||||
"next_chunk",
|
||||
dict(
|
||||
relationship_name = "next_chunk",
|
||||
source_node_id = str(previous_chunk_id),
|
||||
target_node_id = str(chunk.chunk_id),
|
||||
),
|
||||
))
|
||||
|
||||
await graph_engine.add_nodes(chunk_nodes)
|
||||
await graph_engine.add_edges(chunk_edges)
|
||||
|
||||
return data_chunks
|
||||
|
||||
|
||||
def get_previous_chunk_id(document_chunks: list[DocumentChunk], current_chunk: DocumentChunk) -> DocumentChunk:
|
||||
if current_chunk.chunk_index == 0:
|
||||
return current_chunk.document_id
|
||||
|
||||
for chunk in document_chunks:
|
||||
if str(chunk.document_id) == str(current_chunk.document_id) \
|
||||
and chunk.chunk_index == current_chunk.chunk_index - 1:
|
||||
return chunk.chunk_id
|
||||
|
||||
return None
|
||||
3
cognee/modules/data/chunking/__init__.py
Normal file
3
cognee/modules/data/chunking/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .chunk_by_word import chunk_by_word
|
||||
from .chunk_by_sentence import chunk_by_sentence
|
||||
from .chunk_by_paragraph import chunk_by_paragraph
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
from cognee.modules.data.chunking import chunk_by_paragraph
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test_chunking_on_whole_text():
|
||||
test_text = """This is example text. It contains multiple sentences.
|
||||
This is a second paragraph. First two paragraphs are whole.
|
||||
Third paragraph is a bit longer and is finished with a dot."""
|
||||
|
||||
chunks = []
|
||||
|
||||
for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs = False):
|
||||
chunks.append(chunk_data)
|
||||
|
||||
assert len(chunks) == 3
|
||||
|
||||
assert chunks[0]["text"] == "This is example text. It contains multiple sentences."
|
||||
assert chunks[0]["word_count"] == 8
|
||||
assert chunks[0]["cut_type"] == "paragraph_end"
|
||||
|
||||
assert chunks[1]["text"] == "This is a second paragraph. First two paragraphs are whole."
|
||||
assert chunks[1]["word_count"] == 10
|
||||
assert chunks[1]["cut_type"] == "paragraph_end"
|
||||
|
||||
assert chunks[2]["text"] == "Third paragraph is a bit longer and is finished with a dot."
|
||||
assert chunks[2]["word_count"] == 12
|
||||
assert chunks[2]["cut_type"] == "sentence_end"
|
||||
|
||||
def test_chunking_on_cut_text():
|
||||
test_text = """This is example text. It contains multiple sentences.
|
||||
This is a second paragraph. First two paragraphs are whole.
|
||||
Third paragraph is cut and is missing the dot at the end"""
|
||||
|
||||
chunks = []
|
||||
|
||||
for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs = False):
|
||||
chunks.append(chunk_data)
|
||||
|
||||
assert len(chunks) == 3
|
||||
|
||||
assert chunks[0]["text"] == "This is example text. It contains multiple sentences."
|
||||
assert chunks[0]["word_count"] == 8
|
||||
assert chunks[0]["cut_type"] == "paragraph_end"
|
||||
|
||||
assert chunks[1]["text"] == "This is a second paragraph. First two paragraphs are whole."
|
||||
assert chunks[1]["word_count"] == 10
|
||||
assert chunks[1]["cut_type"] == "paragraph_end"
|
||||
|
||||
assert chunks[2]["text"] == "Third paragraph is cut and is missing the dot at the end"
|
||||
assert chunks[2]["word_count"] == 12
|
||||
assert chunks[2]["cut_type"] == "sentence_cut"
|
||||
|
||||
test_chunking_on_whole_text()
|
||||
test_chunking_on_cut_text()
|
||||
69
cognee/modules/data/chunking/chunk_by_paragraph.py
Normal file
69
cognee/modules/data/chunking/chunk_by_paragraph.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from .chunk_by_sentence import chunk_by_sentence
|
||||
|
||||
def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs = True):
|
||||
paragraph = ""
|
||||
last_cut_type = None
|
||||
last_paragraph_id = None
|
||||
paragraph_word_count = 0
|
||||
paragraph_chunk_index = 0
|
||||
|
||||
for (paragraph_id, __, sentence, word_count, end_type) in chunk_by_sentence(data):
|
||||
if paragraph_word_count > 0 and paragraph_word_count + word_count > paragraph_length:
|
||||
if batch_paragraphs is True:
|
||||
chunk_id = uuid5(NAMESPACE_OID, paragraph)
|
||||
yield dict(
|
||||
text = paragraph.strip(),
|
||||
word_count = paragraph_word_count,
|
||||
id = chunk_id, # When batching paragraphs, the paragraph_id is the same as chunk_id.
|
||||
# paragraph_id doens't mean anything since multiple paragraphs are merged.
|
||||
chunk_id = chunk_id,
|
||||
chunk_index = paragraph_chunk_index,
|
||||
cut_type = last_cut_type,
|
||||
)
|
||||
else:
|
||||
yield dict(
|
||||
text = paragraph.strip(),
|
||||
word_count = paragraph_word_count,
|
||||
id = last_paragraph_id,
|
||||
chunk_id = uuid5(NAMESPACE_OID, paragraph),
|
||||
chunk_index = paragraph_chunk_index,
|
||||
cut_type = last_cut_type,
|
||||
)
|
||||
|
||||
paragraph_chunk_index += 1
|
||||
paragraph_word_count = 0
|
||||
paragraph = ""
|
||||
|
||||
paragraph += (" " if len(paragraph) > 0 else "") + sentence
|
||||
paragraph_word_count += word_count
|
||||
|
||||
if end_type == "paragraph_end" or end_type == "sentence_cut":
|
||||
if batch_paragraphs is True:
|
||||
paragraph += "\n\n" if end_type == "paragraph_end" else ""
|
||||
else:
|
||||
yield dict(
|
||||
text = paragraph.strip(),
|
||||
word_count = paragraph_word_count,
|
||||
paragraph_id = paragraph_id,
|
||||
chunk_id = uuid5(NAMESPACE_OID, paragraph),
|
||||
chunk_index = paragraph_chunk_index,
|
||||
cut_type = end_type,
|
||||
)
|
||||
|
||||
paragraph_chunk_index = 0
|
||||
paragraph_word_count = 0
|
||||
paragraph = ""
|
||||
|
||||
last_cut_type = end_type
|
||||
last_paragraph_id = paragraph_id
|
||||
|
||||
if len(paragraph) > 0:
|
||||
yield dict(
|
||||
chunk_id = uuid5(NAMESPACE_OID, paragraph),
|
||||
text = paragraph,
|
||||
word_count = paragraph_word_count,
|
||||
paragraph_id = last_paragraph_id,
|
||||
chunk_index = paragraph_chunk_index,
|
||||
cut_type = last_cut_type,
|
||||
)
|
||||
28
cognee/modules/data/chunking/chunk_by_sentence.py
Normal file
28
cognee/modules/data/chunking/chunk_by_sentence.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from uuid import uuid4
|
||||
from .chunk_by_word import chunk_by_word
|
||||
|
||||
def chunk_by_sentence(data: str):
|
||||
sentence = ""
|
||||
paragraph_id = uuid4()
|
||||
chunk_index = 0
|
||||
word_count = 0
|
||||
|
||||
for (word, word_type) in chunk_by_word(data):
|
||||
sentence += (" " if len(sentence) > 0 else "") + word
|
||||
word_count += 1
|
||||
|
||||
if word_type == "paragraph_end" or word_type == "sentence_end":
|
||||
yield (paragraph_id, chunk_index, sentence, word_count, word_type)
|
||||
sentence = ""
|
||||
word_count = 0
|
||||
paragraph_id = uuid4() if word_type == "paragraph_end" else paragraph_id
|
||||
chunk_index = 0 if word_type == "paragraph_end" else chunk_index + 1
|
||||
|
||||
if len(sentence) > 0:
|
||||
yield (
|
||||
paragraph_id,
|
||||
chunk_index,
|
||||
sentence,
|
||||
word_count,
|
||||
"sentence_cut",
|
||||
)
|
||||
60
cognee/modules/data/chunking/chunk_by_word.py
Normal file
60
cognee/modules/data/chunking/chunk_by_word.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import re
|
||||
|
||||
def chunk_by_word(data: str):
|
||||
sentence_endings = r"[.;!?…]"
|
||||
paragraph_endings = r"[\n\r]"
|
||||
last_processed_character = ""
|
||||
|
||||
word = ""
|
||||
i = 0
|
||||
|
||||
while i < len(data):
|
||||
character = data[i]
|
||||
|
||||
if word == "" and (re.match(paragraph_endings, character) or character == " "):
|
||||
i = i + 1
|
||||
continue
|
||||
|
||||
def is_real_paragraph_end():
|
||||
if re.match(sentence_endings, last_processed_character):
|
||||
return True
|
||||
|
||||
j = i + 1
|
||||
next_character = data[j] if j < len(data) else None
|
||||
while next_character is not None and (re.match(paragraph_endings, next_character) or next_character == " "):
|
||||
j += 1
|
||||
next_character = data[j] if j < len(data) else None
|
||||
if next_character.isupper():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
if re.match(paragraph_endings, character):
|
||||
yield (word, "paragraph_end" if is_real_paragraph_end() else "word")
|
||||
word = ""
|
||||
i = i + 1
|
||||
continue
|
||||
|
||||
if character == " ":
|
||||
yield [word, "word"]
|
||||
word = ""
|
||||
i = i + 1
|
||||
continue
|
||||
|
||||
word += character
|
||||
last_processed_character = character
|
||||
|
||||
if re.match(sentence_endings, character):
|
||||
# Check for ellipses.
|
||||
if i + 2 <= len(data) and data[i] == "." and data[i + 1] == "." and data[i + 2] == ".":
|
||||
word += ".."
|
||||
i = i + 2
|
||||
|
||||
is_paragraph_end = i + 1 < len(data) and re.match(paragraph_endings, data[i + 1])
|
||||
yield (word, "paragraph_end" if is_paragraph_end else "sentence_end")
|
||||
word = ""
|
||||
|
||||
i += 1
|
||||
|
||||
if len(word) > 0:
|
||||
yield (word, "word")
|
||||
|
|
@ -1 +1,2 @@
|
|||
from .prune_data import prune_data
|
||||
from .prune_system import prune_system
|
||||
|
|
|
|||
7
cognee/modules/data/deletion/prune_data.py
Normal file
7
cognee/modules/data/deletion/prune_data.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
|
||||
async def prune_data():
|
||||
base_config = get_base_config()
|
||||
data_root_directory = base_config.data_root_directory
|
||||
LocalStorage.remove_all(data_root_directory)
|
||||
|
|
@ -1,13 +1,17 @@
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relationaldb_config
|
||||
|
||||
async def prune_system(graph = True, vector = True):
|
||||
async def prune_system(graph = True, vector = True, metadata = False):
|
||||
if graph:
|
||||
graph_config = get_graph_config()
|
||||
graph_client = await get_graph_client()
|
||||
await graph_client.delete_graph()
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.delete_graph()
|
||||
|
||||
if vector:
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.prune()
|
||||
|
||||
if metadata:
|
||||
db_config = get_relationaldb_config()
|
||||
db_engine = db_config.database_engine
|
||||
db_engine.delete_database()
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .extract_topics import extract_topics_yake, extract_topics_keybert
|
||||
0
cognee/modules/data/extraction/data_summary/__init__.py
Normal file
0
cognee/modules/data/extraction/data_summary/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
class TextSummary(BaseModel):
|
||||
text: str
|
||||
chunk_id: str
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
|
||||
import asyncio
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
|
||||
from ...processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
from ...extraction.extract_summary import extract_summary
|
||||
from .models.TextSummary import TextSummary
|
||||
|
||||
async def summarize_text_chunks(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel], collection_name: str = "summaries"):
|
||||
if len(data_chunks) == 0:
|
||||
return data_chunks
|
||||
|
||||
chunk_summaries = await asyncio.gather(
|
||||
*[extract_summary(chunk.text, summarization_model) for chunk in data_chunks]
|
||||
)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
await vector_engine.create_collection(collection_name, payload_schema = TextSummary)
|
||||
|
||||
await vector_engine.create_data_points(
|
||||
collection_name,
|
||||
[
|
||||
DataPoint[TextSummary](
|
||||
id = str(chunk.chunk_id),
|
||||
payload = dict(
|
||||
chunk_id = str(chunk.chunk_id),
|
||||
text = chunk_summaries[chunk_index].summary,
|
||||
),
|
||||
embed_field = "text",
|
||||
) for (chunk_index, chunk) in enumerate(data_chunks)
|
||||
],
|
||||
)
|
||||
|
||||
return data_chunks
|
||||
|
|
@ -11,16 +11,4 @@ async def extract_categories(content: str, response_model: Type[BaseModel]):
|
|||
|
||||
llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model)
|
||||
|
||||
return process_categories(llm_output.model_dump())
|
||||
|
||||
def process_categories(llm_output) -> List[dict]:
|
||||
# Extract the first subclass from the list (assuming there could be more)
|
||||
data_category = llm_output["label"]["subclass"][0] if len(llm_output["label"]["subclass"]) > 0 else None
|
||||
|
||||
data_type = llm_output["label"]["type"].lower()
|
||||
|
||||
return [{
|
||||
"data_type": data_type,
|
||||
# The data_category is the value of the Enum member (e.g., "News stories and blog posts")
|
||||
"category_name": data_category.value if data_category else "Other types of text data",
|
||||
}]
|
||||
return llm_output
|
||||
|
|
|
|||
|
|
@ -10,4 +10,4 @@ async def extract_summary(content: str, response_model: Type[BaseModel]):
|
|||
|
||||
llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model)
|
||||
|
||||
return llm_output.model_dump()
|
||||
return llm_output
|
||||
|
|
|
|||
113
cognee/modules/data/extraction/extract_topics.py
Normal file
113
cognee/modules/data/extraction/extract_topics.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
import re
|
||||
import nltk
|
||||
from nltk.tag import pos_tag
|
||||
from nltk.corpus import stopwords, wordnet
|
||||
from nltk.tokenize import word_tokenize
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
|
||||
def extract_topics_yake(texts: list[str]):
|
||||
from yake import KeywordExtractor
|
||||
|
||||
keyword_extractor = KeywordExtractor(
|
||||
top = 3,
|
||||
n = 2,
|
||||
dedupLim = 0.2,
|
||||
dedupFunc = "levenshtein", # "seqm" | "levenshtein"
|
||||
windowsSize = 1,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
topics = keyword_extractor.extract_keywords(preprocess_text(text))
|
||||
yield [topic[0] for topic in topics]
|
||||
|
||||
def extract_topics_keybert(texts: list[str]):
|
||||
from keybert import KeyBERT
|
||||
|
||||
kw_model = KeyBERT()
|
||||
|
||||
for text in texts:
|
||||
topics = kw_model.extract_keywords(
|
||||
preprocess_text(text),
|
||||
keyphrase_ngram_range = (1, 2),
|
||||
top_n = 3,
|
||||
# use_mmr = True,
|
||||
# diversity = 0.9,
|
||||
)
|
||||
yield [topic[0] for topic in topics]
|
||||
|
||||
def preprocess_text(text: str):
|
||||
try:
|
||||
# Used for stopwords removal.
|
||||
stopwords.ensure_loaded()
|
||||
except LookupError:
|
||||
nltk.download("stopwords", quiet = True)
|
||||
stopwords.ensure_loaded()
|
||||
|
||||
try:
|
||||
# Used in WordNetLemmatizer.
|
||||
wordnet.ensure_loaded()
|
||||
except LookupError:
|
||||
nltk.download("wordnet", quiet = True)
|
||||
wordnet.ensure_loaded()
|
||||
|
||||
try:
|
||||
# Used in word_tokenize.
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
except LookupError:
|
||||
nltk.download("punkt", quiet = True)
|
||||
|
||||
text = text.lower()
|
||||
|
||||
# Remove punctuation
|
||||
text = re.sub(r"[^\w\s-]", "", text)
|
||||
|
||||
# Tokenize the text
|
||||
tokens = word_tokenize(text)
|
||||
|
||||
tagged_tokens = pos_tag(tokens)
|
||||
tokens = [word for word, tag in tagged_tokens if tag in ["NNP", "NN", "JJ"]]
|
||||
|
||||
# Remove stop words
|
||||
stop_words = set(stopwords.words("english"))
|
||||
tokens = [word for word in tokens if word not in stop_words]
|
||||
|
||||
# Lemmatize the text
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
tokens = [lemmatizer.lemmatize(word) for word in tokens]
|
||||
|
||||
# Join tokens back to a single string
|
||||
processed_text = " ".join(tokens)
|
||||
|
||||
return processed_text
|
||||
|
||||
|
||||
# def clean_text(text: str):
|
||||
# text = re.sub(r"[ \t]{2,}|[\n\r]", " ", text.lower())
|
||||
# # text = re.sub(r"[`\"'.,;!?…]", "", text).strip()
|
||||
# return text
|
||||
|
||||
# def remove_stop_words(text: str):
|
||||
# try:
|
||||
# stopwords.ensure_loaded()
|
||||
# except LookupError:
|
||||
# download("stopwords")
|
||||
# stopwords.ensure_loaded()
|
||||
|
||||
# stop_words = set(stopwords.words("english"))
|
||||
# text = text.split()
|
||||
# text = [word for word in text if not word in stop_words]
|
||||
# return " ".join(text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
file_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
with open(os.path.join(file_dir, "texts.json"), "r", encoding = "utf-8") as file:
|
||||
import json
|
||||
texts = json.load(file)
|
||||
|
||||
for topics in extract_topics_yake(texts):
|
||||
print(topics)
|
||||
print("\n")
|
||||
66
cognee/modules/data/extraction/extract_topics_naive.py
Normal file
66
cognee/modules/data/extraction/extract_topics_naive.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import re
|
||||
from nltk.downloader import download
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
from nltk.tokenize import sent_tokenize, word_tokenize
|
||||
from nltk.corpus import stopwords, wordnet
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
|
||||
def extract_topics(text: str):
|
||||
sentences = sent_tokenize(text)
|
||||
|
||||
try:
|
||||
wordnet.ensure_loaded()
|
||||
except LookupError:
|
||||
download("wordnet")
|
||||
wordnet.ensure_loaded()
|
||||
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
base_notation_sentences = [lemmatizer.lemmatize(sentence) for sentence in sentences]
|
||||
|
||||
tf_vectorizer = TfidfVectorizer(tokenizer = word_tokenize, token_pattern = None)
|
||||
transformed_corpus = tf_vectorizer.fit_transform(base_notation_sentences)
|
||||
|
||||
svd = TruncatedSVD(n_components = 10)
|
||||
svd_corpus = svd.fit(transformed_corpus)
|
||||
|
||||
feature_scores = dict(
|
||||
zip(
|
||||
tf_vectorizer.vocabulary_,
|
||||
svd_corpus.components_[0]
|
||||
)
|
||||
)
|
||||
|
||||
topics = sorted(
|
||||
feature_scores,
|
||||
# key = feature_scores.get,
|
||||
key = lambda x: transformed_corpus[0, tf_vectorizer.vocabulary_[x]],
|
||||
reverse = True,
|
||||
)[:10]
|
||||
|
||||
return topics
|
||||
|
||||
|
||||
def clean_text(text: str):
|
||||
text = re.sub(r"[ \t]{2,}|[\n\r]", " ", text.lower())
|
||||
return re.sub(r"[`\"'.,;!?…]", "", text).strip()
|
||||
|
||||
def remove_stop_words(text: str):
|
||||
try:
|
||||
stopwords.ensure_loaded()
|
||||
except LookupError:
|
||||
download("stopwords")
|
||||
stopwords.ensure_loaded()
|
||||
|
||||
stop_words = set(stopwords.words("english"))
|
||||
text = text.split()
|
||||
text = [word for word in text if not word in stop_words]
|
||||
return " ".join(text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = """Lorem Ipsum is simply dummy text of the printing and typesetting industry... Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book… It has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.
|
||||
Why do we use it?
|
||||
It is a long established fact that a reader will be distracted by the readable content of a page when looking at its layout! The point of using Lorem Ipsum is that it has a more-or-less normal distribution of letters, as opposed to using 'Content here, content here', making it look like readable English. Many desktop publishing packages and web page editors now use Lorem Ipsum as their default model text, and a search for 'lorem ipsum' will uncover many web sites still in their infancy. Various versions have evolved over the years, sometimes by accident, sometimes on purpose (injected humour and the like).
|
||||
"""
|
||||
print(extract_topics(remove_stop_words(clean_text(text))))
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
from typing import Type, Optional, get_args, get_origin
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
|
||||
async def add_model_class_to_graph(
|
||||
model_class: Type[BaseModel],
|
||||
graph: GraphDBInterface,
|
||||
parent: Optional[str] = None,
|
||||
relationship: Optional[str] = None,
|
||||
):
|
||||
model_name = model_class.__name__
|
||||
|
||||
if await graph.extract_node(model_name):
|
||||
return
|
||||
|
||||
await graph.add_node(model_name, dict(type = "model"))
|
||||
|
||||
if parent and relationship:
|
||||
await graph.add_edge(
|
||||
parent,
|
||||
model_name,
|
||||
relationship,
|
||||
dict(
|
||||
relationship_name = relationship,
|
||||
source_node_id = parent,
|
||||
target_node_id = model_name,
|
||||
),
|
||||
)
|
||||
|
||||
for field_name, field in model_class.model_fields.items():
|
||||
original_types = get_args(field.annotation)
|
||||
field_type = original_types[0] if len(original_types) > 0 else None
|
||||
|
||||
if field_type is None:
|
||||
continue
|
||||
|
||||
if hasattr(field_type, "model_fields"): # Check if field type is a Pydantic model
|
||||
await add_model_class_to_graph(field_type, graph, model_name, field_name)
|
||||
elif get_origin(field.annotation) == list:
|
||||
list_types = get_args(field_type)
|
||||
for item_type in list_types:
|
||||
await add_model_class_to_graph(item_type, graph, model_name, field_name)
|
||||
elif isinstance(field_type, list):
|
||||
item_type = get_args(field_type)[0]
|
||||
if hasattr(item_type, "model_fields"):
|
||||
await add_model_class_to_graph(item_type, graph, model_name, field_name)
|
||||
else:
|
||||
await graph.add_node(str(item_type), dict(type = "value"))
|
||||
await graph.add_edge(
|
||||
model_name,
|
||||
str(item_type),
|
||||
field_name,
|
||||
dict(
|
||||
relationship_name = field_name,
|
||||
source_node_id = model_name,
|
||||
target_node_id = str(item_type),
|
||||
),
|
||||
)
|
||||
else:
|
||||
await graph.add_node(str(field_type), dict(type = "value"))
|
||||
await graph.add_edge(
|
||||
model_name,
|
||||
str(field_type),
|
||||
field_name,
|
||||
dict(
|
||||
relationship_name = field_name,
|
||||
source_node_id = model_name,
|
||||
target_node_id = str(field_type),
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from ...processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
from .add_model_class_to_graph import add_model_class_to_graph
|
||||
|
||||
async def establish_graph_topology(data_chunks: list[DocumentChunk], topology_model: Type[BaseModel]):
|
||||
if topology_model == KnowledgeGraph:
|
||||
return data_chunks
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
await add_model_class_to_graph(topology_model, graph_engine)
|
||||
|
||||
return data_chunks
|
||||
|
||||
|
||||
def generate_node_id(node_id: str) -> str:
|
||||
return node_id.upper().replace(" ", "_").replace("'", "")
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from ...processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
from .extract_knowledge_graph import extract_content_graph
|
||||
|
||||
async def expand_knowledge_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
|
||||
chunk_graphs = await asyncio.gather(
|
||||
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
type_ids = [generate_node_id(node.type) for chunk_graph in chunk_graphs for node in chunk_graph.nodes]
|
||||
graph_type_node_ids = list(set(type_ids))
|
||||
graph_type_nodes = await graph_engine.extract_nodes(graph_type_node_ids)
|
||||
existing_type_nodes_map = {node["id"]: node for node in graph_type_nodes}
|
||||
|
||||
graph_nodes = []
|
||||
graph_edges = []
|
||||
|
||||
for (chunk_index, chunk) in enumerate(data_chunks):
|
||||
graph = chunk_graphs[chunk_index]
|
||||
if graph is None:
|
||||
continue
|
||||
|
||||
for node in graph.nodes:
|
||||
node_id = generate_node_id(node.id)
|
||||
|
||||
graph_nodes.append((
|
||||
node_id,
|
||||
dict(
|
||||
id = node_id,
|
||||
chunk_id = str(chunk.chunk_id),
|
||||
document_id = str(chunk.document_id),
|
||||
name = node.name,
|
||||
type = node.type.lower().capitalize(),
|
||||
description = node.description,
|
||||
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
))
|
||||
|
||||
graph_edges.append((
|
||||
str(chunk.chunk_id),
|
||||
node_id,
|
||||
"contains",
|
||||
dict(
|
||||
relationship_name = "contains",
|
||||
source_node_id = str(chunk.chunk_id),
|
||||
target_node_id = node_id,
|
||||
),
|
||||
))
|
||||
|
||||
type_node_id = generate_node_id(node.type)
|
||||
|
||||
if type_node_id not in existing_type_nodes_map:
|
||||
node_name = node.type.lower().capitalize()
|
||||
|
||||
type_node = dict(
|
||||
id = type_node_id,
|
||||
name = node_name,
|
||||
type = node_name,
|
||||
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
graph_nodes.append((type_node_id, type_node))
|
||||
existing_type_nodes_map[type_node_id] = type_node
|
||||
|
||||
graph_edges.append((
|
||||
str(chunk.chunk_id),
|
||||
type_node_id,
|
||||
"contains_entity_type",
|
||||
dict(
|
||||
relationship_name = "contains_entity_type",
|
||||
source_node_id = str(chunk.chunk_id),
|
||||
target_node_id = type_node_id,
|
||||
),
|
||||
))
|
||||
|
||||
# Add relationship between entity type and entity itself: "Jake is Person"
|
||||
graph_edges.append((
|
||||
type_node_id,
|
||||
node_id,
|
||||
"is_entity_type",
|
||||
dict(
|
||||
relationship_name = "is_entity_type",
|
||||
source_node_id = type_node_id,
|
||||
target_node_id = node_id,
|
||||
),
|
||||
))
|
||||
|
||||
# Add relationship that came from graphs.
|
||||
for edge in graph.edges:
|
||||
graph_edges.append((
|
||||
generate_node_id(edge.source_node_id),
|
||||
generate_node_id(edge.target_node_id),
|
||||
edge.relationship_name,
|
||||
dict(
|
||||
relationship_name = edge.relationship_name,
|
||||
source_node_id = generate_node_id(edge.source_node_id),
|
||||
target_node_id = generate_node_id(edge.target_node_id),
|
||||
),
|
||||
))
|
||||
|
||||
await graph_engine.add_nodes(graph_nodes)
|
||||
|
||||
await graph_engine.add_edges(graph_edges)
|
||||
|
||||
return data_chunks
|
||||
|
||||
|
||||
def generate_node_id(node_id: str) -> str:
|
||||
return node_id.upper().replace(" ", "_").replace("'", "")
|
||||
|
|
@ -3,10 +3,10 @@ from pydantic import BaseModel
|
|||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
|
||||
async def extract_content_graph(content: str, cognitive_layer: str, response_model: Type[BaseModel]):
|
||||
async def extract_content_graph(content: str, response_model: Type[BaseModel]):
|
||||
llm_client = get_llm_client()
|
||||
|
||||
system_prompt = render_prompt("generate_graph_prompt.txt", { "layer": cognitive_layer })
|
||||
output = await llm_client.acreate_structured_output(content, system_prompt, response_model)
|
||||
system_prompt = render_prompt("generate_graph_prompt.txt", {})
|
||||
content_graph = await llm_client.acreate_structured_output(content, system_prompt, response_model)
|
||||
|
||||
return output.model_dump()
|
||||
return content_graph
|
||||
|
|
|
|||
7
cognee/modules/data/extraction/texts.json
Normal file
7
cognee/modules/data/extraction/texts.json
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
"Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book.\nIt has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.",
|
||||
"It is a long established fact that a reader will be distracted by the readable content of a page when looking at its layout.\n\tThe point of using Lorem Ipsum is that it has a more-or-less normal distribution of letters, as opposed to using 'Content here, content here', making it look like readable English. Many desktop publishing packages and web page editors now use Lorem Ipsum as their default model text, and a search for 'lorem ipsum' will uncover many web sites still in their infancy.\n Various versions have evolved over the years, sometimes by accident, sometimes on purpose (injected humour and the like).",
|
||||
"Contrary to popular belief, Lorem Ipsum is not simply random text. It has roots in a piece of classical Latin literature from 45 BC, making it over 2000 years old. Richard McClintock, a Latin professor at Hampden-Sydney College in Virginia, looked up one of the more obscure Latin words, consectetur, from a Lorem Ipsum passage, and going through the cites of the word in classical literature, discovered the undoubtable source. Lorem Ipsum comes from sections 1.10.32 and 1.10.33 of \"de Finibus Bonorum et Malorum\" (The Extremes of Good and Evil) by Cicero, written in 45 BC. This book is a treatise on the theory of ethics, very popular during the Renaissance.\n The first line of Lorem Ipsum, \"Lorem ipsum dolor sit amet..\", comes from a line in section 1.10.32.",
|
||||
"The standard chunk of Lorem Ipsum used since the 1500s is reproduced below for those interested. Sections 1.10.32 and 1.10.33 from \"de Finibus Bonorum et Malorum\" by Cicero are also reproduced in their exact original form, accompanied by English versions from the 1914 translation by H. Rackham.",
|
||||
"There are many variations of passages of Lorem Ipsum available, but the majority have suffered alteration in some form, by injected humour, or randomised words which don't look even slightly believable. If you are going to use a passage of Lorem Ipsum, you need to be sure there isn't anything embarrassing hidden in the middle of text. All the Lorem Ipsum generators on the Internet tend to repeat predefined chunks as necessary, making this the first true generator on the Internet. It uses a dictionary of over 200 Latin words, combined with a handful of model sentence structures, to generate Lorem Ipsum which looks reasonable. The generated Lorem Ipsum is therefore always free from repetition, injected humour, or non-characteristic words etc."
|
||||
]
|
||||
0
cognee/modules/data/processing/__init__.py
Normal file
0
cognee/modules/data/processing/__init__.py
Normal file
10
cognee/modules/data/processing/chunk_types/DocumentChunk.py
Normal file
10
cognee/modules/data/processing/chunk_types/DocumentChunk.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
class DocumentChunk(BaseModel):
|
||||
text: str
|
||||
word_count: int
|
||||
document_id: str
|
||||
chunk_id: str
|
||||
chunk_index: int
|
||||
cut_type: str
|
||||
pages: list[int]
|
||||
122
cognee/modules/data/processing/document_types/AudioDocument.py
Normal file
122
cognee/modules/data/processing/document_types/AudioDocument.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Optional, Generator
|
||||
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.data.chunking import chunk_by_paragraph
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types.Document import Document
|
||||
|
||||
class AudioReader:
|
||||
id: str
|
||||
file_path: str
|
||||
|
||||
def __init__(self, id: str, file_path: str):
|
||||
self.id = id
|
||||
self.file_path = file_path
|
||||
self.llm_client = get_llm_client() # You can choose different models like "tiny", "base", "small", etc.
|
||||
|
||||
|
||||
def read(self, max_chunk_size: Optional[int] = 1024):
|
||||
chunk_index = 0
|
||||
chunk_size = 0
|
||||
chunked_pages = []
|
||||
paragraph_chunks = []
|
||||
|
||||
# Transcribe the audio file
|
||||
result = self.llm_client.create_transcript(self.file_path)
|
||||
text = result.text
|
||||
|
||||
# Simulate reading text in chunks as done in TextReader
|
||||
def read_text_chunks(text, chunk_size):
|
||||
for i in range(0, len(text), chunk_size):
|
||||
yield text[i:i + chunk_size]
|
||||
|
||||
page_index = 0
|
||||
|
||||
for page_text in read_text_chunks(text, max_chunk_size):
|
||||
chunked_pages.append(page_index)
|
||||
page_index += 1
|
||||
|
||||
for chunk_data in chunk_by_paragraph(page_text, max_chunk_size, batch_paragraphs=True):
|
||||
if chunk_size + chunk_data["word_count"] <= max_chunk_size:
|
||||
paragraph_chunks.append(chunk_data)
|
||||
chunk_size += chunk_data["word_count"]
|
||||
else:
|
||||
if len(paragraph_chunks) == 0:
|
||||
yield DocumentChunk(
|
||||
text=chunk_data["text"],
|
||||
word_count=chunk_data["word_count"],
|
||||
document_id=str(self.id),
|
||||
chunk_id=str(chunk_data["chunk_id"]),
|
||||
chunk_index=chunk_index,
|
||||
cut_type=chunk_data["cut_type"],
|
||||
pages=[page_index],
|
||||
)
|
||||
paragraph_chunks = []
|
||||
chunk_size = 0
|
||||
else:
|
||||
chunk_text = " ".join(chunk["text"] for chunk in paragraph_chunks)
|
||||
yield DocumentChunk(
|
||||
text=chunk_text,
|
||||
word_count=chunk_size,
|
||||
document_id=str(self.id),
|
||||
chunk_id=str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
|
||||
chunk_index=chunk_index,
|
||||
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
pages=chunked_pages,
|
||||
)
|
||||
chunked_pages = [page_index]
|
||||
paragraph_chunks = [chunk_data]
|
||||
chunk_size = chunk_data["word_count"]
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if len(paragraph_chunks) > 0:
|
||||
yield DocumentChunk(
|
||||
text=" ".join(chunk["text"] for chunk in paragraph_chunks),
|
||||
word_count=chunk_size,
|
||||
document_id=str(self.id),
|
||||
chunk_id=str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
|
||||
chunk_index=chunk_index,
|
||||
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
pages=chunked_pages,
|
||||
)
|
||||
|
||||
class AudioDocument(Document):
|
||||
type: str = "audio"
|
||||
title: str
|
||||
file_path: str
|
||||
|
||||
def __init__(self, title: str, file_path: str):
|
||||
self.id = uuid5(NAMESPACE_OID, title)
|
||||
self.title = title
|
||||
self.file_path = file_path
|
||||
|
||||
reader = AudioReader(self.id, self.file_path)
|
||||
|
||||
def get_reader(self) -> AudioReader:
|
||||
reader = AudioReader(self.id, self.file_path)
|
||||
return reader
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return dict(
|
||||
id=str(self.id),
|
||||
type=self.type,
|
||||
title=self.title,
|
||||
file_path=self.file_path,
|
||||
)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # Sample usage of AudioDocument
|
||||
# audio_document = AudioDocument("sample_audio", "/Users/vasa/Projects/cognee/cognee/modules/data/processing/document_types/preamble10.wav")
|
||||
# audio_reader = audio_document.get_reader()
|
||||
# for chunk in audio_reader.read():
|
||||
# print(chunk.text)
|
||||
# print(chunk.word_count)
|
||||
# print(chunk.document_id)
|
||||
# print(chunk.chunk_id)
|
||||
# print(chunk.chunk_index)
|
||||
# print(chunk.cut_type)
|
||||
# print(chunk.pages)
|
||||
# print("----")
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from uuid import UUID
|
||||
from typing import Protocol
|
||||
|
||||
class Document(Protocol):
|
||||
id: UUID
|
||||
type: str
|
||||
title: str
|
||||
file_path: str
|
||||
124
cognee/modules/data/processing/document_types/ImageDocument.py
Normal file
124
cognee/modules/data/processing/document_types/ImageDocument.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Optional, Generator
|
||||
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.data.chunking import chunk_by_paragraph
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types.Document import Document
|
||||
|
||||
class ImageReader:
|
||||
id: str
|
||||
file_path: str
|
||||
|
||||
def __init__(self, id: str, file_path: str):
|
||||
self.id = id
|
||||
self.file_path = file_path
|
||||
self.llm_client = get_llm_client() # You can choose different models like "tiny", "base", "small", etc.
|
||||
|
||||
|
||||
def read(self, max_chunk_size: Optional[int] = 1024):
|
||||
chunk_index = 0
|
||||
chunk_size = 0
|
||||
chunked_pages = []
|
||||
paragraph_chunks = []
|
||||
|
||||
# Transcribe the image file
|
||||
result = self.llm_client.transcribe_image(self.file_path)
|
||||
print("Transcription result: ", result.choices[0].message.content)
|
||||
text = result.choices[0].message.content
|
||||
|
||||
|
||||
# Simulate reading text in chunks as done in TextReader
|
||||
def read_text_chunks(text, chunk_size):
|
||||
for i in range(0, len(text), chunk_size):
|
||||
yield text[i:i + chunk_size]
|
||||
|
||||
page_index = 0
|
||||
|
||||
for page_text in read_text_chunks(text, max_chunk_size):
|
||||
chunked_pages.append(page_index)
|
||||
page_index += 1
|
||||
|
||||
for chunk_data in chunk_by_paragraph(page_text, max_chunk_size, batch_paragraphs=True):
|
||||
if chunk_size + chunk_data["word_count"] <= max_chunk_size:
|
||||
paragraph_chunks.append(chunk_data)
|
||||
chunk_size += chunk_data["word_count"]
|
||||
else:
|
||||
if len(paragraph_chunks) == 0:
|
||||
yield DocumentChunk(
|
||||
text=chunk_data["text"],
|
||||
word_count=chunk_data["word_count"],
|
||||
document_id=str(self.id),
|
||||
chunk_id=str(chunk_data["chunk_id"]),
|
||||
chunk_index=chunk_index,
|
||||
cut_type=chunk_data["cut_type"],
|
||||
pages=[page_index],
|
||||
)
|
||||
paragraph_chunks = []
|
||||
chunk_size = 0
|
||||
else:
|
||||
chunk_text = " ".join(chunk["text"] for chunk in paragraph_chunks)
|
||||
yield DocumentChunk(
|
||||
text=chunk_text,
|
||||
word_count=chunk_size,
|
||||
document_id=str(self.id),
|
||||
chunk_id=str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
|
||||
chunk_index=chunk_index,
|
||||
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
pages=chunked_pages,
|
||||
)
|
||||
chunked_pages = [page_index]
|
||||
paragraph_chunks = [chunk_data]
|
||||
chunk_size = chunk_data["word_count"]
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if len(paragraph_chunks) > 0:
|
||||
yield DocumentChunk(
|
||||
text=" ".join(chunk["text"] for chunk in paragraph_chunks),
|
||||
word_count=chunk_size,
|
||||
document_id=str(self.id),
|
||||
chunk_id=str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
|
||||
chunk_index=chunk_index,
|
||||
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
pages=chunked_pages,
|
||||
)
|
||||
|
||||
class ImageDocument(Document):
|
||||
type: str = "image"
|
||||
title: str
|
||||
file_path: str
|
||||
|
||||
def __init__(self, title: str, file_path: str):
|
||||
self.id = uuid5(NAMESPACE_OID, title)
|
||||
self.title = title
|
||||
self.file_path = file_path
|
||||
|
||||
reader = ImageReader(self.id, self.file_path)
|
||||
|
||||
def get_reader(self) -> ImageReader:
|
||||
reader = ImageReader(self.id, self.file_path)
|
||||
return reader
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return dict(
|
||||
id=str(self.id),
|
||||
type=self.type,
|
||||
title=self.title,
|
||||
file_path=self.file_path,
|
||||
)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # Sample usage of AudioDocument
|
||||
# audio_document = ImageDocument("sample_audio", "/Users/vasa/Projects/cognee/assets/architecture.png")
|
||||
# audio_reader = audio_document.get_reader()
|
||||
# for chunk in audio_reader.read():
|
||||
# print(chunk.text)
|
||||
# print(chunk.word_count)
|
||||
# print(chunk.document_id)
|
||||
# print(chunk.chunk_id)
|
||||
# print(chunk.chunk_index)
|
||||
# print(chunk.cut_type)
|
||||
# print(chunk.pages)
|
||||
# print("----")
|
||||
109
cognee/modules/data/processing/document_types/PdfDocument.py
Normal file
109
cognee/modules/data/processing/document_types/PdfDocument.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
# import pdfplumber
|
||||
import logging
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Optional
|
||||
from pypdf import PdfReader as pypdf_PdfReader
|
||||
from cognee.modules.data.chunking import chunk_by_paragraph
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
from .Document import Document
|
||||
|
||||
class PdfReader():
|
||||
id: str
|
||||
file_path: str
|
||||
|
||||
def __init__(self, id: str, file_path: str):
|
||||
self.id = id
|
||||
self.file_path = file_path
|
||||
|
||||
def get_number_of_pages(self):
|
||||
file = pypdf_PdfReader(self.file_path)
|
||||
num_pages = file.get_num_pages()
|
||||
file.stream.close()
|
||||
return num_pages
|
||||
|
||||
def read(self, max_chunk_size: Optional[int] = 1024):
|
||||
chunk_index = 0
|
||||
chunk_size = 0
|
||||
chunked_pages = []
|
||||
paragraph_chunks = []
|
||||
|
||||
file = pypdf_PdfReader(self.file_path)
|
||||
|
||||
for (page_index, page) in enumerate(file.pages):
|
||||
page_text = page.extract_text()
|
||||
chunked_pages.append(page_index)
|
||||
|
||||
for chunk_data in chunk_by_paragraph(page_text, max_chunk_size, batch_paragraphs = True):
|
||||
if chunk_size + chunk_data["word_count"] <= max_chunk_size:
|
||||
paragraph_chunks.append(chunk_data)
|
||||
chunk_size += chunk_data["word_count"]
|
||||
else:
|
||||
if len(paragraph_chunks) == 0:
|
||||
yield DocumentChunk(
|
||||
text = chunk_data["text"],
|
||||
word_count = chunk_data["word_count"],
|
||||
document_id = str(self.id),
|
||||
chunk_id = str(chunk_data["chunk_id"]),
|
||||
chunk_index = chunk_index,
|
||||
cut_type = chunk_data["cut_type"],
|
||||
pages = [page_index],
|
||||
)
|
||||
paragraph_chunks = []
|
||||
chunk_size = 0
|
||||
else:
|
||||
chunk_text = " ".join(chunk["text"] for chunk in paragraph_chunks)
|
||||
yield DocumentChunk(
|
||||
text = chunk_text,
|
||||
word_count = chunk_size,
|
||||
document_id = str(self.id),
|
||||
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
|
||||
chunk_index = chunk_index,
|
||||
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
pages = chunked_pages,
|
||||
)
|
||||
chunked_pages = [page_index]
|
||||
paragraph_chunks = [chunk_data]
|
||||
chunk_size = chunk_data["word_count"]
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if len(paragraph_chunks) > 0:
|
||||
yield DocumentChunk(
|
||||
text = " ".join(chunk["text"] for chunk in paragraph_chunks),
|
||||
word_count = chunk_size,
|
||||
document_id = str(self.id),
|
||||
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
|
||||
chunk_index = chunk_index,
|
||||
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
pages = chunked_pages,
|
||||
)
|
||||
|
||||
file.stream.close()
|
||||
|
||||
class PdfDocument(Document):
|
||||
type: str = "pdf"
|
||||
title: str
|
||||
num_pages: int
|
||||
file_path: str
|
||||
|
||||
def __init__(self, title: str, file_path: str):
|
||||
self.id = uuid5(NAMESPACE_OID, title)
|
||||
self.title = title
|
||||
self.file_path = file_path
|
||||
logging.debug("file_path: %s", self.file_path)
|
||||
reader = PdfReader(self.id, self.file_path)
|
||||
self.num_pages = reader.get_number_of_pages()
|
||||
|
||||
def get_reader(self) -> PdfReader:
|
||||
logging.debug("file_path: %s", self.file_path)
|
||||
reader = PdfReader(self.id, self.file_path)
|
||||
return reader
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return dict(
|
||||
id = str(self.id),
|
||||
type = self.type,
|
||||
title = self.title,
|
||||
num_pages = self.num_pages,
|
||||
file_path = self.file_path,
|
||||
)
|
||||
112
cognee/modules/data/processing/document_types/TextDocument.py
Normal file
112
cognee/modules/data/processing/document_types/TextDocument.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Optional
|
||||
from cognee.modules.data.chunking import chunk_by_paragraph
|
||||
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
|
||||
from .Document import Document
|
||||
|
||||
class TextReader():
|
||||
id: str
|
||||
file_path: str
|
||||
|
||||
def __init__(self, id: str, file_path: str):
|
||||
self.id = id
|
||||
self.file_path = file_path
|
||||
|
||||
def get_number_of_pages(self):
|
||||
num_pages = 1 # Pure text is not formatted
|
||||
return num_pages
|
||||
|
||||
def read(self, max_chunk_size: Optional[int] = 1024):
|
||||
chunk_index = 0
|
||||
chunk_size = 0
|
||||
chunked_pages = []
|
||||
paragraph_chunks = []
|
||||
|
||||
def read_text_chunks(file_path):
|
||||
with open(file_path, mode = "r", encoding = "utf-8") as file:
|
||||
while True:
|
||||
text = file.read(1024)
|
||||
|
||||
if len(text.strip()) == 0:
|
||||
break
|
||||
|
||||
yield text
|
||||
|
||||
|
||||
page_index = 0
|
||||
|
||||
for page_text in read_text_chunks(self.file_path):
|
||||
chunked_pages.append(page_index)
|
||||
page_index += 1
|
||||
|
||||
for chunk_data in chunk_by_paragraph(page_text, max_chunk_size, batch_paragraphs = True):
|
||||
if chunk_size + chunk_data["word_count"] <= max_chunk_size:
|
||||
paragraph_chunks.append(chunk_data)
|
||||
chunk_size += chunk_data["word_count"]
|
||||
else:
|
||||
if len(paragraph_chunks) == 0:
|
||||
yield DocumentChunk(
|
||||
text = chunk_data["text"],
|
||||
word_count = chunk_data["word_count"],
|
||||
document_id = str(self.id),
|
||||
chunk_id = str(chunk_data["chunk_id"]),
|
||||
chunk_index = chunk_index,
|
||||
cut_type = chunk_data["cut_type"],
|
||||
pages = [page_index],
|
||||
)
|
||||
paragraph_chunks = []
|
||||
chunk_size = 0
|
||||
else:
|
||||
chunk_text = " ".join(chunk["text"] for chunk in paragraph_chunks)
|
||||
yield DocumentChunk(
|
||||
text = chunk_text,
|
||||
word_count = chunk_size,
|
||||
document_id = str(self.id),
|
||||
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
|
||||
chunk_index = chunk_index,
|
||||
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
pages = chunked_pages,
|
||||
)
|
||||
chunked_pages = [page_index]
|
||||
paragraph_chunks = [chunk_data]
|
||||
chunk_size = chunk_data["word_count"]
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if len(paragraph_chunks) > 0:
|
||||
yield DocumentChunk(
|
||||
text = " ".join(chunk["text"] for chunk in paragraph_chunks),
|
||||
word_count = chunk_size,
|
||||
document_id = str(self.id),
|
||||
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{chunk_index}")),
|
||||
chunk_index = chunk_index,
|
||||
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
pages = chunked_pages,
|
||||
)
|
||||
|
||||
class TextDocument(Document):
|
||||
type: str = "text"
|
||||
title: str
|
||||
num_pages: int
|
||||
file_path: str
|
||||
|
||||
def __init__(self, title: str, file_path: str):
|
||||
self.id = uuid5(NAMESPACE_OID, title)
|
||||
self.title = title
|
||||
self.file_path = file_path
|
||||
|
||||
reader = TextReader(self.id, self.file_path)
|
||||
self.num_pages = reader.get_number_of_pages()
|
||||
|
||||
def get_reader(self) -> TextReader:
|
||||
reader = TextReader(self.id, self.file_path)
|
||||
return reader
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return dict(
|
||||
id = str(self.id),
|
||||
type = self.type,
|
||||
title = self.title,
|
||||
num_pages = self.num_pages,
|
||||
file_path = self.file_path,
|
||||
)
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from .PdfDocument import PdfDocument
|
||||
from .TextDocument import TextDocument
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
import os
|
||||
from cognee.modules.data.processing.document_types.PdfDocument import PdfDocument
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_file_path = os.path.join(os.path.dirname(__file__), "artificial-inteligence.pdf")
|
||||
pdf_doc = PdfDocument("Test document.pdf", test_file_path)
|
||||
reader = pdf_doc.get_reader()
|
||||
|
||||
for paragraph_data in reader.read():
|
||||
print(paragraph_data["word_count"])
|
||||
print(paragraph_data["text"])
|
||||
print(paragraph_data["cut_type"])
|
||||
print("\n")
|
||||
Binary file not shown.
Binary file not shown.
25
cognee/modules/data/processing/filter_affected_chunks.py
Normal file
25
cognee/modules/data/processing/filter_affected_chunks.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from .chunk_types import DocumentChunk
|
||||
|
||||
async def filter_affected_chunks(data_chunks: list[DocumentChunk], collection_name: str) -> list[DocumentChunk]:
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
if not await vector_engine.has_collection(collection_name):
|
||||
# If collection doesn't exist, all data_chunks are new
|
||||
return data_chunks
|
||||
|
||||
existing_chunks = await vector_engine.retrieve(
|
||||
collection_name,
|
||||
[str(chunk.chunk_id) for chunk in data_chunks],
|
||||
)
|
||||
|
||||
existing_chunks_map = {chunk.id: chunk.payload for chunk in existing_chunks}
|
||||
|
||||
affected_data_chunks = []
|
||||
|
||||
for chunk in data_chunks:
|
||||
if chunk.chunk_id not in existing_chunks_map or \
|
||||
chunk.text != existing_chunks_map[chunk.chunk_id]["text"]:
|
||||
affected_data_chunks.append(chunk)
|
||||
|
||||
return affected_data_chunks
|
||||
30
cognee/modules/data/processing/has_new_chunks.py
Normal file
30
cognee/modules/data/processing/has_new_chunks.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from .chunk_types import DocumentChunk
|
||||
|
||||
async def has_new_chunks(data_chunks: list[DocumentChunk], collection_name: str) -> list[DocumentChunk]:
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
if not await vector_engine.has_collection(collection_name):
|
||||
# There is no collection created,
|
||||
# so no existing chunks, all chunks are new.
|
||||
return True
|
||||
|
||||
existing_chunks = await vector_engine.retrieve(
|
||||
collection_name,
|
||||
[str(chunk.chunk_id) for chunk in data_chunks],
|
||||
)
|
||||
|
||||
if len(existing_chunks) == 0:
|
||||
# If we don't find any existing chunk,
|
||||
# all chunks are new.
|
||||
return True
|
||||
|
||||
existing_chunks_map = {chunk.id: chunk.payload for chunk in existing_chunks}
|
||||
|
||||
new_data_chunks = [
|
||||
chunk for chunk in data_chunks \
|
||||
if chunk.chunk_id not in existing_chunks_map \
|
||||
or chunk.text != existing_chunks_map[chunk.chunk_id]["text"]
|
||||
]
|
||||
|
||||
return len(new_data_chunks) > 0
|
||||
41
cognee/modules/data/processing/process_documents.py
Normal file
41
cognee/modules/data/processing/process_documents.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from .document_types import Document
|
||||
|
||||
async def process_documents(documents: list[Document], parent_node_id: str = None):
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
if parent_node_id and await graph_engine.extract_node(parent_node_id) is None:
|
||||
nodes.append((parent_node_id, {}))
|
||||
|
||||
document_nodes = await graph_engine.extract_nodes([str(document.id) for document in documents])
|
||||
|
||||
for (document_index, document) in enumerate(documents):
|
||||
document_node = document_nodes[document_index] if document_index in document_nodes else None
|
||||
|
||||
if document_node is None:
|
||||
nodes.append((str(document.id), document.to_dict()))
|
||||
|
||||
if parent_node_id:
|
||||
edges.append((
|
||||
parent_node_id,
|
||||
str(document.id),
|
||||
"has_document",
|
||||
dict(
|
||||
relationship_name = "has_document",
|
||||
source_node_id = parent_node_id,
|
||||
target_node_id = str(document.id),
|
||||
),
|
||||
))
|
||||
|
||||
if len(nodes) > 0:
|
||||
await graph_engine.add_nodes(nodes)
|
||||
await graph_engine.add_edges(edges)
|
||||
|
||||
for document in documents:
|
||||
document_reader = document.get_reader()
|
||||
|
||||
for document_chunk in document_reader.read(max_chunk_size = 1024):
|
||||
yield document_chunk
|
||||
29
cognee/modules/data/processing/remove_obsolete_chunks.py
Normal file
29
cognee/modules/data/processing/remove_obsolete_chunks.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
# from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from .chunk_types import DocumentChunk
|
||||
|
||||
async def remove_obsolete_chunks(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
document_ids = set((data_chunk.document_id for data_chunk in data_chunks))
|
||||
|
||||
obsolete_chunk_ids = []
|
||||
|
||||
for document_id in document_ids:
|
||||
chunk_ids = await graph_engine.get_successor_ids(document_id, edge_label = "has_chunk")
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
previous_chunks = await graph_engine.get_predecessor_ids(chunk_id, edge_label = "next_chunk")
|
||||
|
||||
if len(previous_chunks) == 0:
|
||||
obsolete_chunk_ids.append(chunk_id)
|
||||
|
||||
if len(obsolete_chunk_ids) > 0:
|
||||
await graph_engine.delete_nodes(obsolete_chunk_ids)
|
||||
|
||||
disconnected_nodes = await graph_engine.get_disconnected_nodes()
|
||||
if len(disconnected_nodes) > 0:
|
||||
await graph_engine.delete_nodes(disconnected_nodes)
|
||||
|
||||
return data_chunks
|
||||
|
|
@ -8,7 +8,7 @@ def classify(data: Union[str, BinaryIO], filename: str = None):
|
|||
return TextData(data)
|
||||
|
||||
if isinstance(data, BufferedReader):
|
||||
return BinaryData(data)
|
||||
return BinaryData(data, data.name.split("/")[-1] if data.name else filename)
|
||||
|
||||
if hasattr(data, "file"):
|
||||
return BinaryData(data.file, filename)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class BinaryData(IngestionData):
|
|||
def get_identifier(self):
|
||||
metadata = self.get_metadata()
|
||||
|
||||
return metadata["mime_type"] + "_" + "|".join(metadata["keywords"])
|
||||
return self.name + "_" + metadata["mime_type"]
|
||||
|
||||
def get_metadata(self):
|
||||
self.ensure_metadata()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class TextData(IngestionData):
|
|||
self.data = data
|
||||
|
||||
def get_identifier(self):
|
||||
keywords = self.get_metadata()["keywords"]
|
||||
keywords = extract_keywords(self.data)
|
||||
|
||||
return "text/plain" + "_" + "|".join(keywords)
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ class TextData(IngestionData):
|
|||
|
||||
def ensure_metadata(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = dict(keywords = extract_keywords(self.data))
|
||||
self.metadata = {}
|
||||
|
||||
def get_data(self):
|
||||
return self.data
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ def save_data_to_file(data: Union[str, BinaryIO], dataset_name: str, filename: s
|
|||
file_metadata = classified_data.get_metadata()
|
||||
if "name" not in file_metadata or file_metadata["name"] is None:
|
||||
letters = string.ascii_lowercase
|
||||
random_string = ''.join(random.choice(letters) for _ in range(32))
|
||||
file_metadata["name"] = "file" + random_string
|
||||
random_string = "".join(random.choice(letters) for _ in range(32))
|
||||
file_metadata["name"] = "text_" + random_string + ".txt"
|
||||
file_name = file_metadata["name"]
|
||||
LocalStorage(storage_path).store(file_name, classified_data.get_data())
|
||||
|
||||
|
|
|
|||
18
cognee/modules/pipelines/Pipeline.py
Normal file
18
cognee/modules/pipelines/Pipeline.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from uuid import UUID, uuid4
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from .models.Task import Task
|
||||
|
||||
class PipelineConfig(BaseModel):
|
||||
batch_count: int = 10
|
||||
description: Optional[str]
|
||||
|
||||
class Pipeline():
|
||||
id: UUID = uuid4()
|
||||
name: str
|
||||
description: str
|
||||
tasks: list[Task] = []
|
||||
|
||||
def __init__(self, name: str, pipeline_config: PipelineConfig):
|
||||
self.name = name
|
||||
self.description = pipeline_config.description
|
||||
2
cognee/modules/pipelines/__init__.py
Normal file
2
cognee/modules/pipelines/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from .operations.run_tasks import run_tasks
|
||||
from .operations.run_parallel import run_tasks_parallel
|
||||
22
cognee/modules/pipelines/models/Pipeline.py
Normal file
22
cognee/modules/pipelines/models/Pipeline.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from typing import List
|
||||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, UUID, DateTime, String, Text
|
||||
from sqlalchemy.orm import relationship, Mapped
|
||||
from cognee.infrastructure.databases.relational import ModelBase
|
||||
from .PipelineTask import PipelineTask
|
||||
|
||||
class Pipeline(ModelBase):
|
||||
__tablename__ = "pipelines"
|
||||
|
||||
id = Column(UUID, primary_key = True, default = uuid4())
|
||||
name = Column(String)
|
||||
description = Column(Text, nullable = True)
|
||||
|
||||
created_at = Column(DateTime, default = datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, onupdate = datetime.now(timezone.utc))
|
||||
|
||||
tasks = Mapped[List["Task"]] = relationship(
|
||||
secondary = PipelineTask.__tablename__,
|
||||
back_populates = "pipeline",
|
||||
)
|
||||
14
cognee/modules/pipelines/models/PipelineTask.py
Normal file
14
cognee/modules/pipelines/models/PipelineTask.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, DateTime, UUID, ForeignKey
|
||||
from cognee.infrastructure.databases.relational import ModelBase
|
||||
|
||||
class PipelineTask(ModelBase):
|
||||
__tablename__ = "pipeline_task"
|
||||
|
||||
id = Column(UUID, primary_key = True, default = uuid4())
|
||||
|
||||
created_at = Column(DateTime, default = datetime.now(timezone.utc))
|
||||
|
||||
pipeline_id = Column("pipeline", UUID, ForeignKey("pipeline.id"), primary_key = True)
|
||||
task_id = Column("task", UUID, ForeignKey("task.id"), primary_key = True)
|
||||
24
cognee/modules/pipelines/models/Task.py
Normal file
24
cognee/modules/pipelines/models/Task.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
from uuid import uuid4
|
||||
from typing import List
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import relationship, Mapped
|
||||
from sqlalchemy import Column, String, DateTime, UUID, Text
|
||||
from cognee.infrastructure.databases.relational import ModelBase
|
||||
from .PipelineTask import PipelineTask
|
||||
|
||||
class Task(ModelBase):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
id = Column(UUID, primary_key = True, default = uuid4())
|
||||
name = Column(String)
|
||||
description = Column(Text, nullable = True)
|
||||
|
||||
executable = Column(Text)
|
||||
|
||||
created_at = Column(DateTime, default = datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, onupdate = datetime.now(timezone.utc))
|
||||
|
||||
datasets: Mapped[List["Pipeline"]] = relationship(
|
||||
secondary = PipelineTask.__tablename__,
|
||||
back_populates = "task"
|
||||
)
|
||||
0
cognee/modules/pipelines/operations/__init__.py
Normal file
0
cognee/modules/pipelines/operations/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,14 @@
|
|||
import asyncio
|
||||
from cognee.shared.utils import render_graph
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
graph_client = await get_graph_engine()
|
||||
graph = graph_client.graph
|
||||
|
||||
graph_url = await render_graph(graph)
|
||||
|
||||
print(graph_url)
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
import asyncio
|
||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
|
||||
|
||||
async def main():
|
||||
def number_generator(num):
|
||||
for i in range(num):
|
||||
yield i + 1
|
||||
|
||||
async def add_one(num):
|
||||
yield num + 1
|
||||
|
||||
async def multiply_by_two(nums):
|
||||
for num in nums:
|
||||
yield num * 2
|
||||
|
||||
async def add_one_to_batched_data(num):
|
||||
yield num + 1
|
||||
|
||||
pipeline = run_tasks([
|
||||
Task(number_generator, task_config = {"batch_size": 1}),
|
||||
Task(add_one, task_config = {"batch_size": 5}),
|
||||
Task(multiply_by_two, task_config = {"batch_size": 1}),
|
||||
Task(add_one_to_batched_data),
|
||||
], 10)
|
||||
|
||||
async for result in pipeline:
|
||||
print("\n")
|
||||
print(result)
|
||||
print("\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
4
cognee/modules/pipelines/operations/add_task.py
Normal file
4
cognee/modules/pipelines/operations/add_task.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from ..models import Pipeline, Task
|
||||
|
||||
def add_task(pipeline: Pipeline, task: Task):
|
||||
pipeline.tasks.append(task)
|
||||
12
cognee/modules/pipelines/operations/run_parallel.py
Normal file
12
cognee/modules/pipelines/operations/run_parallel.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from typing import Any, Callable, Generator
|
||||
import asyncio
|
||||
from ..tasks.Task import Task
|
||||
|
||||
def run_tasks_parallel(tasks: [Task]) -> Callable[[Any], Generator[Any, Any, Any]]:
|
||||
async def parallel_run(*args, **kwargs):
|
||||
parallel_tasks = [asyncio.create_task(task.run(*args, **kwargs)) for task in tasks]
|
||||
|
||||
results = await asyncio.gather(*parallel_tasks)
|
||||
return results[len(results) - 1] if len(results) > 1 else []
|
||||
|
||||
return Task(parallel_run)
|
||||
90
cognee/modules/pipelines/operations/run_tasks.py
Normal file
90
cognee/modules/pipelines/operations/run_tasks.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
import inspect
|
||||
import logging
|
||||
from ..tasks.Task import Task
|
||||
|
||||
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
|
||||
|
||||
async def run_tasks(tasks: [Task], data):
|
||||
if len(tasks) == 0:
|
||||
yield data
|
||||
return
|
||||
|
||||
running_task = tasks[0]
|
||||
batch_size = running_task.task_config["batch_size"]
|
||||
leftover_tasks = tasks[1:]
|
||||
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
|
||||
# next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
||||
|
||||
if inspect.isasyncgenfunction(running_task.executable):
|
||||
logger.info(f"Running async generator task: `{running_task.executable.__name__}`")
|
||||
try:
|
||||
results = []
|
||||
|
||||
async_iterator = running_task.run(data)
|
||||
|
||||
async for partial_result in async_iterator:
|
||||
results.append(partial_result)
|
||||
|
||||
if len(results) == batch_size:
|
||||
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
|
||||
yield result
|
||||
|
||||
results = []
|
||||
|
||||
if len(results) > 0:
|
||||
async for result in run_tasks(leftover_tasks, results):
|
||||
yield result
|
||||
|
||||
results = []
|
||||
|
||||
logger.info(f"Finished async generator task: `{running_task.executable.__name__}`")
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error occurred while running async generator task: `%s`\n%s\n",
|
||||
running_task.executable.__name__,
|
||||
str(error),
|
||||
exc_info = True,
|
||||
)
|
||||
raise error
|
||||
|
||||
elif inspect.isgeneratorfunction(running_task.executable):
|
||||
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
|
||||
try:
|
||||
results = []
|
||||
|
||||
for partial_result in running_task.run(data):
|
||||
results.append(partial_result)
|
||||
|
||||
if len(results) == batch_size:
|
||||
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
|
||||
yield result
|
||||
|
||||
results = []
|
||||
|
||||
if len(results) > 0:
|
||||
async for result in run_tasks(leftover_tasks, results):
|
||||
yield result
|
||||
|
||||
results = []
|
||||
|
||||
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error occurred while running generator task: `%s`\n%s\n",
|
||||
running_task.executable.__name__,
|
||||
str(error),
|
||||
exc_info = True,
|
||||
)
|
||||
raise error
|
||||
|
||||
elif inspect.iscoroutinefunction(running_task.executable):
|
||||
task_result = await running_task.run(data)
|
||||
|
||||
async for result in run_tasks(leftover_tasks, task_result):
|
||||
yield result
|
||||
|
||||
elif inspect.isfunction(running_task.executable):
|
||||
task_result = running_task.run(data)
|
||||
|
||||
async for result in run_tasks(leftover_tasks, task_result):
|
||||
yield result
|
||||
32
cognee/modules/pipelines/tasks/Task.py
Normal file
32
cognee/modules/pipelines/tasks/Task.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from typing import Union, Callable, Any, Coroutine, Generator, AsyncGenerator
|
||||
|
||||
class Task():
|
||||
executable: Union[
|
||||
Callable[..., Any],
|
||||
Callable[..., Coroutine[Any, Any, Any]],
|
||||
Generator[Any, Any, Any],
|
||||
AsyncGenerator[Any, Any],
|
||||
]
|
||||
task_config: dict[str, Any] = {
|
||||
"batch_size": 1,
|
||||
}
|
||||
default_params: dict[str, Any] = {}
|
||||
|
||||
def __init__(self, executable, *args, task_config = None, **kwargs):
|
||||
self.executable = executable
|
||||
self.default_params = {
|
||||
"args": args,
|
||||
"kwargs": kwargs
|
||||
}
|
||||
|
||||
if task_config is not None:
|
||||
self.task_config = task_config
|
||||
|
||||
if "batch_size" not in task_config:
|
||||
self.task_config["batch_size"] = 1
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
combined_args = self.default_params["args"] + args
|
||||
combined_kwargs = { **self.default_params["kwargs"], **kwargs }
|
||||
|
||||
return self.executable(*combined_args, **combined_kwargs)
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue