Merge branch 'main' into COG-575-remove-graph-overwrite-on-error
This commit is contained in:
commit
be792a7ba6
127 changed files with 3122 additions and 2951 deletions
4
.github/workflows/test_neo4j.yml
vendored
4
.github/workflows/test_neo4j.yml
vendored
|
|
@ -5,7 +5,7 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
types: [labeled]
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
|
|
@ -22,7 +22,7 @@ jobs:
|
||||||
run_neo4j_integration_test:
|
run_neo4j_integration_test:
|
||||||
name: test
|
name: test
|
||||||
needs: get_docs_changes
|
needs: get_docs_changes
|
||||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
|
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
defaults:
|
defaults:
|
||||||
|
|
|
||||||
4
.github/workflows/test_notebook.yml
vendored
4
.github/workflows/test_notebook.yml
vendored
|
|
@ -5,7 +5,7 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
types: [labeled]
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
|
|
@ -23,7 +23,7 @@ jobs:
|
||||||
run_notebook_test:
|
run_notebook_test:
|
||||||
name: test
|
name: test
|
||||||
needs: get_docs_changes
|
needs: get_docs_changes
|
||||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
|
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
|
|
|
||||||
4
.github/workflows/test_pgvector.yml
vendored
4
.github/workflows/test_pgvector.yml
vendored
|
|
@ -5,7 +5,7 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
types: [labeled]
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
|
|
@ -23,7 +23,7 @@ jobs:
|
||||||
run_pgvector_integration_test:
|
run_pgvector_integration_test:
|
||||||
name: test
|
name: test
|
||||||
needs: get_docs_changes
|
needs: get_docs_changes
|
||||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
|
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
|
|
|
||||||
6
.github/workflows/test_python_3_10.yml
vendored
6
.github/workflows/test_python_3_10.yml
vendored
|
|
@ -5,10 +5,10 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
types: [labeled]
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ${{ github.event.label.name == 'run-checks' }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
env:
|
env:
|
||||||
|
|
@ -22,7 +22,7 @@ jobs:
|
||||||
run_common:
|
run_common:
|
||||||
name: test
|
name: test
|
||||||
needs: get_docs_changes
|
needs: get_docs_changes
|
||||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
|
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
|
||||||
6
.github/workflows/test_python_3_11.yml
vendored
6
.github/workflows/test_python_3_11.yml
vendored
|
|
@ -5,10 +5,10 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
types: [labeled]
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ${{ github.event.label.name == 'run-checks' }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
env:
|
env:
|
||||||
|
|
@ -22,7 +22,7 @@ jobs:
|
||||||
run_common:
|
run_common:
|
||||||
name: test
|
name: test
|
||||||
needs: get_docs_changes
|
needs: get_docs_changes
|
||||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
|
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
|
||||||
6
.github/workflows/test_python_3_9.yml
vendored
6
.github/workflows/test_python_3_9.yml
vendored
|
|
@ -5,10 +5,10 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
types: [labeled]
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ${{ github.event.label.name == 'run-checks' }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
env:
|
env:
|
||||||
|
|
@ -22,7 +22,7 @@ jobs:
|
||||||
run_common:
|
run_common:
|
||||||
name: test
|
name: test
|
||||||
needs: get_docs_changes
|
needs: get_docs_changes
|
||||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
|
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
|
||||||
4
.github/workflows/test_qdrant.yml
vendored
4
.github/workflows/test_qdrant.yml
vendored
|
|
@ -5,7 +5,7 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
types: [labeled]
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
|
|
@ -23,7 +23,7 @@ jobs:
|
||||||
run_qdrant_integration_test:
|
run_qdrant_integration_test:
|
||||||
name: test
|
name: test
|
||||||
needs: get_docs_changes
|
needs: get_docs_changes
|
||||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
|
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
defaults:
|
defaults:
|
||||||
|
|
|
||||||
4
.github/workflows/test_weaviate.yml
vendored
4
.github/workflows/test_weaviate.yml
vendored
|
|
@ -5,7 +5,7 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
types: [labeled]
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
|
|
@ -23,7 +23,7 @@ jobs:
|
||||||
run_weaviate_integration_test:
|
run_weaviate_integration_test:
|
||||||
name: test
|
name: test
|
||||||
needs: get_docs_changes
|
needs: get_docs_changes
|
||||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
|
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
defaults:
|
defaults:
|
||||||
|
|
|
||||||
24
README.md
24
README.md
|
|
@ -109,24 +109,34 @@ import asyncio
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.api.v1.search import SearchType
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
await cognee.prune.prune_data() # Reset cognee data
|
# Reset cognee data
|
||||||
await cognee.prune.prune_system(metadata=True) # Reset cognee system state
|
await cognee.prune.prune_data()
|
||||||
|
# Reset cognee system state
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
text = """
|
text = """
|
||||||
Natural language processing (NLP) is an interdisciplinary
|
Natural language processing (NLP) is an interdisciplinary
|
||||||
subfield of computer science and information retrieval.
|
subfield of computer science and information retrieval.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await cognee.add(text) # Add text to cognee
|
# Add text to cognee
|
||||||
await cognee.cognify() # Use LLMs and cognee to create knowledge graph
|
await cognee.add(text)
|
||||||
|
|
||||||
search_results = await cognee.search( # Search cognee for insights
|
# Use LLMs and cognee to create knowledge graph
|
||||||
|
await cognee.cognify()
|
||||||
|
|
||||||
|
# Search cognee for insights
|
||||||
|
search_results = await cognee.search(
|
||||||
SearchType.INSIGHTS,
|
SearchType.INSIGHTS,
|
||||||
{'query': 'Tell me about NLP'}
|
"Tell me about NLP",
|
||||||
)
|
)
|
||||||
|
|
||||||
for result_text in search_results: # Display results
|
# Display results
|
||||||
|
for result_text in search_results:
|
||||||
print(result_text)
|
print(result_text)
|
||||||
|
# natural_language_processing is_a field
|
||||||
|
# natural_language_processing is_subfield_of computer_science
|
||||||
|
# natural_language_processing is_subfield_of information_retrieval
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
```
|
```
|
||||||
|
|
|
||||||
110
cognee/api/v1/cognify/code_graph_pipeline.py
Normal file
110
cognee/api/v1/cognify/code_graph_pipeline.py
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from cognee.shared.SourceCodeGraph import SourceCodeGraph
|
||||||
|
from cognee.shared.utils import send_telemetry
|
||||||
|
from cognee.modules.data.models import Dataset, Data
|
||||||
|
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||||
|
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||||
|
from cognee.modules.pipelines.tasks.Task import Task
|
||||||
|
from cognee.modules.pipelines import run_tasks
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||||
|
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||||
|
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
|
||||||
|
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
|
||||||
|
from cognee.tasks.graph import extract_graph_from_code
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
|
||||||
|
logger = logging.getLogger("code_graph_pipeline")
|
||||||
|
|
||||||
|
update_status_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
class PermissionDeniedException(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None):
|
||||||
|
if user is None:
|
||||||
|
user = await get_default_user()
|
||||||
|
|
||||||
|
existing_datasets = await get_datasets(user.id)
|
||||||
|
|
||||||
|
if datasets is None or len(datasets) == 0:
|
||||||
|
# If no datasets are provided, cognify all existing datasets.
|
||||||
|
datasets = existing_datasets
|
||||||
|
|
||||||
|
if type(datasets[0]) == str:
|
||||||
|
datasets = await get_datasets_by_name(datasets, user.id)
|
||||||
|
|
||||||
|
existing_datasets_map = {
|
||||||
|
generate_dataset_name(dataset.name): True for dataset in existing_datasets
|
||||||
|
}
|
||||||
|
|
||||||
|
awaitables = []
|
||||||
|
|
||||||
|
for dataset in datasets:
|
||||||
|
dataset_name = generate_dataset_name(dataset.name)
|
||||||
|
|
||||||
|
if dataset_name in existing_datasets_map:
|
||||||
|
awaitables.append(run_pipeline(dataset, user))
|
||||||
|
|
||||||
|
return await asyncio.gather(*awaitables)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_pipeline(dataset: Dataset, user: User):
|
||||||
|
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
|
||||||
|
|
||||||
|
document_ids_str = [str(document.id) for document in data_documents]
|
||||||
|
|
||||||
|
dataset_id = dataset.id
|
||||||
|
dataset_name = generate_dataset_name(dataset.name)
|
||||||
|
|
||||||
|
send_telemetry("code_graph_pipeline EXECUTION STARTED", user.id)
|
||||||
|
|
||||||
|
async with update_status_lock:
|
||||||
|
task_status = await get_pipeline_status([dataset_id])
|
||||||
|
|
||||||
|
if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
|
||||||
|
logger.info("Dataset %s is already being processed.", dataset_name)
|
||||||
|
return
|
||||||
|
|
||||||
|
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
|
||||||
|
"dataset_name": dataset_name,
|
||||||
|
"files": document_ids_str,
|
||||||
|
})
|
||||||
|
try:
|
||||||
|
tasks = [
|
||||||
|
Task(classify_documents),
|
||||||
|
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
|
||||||
|
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
|
||||||
|
Task(add_data_points, task_config = { "batch_size": 10 }),
|
||||||
|
Task(extract_graph_from_code, graph_model = SourceCodeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
|
||||||
|
]
|
||||||
|
|
||||||
|
pipeline = run_tasks(tasks, data_documents, "code_graph_pipeline")
|
||||||
|
|
||||||
|
async for result in pipeline:
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
send_telemetry("code_graph_pipeline EXECUTION COMPLETED", user.id)
|
||||||
|
|
||||||
|
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_COMPLETED, {
|
||||||
|
"dataset_name": dataset_name,
|
||||||
|
"files": document_ids_str,
|
||||||
|
})
|
||||||
|
except Exception as error:
|
||||||
|
send_telemetry("code_graph_pipeline EXECUTION ERRORED", user.id)
|
||||||
|
|
||||||
|
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_ERRORED, {
|
||||||
|
"dataset_name": dataset_name,
|
||||||
|
"files": document_ids_str,
|
||||||
|
})
|
||||||
|
raise error
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dataset_name(dataset_name: str) -> str:
|
||||||
|
return dataset_name.replace(".", "_").replace(" ", "_")
|
||||||
|
|
@ -9,21 +9,15 @@ from cognee.modules.data.models import Dataset, Data
|
||||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||||
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||||
from cognee.modules.pipelines.tasks.Task import Task
|
from cognee.modules.pipelines.tasks.Task import Task
|
||||||
from cognee.modules.pipelines import run_tasks, run_tasks_parallel
|
from cognee.modules.pipelines import run_tasks
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||||
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
|
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
|
||||||
from cognee.tasks import chunk_naive_llm_classifier, \
|
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
|
||||||
chunk_remove_disconnected, \
|
from cognee.tasks.graph import extract_graph_from_data
|
||||||
infer_data_ontology, \
|
from cognee.tasks.storage import add_data_points
|
||||||
save_chunks_to_store, \
|
|
||||||
chunk_update_check, \
|
|
||||||
chunks_into_graph, \
|
|
||||||
source_documents_to_chunks, \
|
|
||||||
check_permissions_on_documents, \
|
|
||||||
classify_documents
|
|
||||||
from cognee.tasks.summarization import summarize_text
|
from cognee.tasks.summarization import summarize_text
|
||||||
|
|
||||||
logger = logging.getLogger("cognify.v2")
|
logger = logging.getLogger("cognify.v2")
|
||||||
|
|
@ -87,31 +81,17 @@ async def run_cognify_pipeline(dataset: Dataset, user: User):
|
||||||
try:
|
try:
|
||||||
cognee_config = get_cognify_config()
|
cognee_config = get_cognify_config()
|
||||||
|
|
||||||
root_node_id = None
|
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
|
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
|
||||||
Task(infer_data_ontology, root_node_id = root_node_id, ontology_model = KnowledgeGraph),
|
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
|
||||||
Task(source_documents_to_chunks, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
Task(add_data_points, task_config = { "batch_size": 10 }),
|
||||||
Task(chunks_into_graph, graph_model = KnowledgeGraph, collection_name = "entities", task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
|
||||||
Task(chunk_update_check, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
|
||||||
Task(
|
Task(
|
||||||
save_chunks_to_store,
|
summarize_text,
|
||||||
collection_name = "chunks",
|
summarization_model = cognee_config.summarization_model,
|
||||||
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
|
task_config = { "batch_size": 10 }
|
||||||
run_tasks_parallel([
|
),
|
||||||
Task(
|
|
||||||
summarize_text,
|
|
||||||
summarization_model = cognee_config.summarization_model,
|
|
||||||
collection_name = "summaries",
|
|
||||||
),
|
|
||||||
Task(
|
|
||||||
chunk_naive_llm_classifier,
|
|
||||||
classification_model = cognee_config.classification_model,
|
|
||||||
),
|
|
||||||
]),
|
|
||||||
Task(chunk_remove_disconnected), # Remove the obsolete document chunks.
|
|
||||||
]
|
]
|
||||||
|
|
||||||
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")
|
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from cognee.shared.utils import send_telemetry
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||||
from cognee.tasks.chunking import query_chunks
|
from cognee.tasks.chunks import query_chunks
|
||||||
from cognee.tasks.graph import query_graph_connections
|
from cognee.tasks.graph import query_graph_connections
|
||||||
from cognee.tasks.summarization import query_summaries
|
from cognee.tasks.summarization import query_summaries
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,198 +0,0 @@
|
||||||
""" FalcorDB Adapter for Graph Database"""
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Optional, Any, List, Dict
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
|
|
||||||
|
|
||||||
from falkordb.asyncio import FalkorDB
|
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
|
||||||
|
|
||||||
logger = logging.getLogger("FalcorDBAdapter")
|
|
||||||
|
|
||||||
class FalcorDBAdapter(GraphDBInterface):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
graph_database_url: str,
|
|
||||||
graph_database_username: str,
|
|
||||||
graph_database_password: str,
|
|
||||||
graph_database_port: int,
|
|
||||||
driver: Optional[Any] = None,
|
|
||||||
graph_name: str = "DefaultGraph",
|
|
||||||
):
|
|
||||||
self.driver = FalkorDB(
|
|
||||||
host = graph_database_url,
|
|
||||||
port = graph_database_port)
|
|
||||||
self.graph_name = graph_name
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def query(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
try:
|
|
||||||
selected_graph = self.driver.select_graph(self.graph_name)
|
|
||||||
|
|
||||||
result = await selected_graph.query(query)
|
|
||||||
return result.result_set
|
|
||||||
|
|
||||||
except Exception as error:
|
|
||||||
logger.error("Falkor query error: %s", error, exc_info = True)
|
|
||||||
raise error
|
|
||||||
|
|
||||||
async def graph(self):
|
|
||||||
return self.driver
|
|
||||||
|
|
||||||
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
|
|
||||||
node_id = node_id.replace(":", "_")
|
|
||||||
|
|
||||||
serialized_properties = self.serialize_properties(node_properties)
|
|
||||||
|
|
||||||
if "name" not in serialized_properties:
|
|
||||||
serialized_properties["name"] = node_id
|
|
||||||
|
|
||||||
# serialized_properties["created_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
# serialized_properties["updated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
|
|
||||||
# properties = ", ".join(f"{property_name}: ${property_name}" for property_name in serialized_properties.keys())
|
|
||||||
|
|
||||||
query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
|
|
||||||
ON CREATE SET node += $properties
|
|
||||||
RETURN ID(node) AS internal_id, node.id AS nodeId"""
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"node_id": node_id,
|
|
||||||
"properties": serialized_properties,
|
|
||||||
}
|
|
||||||
|
|
||||||
return await self.query(query, params)
|
|
||||||
|
|
||||||
async def add_nodes(self, nodes: list[tuple[str, dict[str, Any]]]) -> None:
|
|
||||||
for node in nodes:
|
|
||||||
node_id, node_properties = node
|
|
||||||
node_id = node_id.replace(":", "_")
|
|
||||||
|
|
||||||
await self.add_node(
|
|
||||||
node_id = node_id,
|
|
||||||
node_properties = node_properties,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_node_description(self, node_id: str):
|
|
||||||
query = """MATCH (n)-[r]->(m)
|
|
||||||
WHERE n.id = $node_id
|
|
||||||
AND NOT m.id CONTAINS 'DefaultGraphModel'
|
|
||||||
RETURN m
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = await self.query(query, dict(node_id = node_id))
|
|
||||||
|
|
||||||
descriptions = []
|
|
||||||
|
|
||||||
for node in result:
|
|
||||||
# Assuming 'm' is a consistent key in your data structure
|
|
||||||
attributes = node.get("m", {})
|
|
||||||
|
|
||||||
# Ensure all required attributes are present
|
|
||||||
if all(key in attributes for key in ["id", "layer_id", "description"]):
|
|
||||||
descriptions.append({
|
|
||||||
"id": attributes["id"],
|
|
||||||
"layer_id": attributes["layer_id"],
|
|
||||||
"description": attributes["description"],
|
|
||||||
})
|
|
||||||
|
|
||||||
return descriptions
|
|
||||||
|
|
||||||
async def get_layer_nodes(self):
|
|
||||||
query = """MATCH (node) WHERE node.layer_id IS NOT NULL
|
|
||||||
RETURN node"""
|
|
||||||
|
|
||||||
return [result["node"] for result in (await self.query(query))]
|
|
||||||
|
|
||||||
async def extract_node(self, node_id: str):
|
|
||||||
results = self.extract_nodes([node_id])
|
|
||||||
|
|
||||||
return results[0] if len(results) > 0 else None
|
|
||||||
|
|
||||||
async def extract_nodes(self, node_ids: List[str]):
|
|
||||||
query = """
|
|
||||||
UNWIND $node_ids AS id
|
|
||||||
MATCH (node {id: id})
|
|
||||||
RETURN node"""
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"node_ids": node_ids
|
|
||||||
}
|
|
||||||
|
|
||||||
results = await self.query(query, params)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
|
||||||
node_id = id.replace(":", "_")
|
|
||||||
|
|
||||||
query = f"MATCH (node:`{node_id}` {{id: $node_id}}) DETACH DELETE n"
|
|
||||||
params = { "node_id": node_id }
|
|
||||||
|
|
||||||
return await self.query(query, params)
|
|
||||||
|
|
||||||
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
|
|
||||||
serialized_properties = self.serialize_properties(edge_properties)
|
|
||||||
from_node = from_node.replace(":", "_")
|
|
||||||
to_node = to_node.replace(":", "_")
|
|
||||||
|
|
||||||
query = f"""MATCH (from_node:`{from_node}` {{id: $from_node}}), (to_node:`{to_node}` {{id: $to_node}})
|
|
||||||
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
|
|
||||||
SET r += $properties
|
|
||||||
RETURN r"""
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"from_node": from_node,
|
|
||||||
"to_node": to_node,
|
|
||||||
"properties": serialized_properties
|
|
||||||
}
|
|
||||||
|
|
||||||
return await self.query(query, params)
|
|
||||||
|
|
||||||
|
|
||||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
|
||||||
# edges_data = []
|
|
||||||
|
|
||||||
for edge in edges:
|
|
||||||
from_node, to_node, relationship_name, edge_properties = edge
|
|
||||||
from_node = from_node.replace(":", "_")
|
|
||||||
to_node = to_node.replace(":", "_")
|
|
||||||
|
|
||||||
await self.add_edge(
|
|
||||||
from_node = from_node,
|
|
||||||
to_node = to_node,
|
|
||||||
relationship_name = relationship_name,
|
|
||||||
edge_properties = edge_properties
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def filter_nodes(self, search_criteria):
|
|
||||||
query = f"""MATCH (node)
|
|
||||||
WHERE node.id CONTAINS '{search_criteria}'
|
|
||||||
RETURN node"""
|
|
||||||
|
|
||||||
|
|
||||||
return await self.query(query)
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_graph(self):
|
|
||||||
query = """MATCH (node)
|
|
||||||
DETACH DELETE node;"""
|
|
||||||
|
|
||||||
return await self.query(query)
|
|
||||||
|
|
||||||
def serialize_properties(self, properties = dict()):
|
|
||||||
return {
|
|
||||||
property_key: json.dumps(property_value)
|
|
||||||
if isinstance(property_value, (dict, list))
|
|
||||||
else property_value for property_key, property_value in properties.items()
|
|
||||||
}
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
from .config import get_graph_config
|
from .config import get_graph_config
|
||||||
from .graph_db_interface import GraphDBInterface
|
from .graph_db_interface import GraphDBInterface
|
||||||
from .networkx.adapter import NetworkXAdapter
|
|
||||||
|
|
||||||
|
|
||||||
async def get_graph_engine() -> GraphDBInterface :
|
async def get_graph_engine() -> GraphDBInterface :
|
||||||
|
|
@ -21,19 +20,19 @@ async def get_graph_engine() -> GraphDBInterface :
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif config.graph_database_provider == "falkorb":
|
elif config.graph_database_provider == "falkordb":
|
||||||
try:
|
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||||
from .falkordb.adapter import FalcorDBAdapter
|
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||||
|
|
||||||
return FalcorDBAdapter(
|
embedding_engine = get_embedding_engine()
|
||||||
graph_database_url = config.graph_database_url,
|
|
||||||
graph_database_username = config.graph_database_username,
|
|
||||||
graph_database_password = config.graph_database_password,
|
|
||||||
graph_database_port = config.graph_database_port
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
return FalkorDBAdapter(
|
||||||
|
database_url = config.graph_database_url,
|
||||||
|
database_port = config.graph_database_port,
|
||||||
|
embedding_engine = embedding_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .networkx.adapter import NetworkXAdapter
|
||||||
graph_client = NetworkXAdapter(filename = config.graph_file_path)
|
graph_client = NetworkXAdapter(filename = config.graph_file_path)
|
||||||
|
|
||||||
if graph_client.graph is None:
|
if graph_client.graph is None:
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from abc import abstractmethod
|
||||||
|
|
||||||
class GraphDBInterface(Protocol):
|
class GraphDBInterface(Protocol):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def graph(self):
|
async def query(self, query: str, params: dict):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
""" Neo4j Adapter for Graph Database"""
|
""" Neo4j Adapter for Graph Database"""
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from textwrap import dedent
|
||||||
from typing import Optional, Any, List, Dict
|
from typing import Optional, Any, List, Dict
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from uuid import UUID
|
||||||
from neo4j import AsyncSession
|
from neo4j import AsyncSession
|
||||||
from neo4j import AsyncGraphDatabase
|
from neo4j import AsyncGraphDatabase
|
||||||
from neo4j.exceptions import Neo4jError
|
from neo4j.exceptions import Neo4jError
|
||||||
from networkx import predecessor
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||||
|
|
||||||
logger = logging.getLogger("Neo4jAdapter")
|
logger = logging.getLogger("Neo4jAdapter")
|
||||||
|
|
@ -41,17 +42,13 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
async with self.get_session() as session:
|
async with self.get_session() as session:
|
||||||
result = await session.run(query, parameters=params)
|
result = await session.run(query, parameters = params)
|
||||||
data = await result.data()
|
data = await result.data()
|
||||||
await self.close()
|
|
||||||
return data
|
return data
|
||||||
except Neo4jError as error:
|
except Neo4jError as error:
|
||||||
logger.error("Neo4j query error: %s", error, exc_info = True)
|
logger.error("Neo4j query error: %s", error, exc_info = True)
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
async def graph(self):
|
|
||||||
return await self.get_session()
|
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
results = self.query(
|
results = self.query(
|
||||||
"""
|
"""
|
||||||
|
|
@ -63,73 +60,40 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
return results[0]["node_exists"] if len(results) > 0 else False
|
return results[0]["node_exists"] if len(results) > 0 else False
|
||||||
|
|
||||||
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
|
async def add_node(self, node: DataPoint):
|
||||||
node_id = node_id.replace(":", "_")
|
serialized_properties = self.serialize_properties(node.model_dump())
|
||||||
|
|
||||||
serialized_properties = self.serialize_properties(node_properties)
|
query = dedent("""MERGE (node {id: $node_id})
|
||||||
|
ON CREATE SET node += $properties, node.updated_at = timestamp()
|
||||||
if "name" not in serialized_properties:
|
ON MATCH SET node += $properties, node.updated_at = timestamp()
|
||||||
serialized_properties["name"] = node_id
|
RETURN ID(node) AS internal_id, node.id AS nodeId""")
|
||||||
|
|
||||||
query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
|
|
||||||
ON CREATE SET node += $properties
|
|
||||||
RETURN ID(node) AS internal_id, node.id AS nodeId"""
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"node_id": node_id,
|
"node_id": str(node.id),
|
||||||
"properties": serialized_properties,
|
"properties": serialized_properties,
|
||||||
}
|
}
|
||||||
|
|
||||||
return await self.query(query, params)
|
return await self.query(query, params)
|
||||||
|
|
||||||
async def add_nodes(self, nodes: list[tuple[str, dict[str, Any]]]) -> None:
|
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
||||||
query = """
|
query = """
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MERGE (n {id: node.node_id})
|
MERGE (n {id: node.node_id})
|
||||||
ON CREATE SET n += node.properties
|
ON CREATE SET n += node.properties, n.updated_at = timestamp()
|
||||||
|
ON MATCH SET n += node.properties, n.updated_at = timestamp()
|
||||||
WITH n, node.node_id AS label
|
WITH n, node.node_id AS label
|
||||||
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
|
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
|
||||||
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nodes = [{
|
nodes = [{
|
||||||
"node_id": node_id,
|
"node_id": str(node.id),
|
||||||
"properties": self.serialize_properties(node_properties),
|
"properties": self.serialize_properties(node.model_dump()),
|
||||||
} for (node_id, node_properties) in nodes]
|
} for node in nodes]
|
||||||
|
|
||||||
results = await self.query(query, dict(nodes = nodes))
|
results = await self.query(query, dict(nodes = nodes))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def extract_node_description(self, node_id: str):
|
|
||||||
query = """MATCH (n)-[r]->(m)
|
|
||||||
WHERE n.id = $node_id
|
|
||||||
AND NOT m.id CONTAINS 'DefaultGraphModel'
|
|
||||||
RETURN m
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = await self.query(query, dict(node_id = node_id))
|
|
||||||
|
|
||||||
descriptions = []
|
|
||||||
|
|
||||||
for node in result:
|
|
||||||
# Assuming 'm' is a consistent key in your data structure
|
|
||||||
attributes = node.get("m", {})
|
|
||||||
|
|
||||||
# Ensure all required attributes are present
|
|
||||||
if all(key in attributes for key in ["id", "layer_id", "description"]):
|
|
||||||
descriptions.append({
|
|
||||||
"id": attributes["id"],
|
|
||||||
"layer_id": attributes["layer_id"],
|
|
||||||
"description": attributes["description"],
|
|
||||||
})
|
|
||||||
|
|
||||||
return descriptions
|
|
||||||
|
|
||||||
async def get_layer_nodes(self):
|
|
||||||
query = """MATCH (node) WHERE node.layer_id IS NOT NULL
|
|
||||||
RETURN node"""
|
|
||||||
|
|
||||||
return [result["node"] for result in (await self.query(query))]
|
|
||||||
|
|
||||||
async def extract_node(self, node_id: str):
|
async def extract_node(self, node_id: str):
|
||||||
results = await self.extract_nodes([node_id])
|
results = await self.extract_nodes([node_id])
|
||||||
|
|
@ -170,13 +134,20 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return await self.query(query, params)
|
return await self.query(query, params)
|
||||||
|
|
||||||
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
|
||||||
query = f"""
|
query = """
|
||||||
MATCH (from_node:`{from_node}`)-[relationship:`{edge_label}`]->(to_node:`{to_node}`)
|
MATCH (from_node)-[relationship]->(to_node)
|
||||||
|
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
|
||||||
RETURN COUNT(relationship) > 0 AS edge_exists
|
RETURN COUNT(relationship) > 0 AS edge_exists
|
||||||
"""
|
"""
|
||||||
|
|
||||||
edge_exists = await self.query(query)
|
params = {
|
||||||
|
"from_node_id": str(from_node),
|
||||||
|
"to_node_id": str(to_node),
|
||||||
|
"edge_label": edge_label,
|
||||||
|
}
|
||||||
|
|
||||||
|
edge_exists = await self.query(query, params)
|
||||||
return edge_exists
|
return edge_exists
|
||||||
|
|
||||||
async def has_edges(self, edges):
|
async def has_edges(self, edges):
|
||||||
|
|
@ -190,8 +161,8 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
try:
|
try:
|
||||||
params = {
|
params = {
|
||||||
"edges": [{
|
"edges": [{
|
||||||
"from_node": edge[0],
|
"from_node": str(edge[0]),
|
||||||
"to_node": edge[1],
|
"to_node": str(edge[1]),
|
||||||
"relationship_name": edge[2],
|
"relationship_name": edge[2],
|
||||||
} for edge in edges],
|
} for edge in edges],
|
||||||
}
|
}
|
||||||
|
|
@ -203,21 +174,21 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
|
|
||||||
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
|
async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
|
||||||
serialized_properties = self.serialize_properties(edge_properties)
|
serialized_properties = self.serialize_properties(edge_properties)
|
||||||
from_node = from_node.replace(":", "_")
|
|
||||||
to_node = to_node.replace(":", "_")
|
|
||||||
|
|
||||||
query = f"""MATCH (from_node:`{from_node}`
|
query = dedent("""MATCH (from_node {id: $from_node}),
|
||||||
{{id: $from_node}}),
|
(to_node {id: $to_node})
|
||||||
(to_node:`{to_node}` {{id: $to_node}})
|
MERGE (from_node)-[r]->(to_node)
|
||||||
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
|
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
|
||||||
SET r += $properties
|
ON MATCH SET r += $properties, r.updated_at = timestamp()
|
||||||
RETURN r"""
|
RETURN r
|
||||||
|
""")
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"from_node": from_node,
|
"from_node": str(from_node),
|
||||||
"to_node": to_node,
|
"to_node": str(to_node),
|
||||||
|
"relationship_name": relationship_name,
|
||||||
"properties": serialized_properties
|
"properties": serialized_properties
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -234,13 +205,13 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
edges = [{
|
edges = [{
|
||||||
"from_node": edge[0],
|
"from_node": str(edge[0]),
|
||||||
"to_node": edge[1],
|
"to_node": str(edge[1]),
|
||||||
"relationship_name": edge[2],
|
"relationship_name": edge[2],
|
||||||
"properties": {
|
"properties": {
|
||||||
**(edge[3] if edge[3] else {}),
|
**(edge[3] if edge[3] else {}),
|
||||||
"source_node_id": edge[0],
|
"source_node_id": str(edge[0]),
|
||||||
"target_node_id": edge[1],
|
"target_node_id": str(edge[1]),
|
||||||
},
|
},
|
||||||
} for edge in edges]
|
} for edge in edges]
|
||||||
|
|
||||||
|
|
@ -300,14 +271,6 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
return results[0]["ids"] if len(results) > 0 else []
|
return results[0]["ids"] if len(results) > 0 else []
|
||||||
|
|
||||||
|
|
||||||
async def filter_nodes(self, search_criteria):
|
|
||||||
query = f"""MATCH (node)
|
|
||||||
WHERE node.id CONTAINS '{search_criteria}'
|
|
||||||
RETURN node"""
|
|
||||||
|
|
||||||
return await self.query(query)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
|
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||||
if edge_label is not None:
|
if edge_label is not None:
|
||||||
query = """
|
query = """
|
||||||
|
|
@ -379,7 +342,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return predecessors + successors
|
return predecessors + successors
|
||||||
|
|
||||||
async def get_connections(self, node_id: str) -> list:
|
async def get_connections(self, node_id: UUID) -> list:
|
||||||
predecessors_query = """
|
predecessors_query = """
|
||||||
MATCH (node)<-[relation]-(neighbour)
|
MATCH (node)<-[relation]-(neighbour)
|
||||||
WHERE node.id = $node_id
|
WHERE node.id = $node_id
|
||||||
|
|
@ -392,8 +355,8 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
predecessors, successors = await asyncio.gather(
|
predecessors, successors = await asyncio.gather(
|
||||||
self.query(predecessors_query, dict(node_id = node_id)),
|
self.query(predecessors_query, dict(node_id = str(node_id))),
|
||||||
self.query(successors_query, dict(node_id = node_id)),
|
self.query(successors_query, dict(node_id = str(node_id))),
|
||||||
)
|
)
|
||||||
|
|
||||||
connections = []
|
connections = []
|
||||||
|
|
@ -438,15 +401,22 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
return await self.query(query)
|
return await self.query(query)
|
||||||
|
|
||||||
def serialize_properties(self, properties = dict()):
|
def serialize_properties(self, properties = dict()):
|
||||||
return {
|
serialized_properties = {}
|
||||||
property_key: json.dumps(property_value)
|
|
||||||
if isinstance(property_value, (dict, list))
|
for property_key, property_value in properties.items():
|
||||||
else property_value for property_key, property_value in properties.items()
|
if isinstance(property_value, UUID):
|
||||||
}
|
serialized_properties[property_key] = str(property_value)
|
||||||
|
continue
|
||||||
|
|
||||||
|
serialized_properties[property_key] = property_value
|
||||||
|
|
||||||
|
return serialized_properties
|
||||||
|
|
||||||
async def get_graph_data(self):
|
async def get_graph_data(self):
|
||||||
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
||||||
|
|
||||||
result = await self.query(query)
|
result = await self.query(query)
|
||||||
|
|
||||||
nodes = [(
|
nodes = [(
|
||||||
record["properties"]["id"],
|
record["properties"]["id"],
|
||||||
record["properties"],
|
record["properties"],
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,19 @@
|
||||||
"""Adapter for NetworkX graph database."""
|
"""Adapter for NetworkX graph database."""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from re import A
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
|
from uuid import UUID
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiofiles.os as aiofiles_os
|
import aiofiles.os as aiofiles_os
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
|
|
||||||
logger = logging.getLogger("NetworkXAdapter")
|
logger = logging.getLogger("NetworkXAdapter")
|
||||||
|
|
||||||
|
|
@ -25,29 +30,38 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
def __init__(self, filename = "cognee_graph.pkl"):
|
def __init__(self, filename = "cognee_graph.pkl"):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
|
|
||||||
|
async def get_graph_data(self):
|
||||||
|
await self.load_graph_from_file()
|
||||||
|
return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True)))
|
||||||
|
|
||||||
|
async def query(self, query: str, params: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
return self.graph.has_node(node_id)
|
return self.graph.has_node(node_id)
|
||||||
|
|
||||||
async def add_node(
|
async def add_node(
|
||||||
self,
|
self,
|
||||||
node_id: str,
|
node: DataPoint,
|
||||||
node_properties,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.graph.has_node(id):
|
self.graph.add_node(node.id, **node.model_dump())
|
||||||
self.graph.add_node(node_id, **node_properties)
|
|
||||||
await self.save_graph_to_file(self.filename)
|
await self.save_graph_to_file(self.filename)
|
||||||
|
|
||||||
async def add_nodes(
|
async def add_nodes(
|
||||||
self,
|
self,
|
||||||
nodes: List[tuple[str, dict]],
|
nodes: list[DataPoint],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
nodes = [(node.id, node.model_dump()) for node in nodes]
|
||||||
|
|
||||||
self.graph.add_nodes_from(nodes)
|
self.graph.add_nodes_from(nodes)
|
||||||
await self.save_graph_to_file(self.filename)
|
await self.save_graph_to_file(self.filename)
|
||||||
|
|
||||||
|
|
||||||
async def get_graph(self):
|
async def get_graph(self):
|
||||||
return self.graph
|
return self.graph
|
||||||
|
|
||||||
|
|
||||||
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
||||||
return self.graph.has_edge(from_node, to_node, key = edge_label)
|
return self.graph.has_edge(from_node, to_node, key = edge_label)
|
||||||
|
|
||||||
|
|
@ -55,18 +69,20 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
for (from_node, to_node, edge_label) in edges:
|
for (from_node, to_node, edge_label) in edges:
|
||||||
if await self.has_edge(from_node, to_node, edge_label):
|
if self.graph.has_edge(from_node, to_node, edge_label):
|
||||||
result.append((from_node, to_node, edge_label))
|
result.append((from_node, to_node, edge_label))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def add_edge(
|
async def add_edge(
|
||||||
self,
|
self,
|
||||||
from_node: str,
|
from_node: str,
|
||||||
to_node: str,
|
to_node: str,
|
||||||
relationship_name: str,
|
relationship_name: str,
|
||||||
edge_properties: Dict[str, Any] = None,
|
edge_properties: Dict[str, Any] = {},
|
||||||
) -> None:
|
) -> None:
|
||||||
|
edge_properties["updated_at"] = datetime.now(timezone.utc)
|
||||||
self.graph.add_edge(from_node, to_node, key = relationship_name, **(edge_properties if edge_properties else {}))
|
self.graph.add_edge(from_node, to_node, key = relationship_name, **(edge_properties if edge_properties else {}))
|
||||||
await self.save_graph_to_file(self.filename)
|
await self.save_graph_to_file(self.filename)
|
||||||
|
|
||||||
|
|
@ -74,22 +90,29 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
self,
|
self,
|
||||||
edges: tuple[str, str, str, dict],
|
edges: tuple[str, str, str, dict],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
edges = [(edge[0], edge[1], edge[2], {
|
||||||
|
**(edge[3] if len(edge) == 4 else {}),
|
||||||
|
"updated_at": datetime.now(timezone.utc),
|
||||||
|
}) for edge in edges]
|
||||||
|
|
||||||
self.graph.add_edges_from(edges)
|
self.graph.add_edges_from(edges)
|
||||||
await self.save_graph_to_file(self.filename)
|
await self.save_graph_to_file(self.filename)
|
||||||
|
|
||||||
async def get_edges(self, node_id: str):
|
async def get_edges(self, node_id: str):
|
||||||
return list(self.graph.in_edges(node_id, data = True)) + list(self.graph.out_edges(node_id, data = True))
|
return list(self.graph.in_edges(node_id, data = True)) + list(self.graph.out_edges(node_id, data = True))
|
||||||
|
|
||||||
|
|
||||||
async def delete_node(self, node_id: str) -> None:
|
async def delete_node(self, node_id: str) -> None:
|
||||||
"""Asynchronously delete a node from the graph if it exists."""
|
"""Asynchronously delete a node from the graph if it exists."""
|
||||||
if self.graph.has_node(id):
|
if self.graph.has_node(node_id):
|
||||||
self.graph.remove_node(id)
|
self.graph.remove_node(node_id)
|
||||||
await self.save_graph_to_file(self.filename)
|
await self.save_graph_to_file(self.filename)
|
||||||
|
|
||||||
async def delete_nodes(self, node_ids: List[str]) -> None:
|
async def delete_nodes(self, node_ids: List[str]) -> None:
|
||||||
self.graph.remove_nodes_from(node_ids)
|
self.graph.remove_nodes_from(node_ids)
|
||||||
await self.save_graph_to_file(self.filename)
|
await self.save_graph_to_file(self.filename)
|
||||||
|
|
||||||
|
|
||||||
async def get_disconnected_nodes(self) -> List[str]:
|
async def get_disconnected_nodes(self) -> List[str]:
|
||||||
connected_components = list(nx.weakly_connected_components(self.graph))
|
connected_components = list(nx.weakly_connected_components(self.graph))
|
||||||
|
|
||||||
|
|
@ -102,33 +125,6 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return disconnected_nodes
|
return disconnected_nodes
|
||||||
|
|
||||||
async def extract_node_description(self, node_id: str) -> Dict[str, Any]:
|
|
||||||
descriptions = []
|
|
||||||
|
|
||||||
if self.graph.has_node(node_id):
|
|
||||||
# Get the attributes of the node
|
|
||||||
for neighbor in self.graph.neighbors(node_id):
|
|
||||||
# Get the attributes of the neighboring node
|
|
||||||
attributes = self.graph.nodes[neighbor]
|
|
||||||
|
|
||||||
# Ensure all required attributes are present before extracting description
|
|
||||||
if all(key in attributes for key in ["id", "layer_id", "description"]):
|
|
||||||
descriptions.append({
|
|
||||||
"id": attributes["id"],
|
|
||||||
"layer_id": attributes["layer_id"],
|
|
||||||
"description": attributes["description"],
|
|
||||||
})
|
|
||||||
|
|
||||||
return descriptions
|
|
||||||
|
|
||||||
async def get_layer_nodes(self):
|
|
||||||
layer_nodes = []
|
|
||||||
|
|
||||||
for _, data in self.graph.nodes(data = True):
|
|
||||||
if "layer_id" in data:
|
|
||||||
layer_nodes.append(data)
|
|
||||||
|
|
||||||
return layer_nodes
|
|
||||||
|
|
||||||
async def extract_node(self, node_id: str) -> dict:
|
async def extract_node(self, node_id: str) -> dict:
|
||||||
if self.graph.has_node(node_id):
|
if self.graph.has_node(node_id):
|
||||||
|
|
@ -139,7 +135,7 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
async def extract_nodes(self, node_ids: List[str]) -> List[dict]:
|
async def extract_nodes(self, node_ids: List[str]) -> List[dict]:
|
||||||
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
|
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
|
||||||
|
|
||||||
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list:
|
async def get_predecessors(self, node_id: UUID, edge_label: str = None) -> list:
|
||||||
if self.graph.has_node(node_id):
|
if self.graph.has_node(node_id):
|
||||||
if edge_label is None:
|
if edge_label is None:
|
||||||
return [
|
return [
|
||||||
|
|
@ -155,7 +151,7 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
async def get_successors(self, node_id: str, edge_label: str = None) -> list:
|
async def get_successors(self, node_id: UUID, edge_label: str = None) -> list:
|
||||||
if self.graph.has_node(node_id):
|
if self.graph.has_node(node_id):
|
||||||
if edge_label is None:
|
if edge_label is None:
|
||||||
return [
|
return [
|
||||||
|
|
@ -184,13 +180,13 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return neighbours
|
return neighbours
|
||||||
|
|
||||||
async def get_connections(self, node_id: str) -> list:
|
async def get_connections(self, node_id: UUID) -> list:
|
||||||
if not self.graph.has_node(node_id):
|
if not self.graph.has_node(node_id):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
node = self.graph.nodes[node_id]
|
node = self.graph.nodes[node_id]
|
||||||
|
|
||||||
if "uuid" not in node:
|
if "id" not in node:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
predecessors, successors = await asyncio.gather(
|
predecessors, successors = await asyncio.gather(
|
||||||
|
|
@ -201,14 +197,14 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
connections = []
|
connections = []
|
||||||
|
|
||||||
for neighbor in predecessors:
|
for neighbor in predecessors:
|
||||||
if "uuid" in neighbor:
|
if "id" in neighbor:
|
||||||
edge_data = self.graph.get_edge_data(neighbor["uuid"], node["uuid"])
|
edge_data = self.graph.get_edge_data(neighbor["id"], node["id"])
|
||||||
for edge_properties in edge_data.values():
|
for edge_properties in edge_data.values():
|
||||||
connections.append((neighbor, edge_properties, node))
|
connections.append((neighbor, edge_properties, node))
|
||||||
|
|
||||||
for neighbor in successors:
|
for neighbor in successors:
|
||||||
if "uuid" in neighbor:
|
if "id" in neighbor:
|
||||||
edge_data = self.graph.get_edge_data(node["uuid"], neighbor["uuid"])
|
edge_data = self.graph.get_edge_data(node["id"], neighbor["id"])
|
||||||
for edge_properties in edge_data.values():
|
for edge_properties in edge_data.values():
|
||||||
connections.append((node, edge_properties, neighbor))
|
connections.append((node, edge_properties, neighbor))
|
||||||
|
|
||||||
|
|
@ -240,7 +236,7 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
graph_data = nx.readwrite.json_graph.node_link_data(self.graph)
|
graph_data = nx.readwrite.json_graph.node_link_data(self.graph)
|
||||||
|
|
||||||
async with aiofiles.open(file_path, "w") as file:
|
async with aiofiles.open(file_path, "w") as file:
|
||||||
await file.write(json.dumps(graph_data))
|
await file.write(json.dumps(graph_data, cls = JSONEncoder))
|
||||||
|
|
||||||
|
|
||||||
async def load_graph_from_file(self, file_path: str = None):
|
async def load_graph_from_file(self, file_path: str = None):
|
||||||
|
|
@ -254,6 +250,29 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
async with aiofiles.open(file_path, "r") as file:
|
async with aiofiles.open(file_path, "r") as file:
|
||||||
graph_data = json.loads(await file.read())
|
graph_data = json.loads(await file.read())
|
||||||
|
for node in graph_data["nodes"]:
|
||||||
|
try:
|
||||||
|
node["id"] = UUID(node["id"])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if "updated_at" in node:
|
||||||
|
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||||
|
|
||||||
|
for edge in graph_data["links"]:
|
||||||
|
try:
|
||||||
|
source_id = UUID(edge["source"])
|
||||||
|
target_id = UUID(edge["target"])
|
||||||
|
|
||||||
|
edge["source"] = source_id
|
||||||
|
edge["target"] = target_id
|
||||||
|
edge["source_node_id"] = source_id
|
||||||
|
edge["target_node_id"] = target_id
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if "updated_at" in edge:
|
||||||
|
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||||
|
|
||||||
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
|
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
|
||||||
else:
|
else:
|
||||||
# Log that the file does not exist and an empty graph is initialized
|
# Log that the file does not exist and an empty graph is initialized
|
||||||
|
|
@ -265,9 +284,11 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
os.makedirs(file_dir, exist_ok = True)
|
os.makedirs(file_dir, exist_ok = True)
|
||||||
|
|
||||||
await self.save_graph_to_file(file_path)
|
await self.save_graph_to_file(file_path)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Failed to load graph from file: %s", file_path)
|
logger.error("Failed to load graph from file: %s", file_path)
|
||||||
|
|
||||||
|
|
||||||
async def delete_graph(self, file_path: str = None):
|
async def delete_graph(self, file_path: str = None):
|
||||||
"""Asynchronously delete the graph file from the filesystem."""
|
"""Asynchronously delete the graph file from the filesystem."""
|
||||||
if file_path is None:
|
if file_path is None:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,267 @@
|
||||||
|
import asyncio
|
||||||
|
from textwrap import dedent
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
from falkordb import FalkorDB
|
||||||
|
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||||
|
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
|
||||||
|
from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface
|
||||||
|
|
||||||
|
class IndexSchema(DataPoint):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
_metadata: dict = {
|
||||||
|
"index_fields": ["text"]
|
||||||
|
}
|
||||||
|
|
||||||
|
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
database_url: str,
|
||||||
|
database_port: int,
|
||||||
|
embedding_engine = EmbeddingEngine,
|
||||||
|
):
|
||||||
|
self.driver = FalkorDB(
|
||||||
|
host = database_url,
|
||||||
|
port = database_port,
|
||||||
|
)
|
||||||
|
self.embedding_engine = embedding_engine
|
||||||
|
self.graph_name = "cognee_graph"
|
||||||
|
|
||||||
|
def query(self, query: str, params: dict = {}):
|
||||||
|
graph = self.driver.select_graph(self.graph_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = graph.query(query, params)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error executing query: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||||
|
return await self.embedding_engine.embed_text(data)
|
||||||
|
|
||||||
|
async def stringify_properties(self, properties: dict, vectorize_fields = []) -> str:
|
||||||
|
async def get_value(key, value):
|
||||||
|
return f"'{value}'" if key not in vectorize_fields else await self.get_vectorized_value(value)
|
||||||
|
|
||||||
|
return ",".join([f"{key}:{await get_value(key, value)}" for key, value in properties.items()])
|
||||||
|
|
||||||
|
async def get_vectorized_value(self, value: Any) -> str:
|
||||||
|
vector = (await self.embed_data([value]))[0]
|
||||||
|
return f"vecf32({vector})"
|
||||||
|
|
||||||
|
async def create_data_point_query(self, data_point: DataPoint):
|
||||||
|
node_label = type(data_point).__name__
|
||||||
|
node_properties = await self.stringify_properties(
|
||||||
|
data_point.model_dump(),
|
||||||
|
data_point._metadata["index_fields"],
|
||||||
|
# data_point._metadata["index_fields"] if hasattr(data_point, "_metadata") else [],
|
||||||
|
)
|
||||||
|
|
||||||
|
return dedent(f"""
|
||||||
|
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
|
||||||
|
ON CREATE SET node += ({{{node_properties}}})
|
||||||
|
ON CREATE SET node.updated_at = timestamp()
|
||||||
|
ON MATCH SET node += ({{{node_properties}}})
|
||||||
|
ON MATCH SET node.updated_at = timestamp()
|
||||||
|
""").strip()
|
||||||
|
|
||||||
|
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
|
||||||
|
properties = await self.stringify_properties(edge[3])
|
||||||
|
properties = f"{{{properties}}}"
|
||||||
|
|
||||||
|
return dedent(f"""
|
||||||
|
MERGE (source {{id:'{edge[0]}'}})
|
||||||
|
MERGE (target {{id: '{edge[1]}'}})
|
||||||
|
MERGE (source)-[edge:{edge[2]} {properties}]->(target)
|
||||||
|
ON MATCH SET edge.updated_at = timestamp()
|
||||||
|
ON CREATE SET edge.updated_at = timestamp()
|
||||||
|
""").strip()
|
||||||
|
|
||||||
|
async def create_collection(self, collection_name: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def has_collection(self, collection_name: str) -> bool:
|
||||||
|
collections = self.driver.list_graphs()
|
||||||
|
|
||||||
|
return collection_name in collections
|
||||||
|
|
||||||
|
async def create_data_points(self, data_points: list[DataPoint]):
|
||||||
|
queries = [await self.create_data_point_query(data_point) for data_point in data_points]
|
||||||
|
for query in queries:
|
||||||
|
self.query(query)
|
||||||
|
|
||||||
|
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||||
|
graph = self.driver.select_graph(self.graph_name)
|
||||||
|
|
||||||
|
if not self.has_vector_index(graph, index_name, index_property_name):
|
||||||
|
graph.create_node_vector_index(index_name, index_property_name, dim = self.embedding_engine.get_vector_size())
|
||||||
|
|
||||||
|
def has_vector_index(self, graph, index_name: str, index_property_name: str) -> bool:
|
||||||
|
try:
|
||||||
|
indices = graph.list_indices()
|
||||||
|
|
||||||
|
return any([(index[0] == index_name and index_property_name in index[1]) for index in indices.result_set])
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def add_node(self, node: DataPoint):
|
||||||
|
await self.create_data_points([node])
|
||||||
|
|
||||||
|
async def add_nodes(self, nodes: list[DataPoint]):
|
||||||
|
await self.create_data_points(nodes)
|
||||||
|
|
||||||
|
async def add_edge(self, edge: tuple[str, str, str, dict]):
|
||||||
|
query = await self.create_edge_query(edge)
|
||||||
|
|
||||||
|
self.query(query)
|
||||||
|
|
||||||
|
async def add_edges(self, edges: list[tuple[str, str, str, dict]]):
|
||||||
|
queries = [await self.create_edge_query(edge) for edge in edges]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
self.query(query)
|
||||||
|
|
||||||
|
async def has_edges(self, edges):
|
||||||
|
query = dedent("""
|
||||||
|
UNWIND $edges AS edge
|
||||||
|
MATCH (a)-[r]->(b)
|
||||||
|
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
|
||||||
|
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
|
||||||
|
""").strip()
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"edges": [{
|
||||||
|
"from_node": str(edge[0]),
|
||||||
|
"to_node": str(edge[1]),
|
||||||
|
"relationship_name": edge[2],
|
||||||
|
} for edge in edges],
|
||||||
|
}
|
||||||
|
|
||||||
|
results = self.query(query, params).result_set
|
||||||
|
|
||||||
|
return [result["edge_exists"] for result in results]
|
||||||
|
|
||||||
|
async def retrieve(self, data_point_ids: list[str]):
|
||||||
|
return self.query(
|
||||||
|
f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
|
||||||
|
{
|
||||||
|
"node_ids": data_point_ids,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def extract_node(self, data_point_id: str):
|
||||||
|
return await self.retrieve([data_point_id])
|
||||||
|
|
||||||
|
async def extract_nodes(self, data_point_ids: list[str]):
|
||||||
|
return await self.retrieve(data_point_ids)
|
||||||
|
|
||||||
|
async def get_connections(self, node_id: UUID) -> list:
|
||||||
|
predecessors_query = """
|
||||||
|
MATCH (node)<-[relation]-(neighbour)
|
||||||
|
WHERE node.id = $node_id
|
||||||
|
RETURN neighbour, relation, node
|
||||||
|
"""
|
||||||
|
successors_query = """
|
||||||
|
MATCH (node)-[relation]->(neighbour)
|
||||||
|
WHERE node.id = $node_id
|
||||||
|
RETURN node, relation, neighbour
|
||||||
|
"""
|
||||||
|
|
||||||
|
predecessors, successors = await asyncio.gather(
|
||||||
|
self.query(predecessors_query, dict(node_id = node_id)),
|
||||||
|
self.query(successors_query, dict(node_id = node_id)),
|
||||||
|
)
|
||||||
|
|
||||||
|
connections = []
|
||||||
|
|
||||||
|
for neighbour in predecessors:
|
||||||
|
neighbour = neighbour["relation"]
|
||||||
|
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
|
||||||
|
|
||||||
|
for neighbour in successors:
|
||||||
|
neighbour = neighbour["relation"]
|
||||||
|
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
|
||||||
|
|
||||||
|
return connections
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
query_text: str = None,
|
||||||
|
query_vector: list[float] = None,
|
||||||
|
limit: int = 10,
|
||||||
|
with_vector: bool = False,
|
||||||
|
):
|
||||||
|
if query_text is None and query_vector is None:
|
||||||
|
raise ValueError("One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
|
if query_text and not query_vector:
|
||||||
|
query_vector = (await self.embed_data([query_text]))[0]
|
||||||
|
|
||||||
|
query = dedent(f"""
|
||||||
|
CALL db.idx.vector.queryNodes(
|
||||||
|
{collection_name},
|
||||||
|
'text',
|
||||||
|
{limit},
|
||||||
|
vecf32({query_vector})
|
||||||
|
) YIELD node, score
|
||||||
|
""").strip()
|
||||||
|
|
||||||
|
result = self.query(query)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def batch_search(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
query_texts: list[str],
|
||||||
|
limit: int = None,
|
||||||
|
with_vectors: bool = False,
|
||||||
|
):
|
||||||
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||||
|
|
||||||
|
return await asyncio.gather(
|
||||||
|
*[self.search(
|
||||||
|
collection_name = collection_name,
|
||||||
|
query_vector = query_vector,
|
||||||
|
limit = limit,
|
||||||
|
with_vector = with_vectors,
|
||||||
|
) for query_vector in query_vectors]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||||
|
return self.query(
|
||||||
|
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
|
||||||
|
{
|
||||||
|
"node_ids": data_point_ids,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_node(self, collection_name: str, data_point_id: str):
|
||||||
|
return await self.delete_data_points([data_point_id])
|
||||||
|
|
||||||
|
async def delete_nodes(self, collection_name: str, data_point_ids: list[str]):
|
||||||
|
self.delete_data_points(data_point_ids)
|
||||||
|
|
||||||
|
async def delete_graph(self):
|
||||||
|
try:
|
||||||
|
graph = self.driver.select_graph(self.graph_name)
|
||||||
|
|
||||||
|
indices = graph.list_indices()
|
||||||
|
for index in indices.result_set:
|
||||||
|
for field in index[1]:
|
||||||
|
graph.drop_node_vector_index(index[0], field)
|
||||||
|
|
||||||
|
graph.delete()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting graph: {e}")
|
||||||
|
|
||||||
|
async def prune(self):
|
||||||
|
self.delete_graph()
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
from .models.DataPoint import DataPoint
|
|
||||||
from .models.VectorConfig import VectorConfig
|
from .models.VectorConfig import VectorConfig
|
||||||
from .models.CollectionConfig import CollectionConfig
|
from .models.CollectionConfig import CollectionConfig
|
||||||
from .vector_db_interface import VectorDBInterface
|
from .vector_db_interface import VectorDBInterface
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ class VectorConfig(BaseSettings):
|
||||||
os.path.join(get_absolute_path(".cognee_system"), "databases"),
|
os.path.join(get_absolute_path(".cognee_system"), "databases"),
|
||||||
"cognee.lancedb"
|
"cognee.lancedb"
|
||||||
)
|
)
|
||||||
|
vector_db_port: int = 1234
|
||||||
vector_db_key: str = ""
|
vector_db_key: str = ""
|
||||||
vector_db_provider: str = "lancedb"
|
vector_db_provider: str = "lancedb"
|
||||||
|
|
||||||
|
|
@ -16,6 +17,7 @@ class VectorConfig(BaseSettings):
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
return {
|
return {
|
||||||
"vector_db_url": self.vector_db_url,
|
"vector_db_url": self.vector_db_url,
|
||||||
|
"vector_db_port": self.vector_db_port,
|
||||||
"vector_db_key": self.vector_db_key,
|
"vector_db_key": self.vector_db_key,
|
||||||
"vector_db_provider": self.vector_db_provider,
|
"vector_db_provider": self.vector_db_provider,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from ..relational.config import get_relational_config
|
|
||||||
|
|
||||||
class VectorConfig(Dict):
|
class VectorConfig(Dict):
|
||||||
vector_db_url: str
|
vector_db_url: str
|
||||||
|
vector_db_port: str
|
||||||
vector_db_key: str
|
vector_db_key: str
|
||||||
vector_db_provider: str
|
vector_db_provider: str
|
||||||
|
|
||||||
|
|
@ -29,6 +28,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
embedding_engine = embedding_engine
|
embedding_engine = embedding_engine
|
||||||
)
|
)
|
||||||
elif config["vector_db_provider"] == "pgvector":
|
elif config["vector_db_provider"] == "pgvector":
|
||||||
|
from cognee.infrastructure.databases.relational import get_relational_config
|
||||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||||
|
|
||||||
# Get configuration for postgres database
|
# Get configuration for postgres database
|
||||||
|
|
@ -43,9 +43,18 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return PGVectorAdapter(connection_string,
|
return PGVectorAdapter(
|
||||||
config["vector_db_key"],
|
connection_string,
|
||||||
embedding_engine
|
config["vector_db_key"],
|
||||||
|
embedding_engine,
|
||||||
|
)
|
||||||
|
elif config["vector_db_provider"] == "falkordb":
|
||||||
|
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||||
|
|
||||||
|
return FalkorDBAdapter(
|
||||||
|
database_url = config["vector_db_url"],
|
||||||
|
database_port = config["vector_db_port"],
|
||||||
|
embedding_engine = embedding_engine,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||||
|
|
|
||||||
|
|
@ -1,57 +0,0 @@
|
||||||
|
|
||||||
from typing import List, Dict, Optional, Any
|
|
||||||
|
|
||||||
from falkordb import FalkorDB
|
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
|
||||||
from ..vector_db_interface import VectorDBInterface
|
|
||||||
from ..models.DataPoint import DataPoint
|
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FalcorDBAdapter(VectorDBInterface):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
graph_database_url: str,
|
|
||||||
graph_database_username: str,
|
|
||||||
graph_database_password: str,
|
|
||||||
graph_database_port: int,
|
|
||||||
driver: Optional[Any] = None,
|
|
||||||
embedding_engine = EmbeddingEngine,
|
|
||||||
graph_name: str = "DefaultGraph",
|
|
||||||
):
|
|
||||||
self.driver = FalkorDB(
|
|
||||||
host = graph_database_url,
|
|
||||||
port = graph_database_port)
|
|
||||||
self.graph_name = graph_name
|
|
||||||
self.embedding_engine = embedding_engine
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
|
||||||
return await self.embedding_engine.embed_text(data)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema = None):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def search(
|
|
||||||
self,
|
|
||||||
collection_name: str,
|
|
||||||
query_text: str = None,
|
|
||||||
query_vector: List[float] = None,
|
|
||||||
limit: int = 10,
|
|
||||||
with_vector: bool = False,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
|
||||||
pass
|
|
||||||
|
|
@ -1,12 +1,25 @@
|
||||||
|
import inspect
|
||||||
from typing import List, Optional, get_type_hints, Generic, TypeVar
|
from typing import List, Optional, get_type_hints, Generic, TypeVar
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from uuid import UUID
|
||||||
import lancedb
|
import lancedb
|
||||||
|
from pydantic import BaseModel
|
||||||
from lancedb.pydantic import Vector, LanceModel
|
from lancedb.pydantic import Vector, LanceModel
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
|
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
from ..vector_db_interface import VectorDBInterface, DataPoint
|
from ..vector_db_interface import VectorDBInterface
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
|
class IndexSchema(DataPoint):
|
||||||
|
id: str
|
||||||
|
text: str
|
||||||
|
|
||||||
|
_metadata: dict = {
|
||||||
|
"index_fields": ["text"]
|
||||||
|
}
|
||||||
|
|
||||||
class LanceDBAdapter(VectorDBInterface):
|
class LanceDBAdapter(VectorDBInterface):
|
||||||
name = "LanceDB"
|
name = "LanceDB"
|
||||||
url: str
|
url: str
|
||||||
|
|
@ -38,10 +51,12 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
collection_names = await connection.table_names()
|
collection_names = await connection.table_names()
|
||||||
return collection_name in collection_names
|
return collection_name in collection_names
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema = None):
|
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
|
||||||
data_point_types = get_type_hints(DataPoint)
|
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
|
payload_schema = self.get_data_point_schema(payload_schema)
|
||||||
|
data_point_types = get_type_hints(payload_schema)
|
||||||
|
|
||||||
class LanceDataPoint(LanceModel):
|
class LanceDataPoint(LanceModel):
|
||||||
id: data_point_types["id"]
|
id: data_point_types["id"]
|
||||||
vector: Vector(vector_size)
|
vector: Vector(vector_size)
|
||||||
|
|
@ -55,13 +70,16 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
exist_ok = True,
|
exist_ok = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
||||||
connection = await self.get_connection()
|
connection = await self.get_connection()
|
||||||
|
|
||||||
|
payload_schema = type(data_points[0])
|
||||||
|
payload_schema = self.get_data_point_schema(payload_schema)
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
await self.create_collection(
|
await self.create_collection(
|
||||||
collection_name,
|
collection_name,
|
||||||
payload_schema = type(data_points[0].payload),
|
payload_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
collection = await connection.open_table(collection_name)
|
collection = await connection.open_table(collection_name)
|
||||||
|
|
@ -79,15 +97,26 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
vector: Vector(vector_size)
|
vector: Vector(vector_size)
|
||||||
payload: PayloadSchema
|
payload: PayloadSchema
|
||||||
|
|
||||||
|
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
|
||||||
|
properties = get_own_properties(data_point)
|
||||||
|
properties["id"] = str(properties["id"])
|
||||||
|
|
||||||
|
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
|
||||||
|
id = str(data_point.id),
|
||||||
|
vector = vector,
|
||||||
|
payload = properties,
|
||||||
|
)
|
||||||
|
|
||||||
lance_data_points = [
|
lance_data_points = [
|
||||||
LanceDataPoint[type(data_point.id), type(data_point.payload)](
|
create_lance_data_point(data_point, data_vectors[data_point_index])
|
||||||
id = data_point.id,
|
for (data_point_index, data_point) in enumerate(data_points)
|
||||||
vector = data_vectors[data_index],
|
|
||||||
payload = data_point.payload,
|
|
||||||
) for (data_index, data_point) in enumerate(data_points)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
await collection.add(lance_data_points)
|
await collection.merge_insert("id") \
|
||||||
|
.when_matched_update_all() \
|
||||||
|
.when_not_matched_insert_all() \
|
||||||
|
.execute(lance_data_points)
|
||||||
|
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||||
connection = await self.get_connection()
|
connection = await self.get_connection()
|
||||||
|
|
@ -99,7 +128,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
|
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
|
||||||
|
|
||||||
return [ScoredResult(
|
return [ScoredResult(
|
||||||
id = result["id"],
|
id = UUID(result["id"]),
|
||||||
payload = result["payload"],
|
payload = result["payload"],
|
||||||
score = 0,
|
score = 0,
|
||||||
) for result in results.to_dict("index").values()]
|
) for result in results.to_dict("index").values()]
|
||||||
|
|
@ -135,10 +164,19 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
if value < min_value:
|
if value < min_value:
|
||||||
min_value = value
|
min_value = value
|
||||||
|
|
||||||
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
|
normalized_values = []
|
||||||
|
min_value = min(result["_distance"] for result in result_values)
|
||||||
|
max_value = max(result["_distance"] for result in result_values)
|
||||||
|
|
||||||
|
if max_value == min_value:
|
||||||
|
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
|
||||||
|
normalized_values = [0 for _ in result_values]
|
||||||
|
else:
|
||||||
|
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
|
||||||
|
result_values]
|
||||||
|
|
||||||
return [ScoredResult(
|
return [ScoredResult(
|
||||||
id = str(result["id"]),
|
id = UUID(result["id"]),
|
||||||
payload = result["payload"],
|
payload = result["payload"],
|
||||||
score = normalized_values[value_index],
|
score = normalized_values[value_index],
|
||||||
) for value_index, result in enumerate(result_values)]
|
) for value_index, result in enumerate(result_values)]
|
||||||
|
|
@ -170,7 +208,27 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
|
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||||
|
await self.create_collection(f"{index_name}_{index_property_name}", payload_schema = IndexSchema)
|
||||||
|
|
||||||
|
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
|
||||||
|
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||||
|
IndexSchema(
|
||||||
|
id = str(data_point.id),
|
||||||
|
text = getattr(data_point, data_point._metadata["index_fields"][0]),
|
||||||
|
) for data_point in data_points
|
||||||
|
])
|
||||||
|
|
||||||
async def prune(self):
|
async def prune(self):
|
||||||
# Clean up the database if it was set up as temporary
|
# Clean up the database if it was set up as temporary
|
||||||
if self.url.startswith("/"):
|
if self.url.startswith("/"):
|
||||||
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside
|
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside
|
||||||
|
|
||||||
|
def get_data_point_schema(self, model_type):
|
||||||
|
return copy_model(
|
||||||
|
model_type,
|
||||||
|
include_fields = {
|
||||||
|
"id": (str, ...),
|
||||||
|
},
|
||||||
|
exclude_fields = ["_metadata"],
|
||||||
|
)
|
||||||
|
|
@ -1,13 +0,0 @@
|
||||||
from typing import Generic, TypeVar
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
PayloadSchema = TypeVar("PayloadSchema", bound = BaseModel)
|
|
||||||
|
|
||||||
class DataPoint(BaseModel, Generic[PayloadSchema]):
|
|
||||||
id: str
|
|
||||||
payload: PayloadSchema
|
|
||||||
embed_field: str = "value"
|
|
||||||
|
|
||||||
def get_embeddable_data(self):
|
|
||||||
if hasattr(self.payload, self.embed_field):
|
|
||||||
return getattr(self.payload, self.embed_field)
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
from uuid import UUID
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
class ScoredResult(BaseModel):
|
class ScoredResult(BaseModel):
|
||||||
id: str
|
id: UUID
|
||||||
score: float # Lower score is better
|
score: float # Lower score is better
|
||||||
payload: Dict[str, Any]
|
payload: Dict[str, Any]
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,26 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from uuid import UUID
|
||||||
from pgvector.sqlalchemy import Vector
|
from pgvector.sqlalchemy import Vector
|
||||||
from typing import List, Optional, get_type_hints
|
from typing import List, Optional, get_type_hints
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
from sqlalchemy import JSON, Column, Table, select, delete
|
from sqlalchemy import JSON, Column, Table, select, delete
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
|
|
||||||
from .serialize_datetime import serialize_datetime
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
from .serialize_data import serialize_data
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
from ..vector_db_interface import VectorDBInterface, DataPoint
|
from ..vector_db_interface import VectorDBInterface
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||||
from ...relational.ModelBase import Base
|
from ...relational.ModelBase import Base
|
||||||
|
|
||||||
|
class IndexSchema(DataPoint):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
_metadata: dict = {
|
||||||
|
"index_fields": ["text"]
|
||||||
|
}
|
||||||
|
|
||||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
|
|
@ -45,7 +54,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
|
|
||||||
class PGVectorDataPoint(Base):
|
class PGVectorDataPoint(Base):
|
||||||
__tablename__ = collection_name
|
__tablename__ = collection_name
|
||||||
__table_args__ = {"extend_existing": True}
|
__table_args__ = {"extend_existing": True}
|
||||||
|
|
@ -71,47 +79,58 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
async def create_data_points(
|
async def create_data_points(
|
||||||
self, collection_name: str, data_points: List[DataPoint]
|
self, collection_name: str, data_points: List[DataPoint]
|
||||||
):
|
):
|
||||||
async with self.get_async_session() as session:
|
if not await self.has_collection(collection_name):
|
||||||
if not await self.has_collection(collection_name):
|
await self.create_collection(
|
||||||
await self.create_collection(
|
collection_name = collection_name,
|
||||||
collection_name=collection_name,
|
payload_schema = type(data_points[0]),
|
||||||
payload_schema=type(data_points[0].payload),
|
|
||||||
)
|
|
||||||
|
|
||||||
data_vectors = await self.embed_data(
|
|
||||||
[data_point.get_embeddable_data() for data_point in data_points]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
data_vectors = await self.embed_data(
|
||||||
|
[data_point.get_embeddable_data() for data_point in data_points]
|
||||||
|
)
|
||||||
|
|
||||||
class PGVectorDataPoint(Base):
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
__tablename__ = collection_name
|
|
||||||
__table_args__ = {"extend_existing": True}
|
|
||||||
# PGVector requires one column to be the primary key
|
|
||||||
primary_key: Mapped[int] = mapped_column(
|
|
||||||
primary_key=True, autoincrement=True
|
|
||||||
)
|
|
||||||
id: Mapped[type(data_points[0].id)]
|
|
||||||
payload = Column(JSON)
|
|
||||||
vector = Column(Vector(vector_size))
|
|
||||||
|
|
||||||
def __init__(self, id, payload, vector):
|
class PGVectorDataPoint(Base):
|
||||||
self.id = id
|
__tablename__ = collection_name
|
||||||
self.payload = payload
|
__table_args__ = {"extend_existing": True}
|
||||||
self.vector = vector
|
# PGVector requires one column to be the primary key
|
||||||
|
primary_key: Mapped[int] = mapped_column(
|
||||||
|
primary_key=True, autoincrement=True
|
||||||
|
)
|
||||||
|
id: Mapped[type(data_points[0].id)]
|
||||||
|
payload = Column(JSON)
|
||||||
|
vector = Column(Vector(vector_size))
|
||||||
|
|
||||||
pgvector_data_points = [
|
def __init__(self, id, payload, vector):
|
||||||
PGVectorDataPoint(
|
self.id = id
|
||||||
id=data_point.id,
|
self.payload = payload
|
||||||
vector=data_vectors[data_index],
|
self.vector = vector
|
||||||
payload=serialize_datetime(data_point.payload.dict()),
|
|
||||||
)
|
|
||||||
for (data_index, data_point) in enumerate(data_points)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
pgvector_data_points = [
|
||||||
|
PGVectorDataPoint(
|
||||||
|
id = data_point.id,
|
||||||
|
vector = data_vectors[data_index],
|
||||||
|
payload = serialize_data(data_point.model_dump()),
|
||||||
|
)
|
||||||
|
for (data_index, data_point) in enumerate(data_points)
|
||||||
|
]
|
||||||
|
|
||||||
|
async with self.get_async_session() as session:
|
||||||
session.add_all(pgvector_data_points)
|
session.add_all(pgvector_data_points)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||||
|
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||||
|
|
||||||
|
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
|
||||||
|
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||||
|
IndexSchema(
|
||||||
|
id = data_point.id,
|
||||||
|
text = data_point.get_embeddable_data(),
|
||||||
|
) for data_point in data_points
|
||||||
|
])
|
||||||
|
|
||||||
async def get_table(self, collection_name: str) -> Table:
|
async def get_table(self, collection_name: str) -> Table:
|
||||||
"""
|
"""
|
||||||
Dynamically loads a table using the given collection name
|
Dynamically loads a table using the given collection name
|
||||||
|
|
@ -126,18 +145,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
raise ValueError(f"Table '{collection_name}' not found.")
|
raise ValueError(f"Table '{collection_name}' not found.")
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
||||||
async with self.get_async_session() as session:
|
# Get PGVectorDataPoint Table from database
|
||||||
# Get PGVectorDataPoint Table from database
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
PGVectorDataPoint = await self.get_table(collection_name)
|
|
||||||
|
|
||||||
|
async with self.get_async_session() as session:
|
||||||
results = await session.execute(
|
results = await session.execute(
|
||||||
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
|
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
|
||||||
)
|
)
|
||||||
results = results.all()
|
results = results.all()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(id=result.id, payload=result.payload, score=0)
|
ScoredResult(
|
||||||
for result in results
|
id = UUID(result.id),
|
||||||
|
payload = result.payload,
|
||||||
|
score = 0
|
||||||
|
) for result in results
|
||||||
]
|
]
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
|
|
@ -154,11 +176,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
if query_text and not query_vector:
|
if query_text and not query_vector:
|
||||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
||||||
|
# Get PGVectorDataPoint Table from database
|
||||||
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
|
|
||||||
|
closest_items = []
|
||||||
|
|
||||||
# Use async session to connect to the database
|
# Use async session to connect to the database
|
||||||
async with self.get_async_session() as session:
|
async with self.get_async_session() as session:
|
||||||
# Get PGVectorDataPoint Table from database
|
|
||||||
PGVectorDataPoint = await self.get_table(collection_name)
|
|
||||||
|
|
||||||
# Find closest vectors to query_vector
|
# Find closest vectors to query_vector
|
||||||
closest_items = await session.execute(
|
closest_items = await session.execute(
|
||||||
select(
|
select(
|
||||||
|
|
@ -171,19 +195,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
.limit(limit)
|
.limit(limit)
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_list = []
|
vector_list = []
|
||||||
# Extract distances and find min/max for normalization
|
|
||||||
for vector in closest_items:
|
|
||||||
# TODO: Add normalization of similarity score
|
|
||||||
vector_list.append(vector)
|
|
||||||
|
|
||||||
# Create and return ScoredResult objects
|
# Extract distances and find min/max for normalization
|
||||||
return [
|
for vector in closest_items:
|
||||||
ScoredResult(
|
# TODO: Add normalization of similarity score
|
||||||
id=str(row.id), payload=row.payload, score=row.similarity
|
vector_list.append(vector)
|
||||||
)
|
|
||||||
for row in vector_list
|
# Create and return ScoredResult objects
|
||||||
]
|
return [
|
||||||
|
ScoredResult(
|
||||||
|
id = UUID(str(row.id)),
|
||||||
|
payload = row.payload,
|
||||||
|
score = row.similarity
|
||||||
|
) for row in vector_list
|
||||||
|
]
|
||||||
|
|
||||||
async def batch_search(
|
async def batch_search(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
def serialize_datetime(data):
|
def serialize_data(data):
|
||||||
"""Recursively convert datetime objects in dictionaries/lists to ISO format."""
|
"""Recursively convert datetime objects in dictionaries/lists to ISO format."""
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
return {key: serialize_datetime(value) for key, value in data.items()}
|
return {key: serialize_data(value) for key, value in data.items()}
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
return [serialize_datetime(item) for item in data]
|
return [serialize_data(item) for item in data]
|
||||||
elif isinstance(data, datetime):
|
elif isinstance(data, datetime):
|
||||||
return data.isoformat() # Convert datetime to ISO 8601 string
|
return data.isoformat() # Convert datetime to ISO 8601 string
|
||||||
|
elif isinstance(data, UUID):
|
||||||
|
return str(data)
|
||||||
else:
|
else:
|
||||||
return data
|
return data
|
||||||
|
|
@ -1,12 +1,22 @@
|
||||||
import logging
|
import logging
|
||||||
|
from uuid import UUID
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
from ..models.DataPoint import DataPoint
|
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
logger = logging.getLogger("QDrantAdapter")
|
logger = logging.getLogger("QDrantAdapter")
|
||||||
|
|
||||||
|
class IndexSchema(DataPoint):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
_metadata: dict = {
|
||||||
|
"index_fields": ["text"]
|
||||||
|
}
|
||||||
|
|
||||||
# class CollectionConfig(BaseModel, extra = "forbid"):
|
# class CollectionConfig(BaseModel, extra = "forbid"):
|
||||||
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
|
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
|
||||||
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
|
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
|
||||||
|
|
@ -75,20 +85,19 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
):
|
):
|
||||||
client = self.get_qdrant_client()
|
client = self.get_qdrant_client()
|
||||||
|
|
||||||
result = await client.create_collection(
|
if not await client.collection_exists(collection_name):
|
||||||
collection_name = collection_name,
|
await client.create_collection(
|
||||||
vectors_config = {
|
collection_name = collection_name,
|
||||||
"text": models.VectorParams(
|
vectors_config = {
|
||||||
size = self.embedding_engine.get_vector_size(),
|
"text": models.VectorParams(
|
||||||
distance = "Cosine"
|
size = self.embedding_engine.get_vector_size(),
|
||||||
)
|
distance = "Cosine"
|
||||||
}
|
)
|
||||||
)
|
}
|
||||||
|
)
|
||||||
|
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||||
client = self.get_qdrant_client()
|
client = self.get_qdrant_client()
|
||||||
|
|
||||||
|
|
@ -96,8 +105,8 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
|
|
||||||
def convert_to_qdrant_point(data_point: DataPoint):
|
def convert_to_qdrant_point(data_point: DataPoint):
|
||||||
return models.PointStruct(
|
return models.PointStruct(
|
||||||
id = data_point.id,
|
id = str(data_point.id),
|
||||||
payload = data_point.payload.dict(),
|
payload = data_point.model_dump(),
|
||||||
vector = {
|
vector = {
|
||||||
"text": data_vectors[data_points.index(data_point)]
|
"text": data_vectors[data_points.index(data_point)]
|
||||||
}
|
}
|
||||||
|
|
@ -116,6 +125,17 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
finally:
|
finally:
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
|
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||||
|
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||||
|
|
||||||
|
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
|
||||||
|
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||||
|
IndexSchema(
|
||||||
|
id = data_point.id,
|
||||||
|
text = getattr(data_point, data_point._metadata["index_fields"][0]),
|
||||||
|
) for data_point in data_points
|
||||||
|
])
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||||
client = self.get_qdrant_client()
|
client = self.get_qdrant_client()
|
||||||
results = await client.retrieve(collection_name, data_point_ids, with_payload = True)
|
results = await client.retrieve(collection_name, data_point_ids, with_payload = True)
|
||||||
|
|
@ -135,7 +155,7 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
|
|
||||||
client = self.get_qdrant_client()
|
client = self.get_qdrant_client()
|
||||||
|
|
||||||
result = await client.search(
|
results = await client.search(
|
||||||
collection_name = collection_name,
|
collection_name = collection_name,
|
||||||
query_vector = models.NamedVector(
|
query_vector = models.NamedVector(
|
||||||
name = "text",
|
name = "text",
|
||||||
|
|
@ -147,7 +167,16 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
|
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
return result
|
return [
|
||||||
|
ScoredResult(
|
||||||
|
id = UUID(result.id),
|
||||||
|
payload = {
|
||||||
|
**result.payload,
|
||||||
|
"id": UUID(result.id),
|
||||||
|
},
|
||||||
|
score = 1 - result.score,
|
||||||
|
) for result in results
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int = None, with_vectors: bool = False):
|
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int = None, with_vectors: bool = False):
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import List, Protocol, Optional
|
from typing import List, Protocol, Optional
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from .models.DataPoint import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from .models.PayloadSchema import PayloadSchema
|
from .models.PayloadSchema import PayloadSchema
|
||||||
|
|
||||||
class VectorDBInterface(Protocol):
|
class VectorDBInterface(Protocol):
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,22 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
from ..models.DataPoint import DataPoint
|
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
logger = logging.getLogger("WeaviateAdapter")
|
logger = logging.getLogger("WeaviateAdapter")
|
||||||
|
|
||||||
|
class IndexSchema(DataPoint):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
_metadata: dict = {
|
||||||
|
"index_fields": ["text"]
|
||||||
|
}
|
||||||
|
|
||||||
class WeaviateAdapter(VectorDBInterface):
|
class WeaviateAdapter(VectorDBInterface):
|
||||||
name = "Weaviate"
|
name = "Weaviate"
|
||||||
url: str
|
url: str
|
||||||
|
|
@ -48,18 +57,21 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
|
|
||||||
future = asyncio.Future()
|
future = asyncio.Future()
|
||||||
|
|
||||||
future.set_result(
|
if not self.client.collections.exists(collection_name):
|
||||||
self.client.collections.create(
|
future.set_result(
|
||||||
name=collection_name,
|
self.client.collections.create(
|
||||||
properties=[
|
name = collection_name,
|
||||||
wvcc.Property(
|
properties = [
|
||||||
name="text",
|
wvcc.Property(
|
||||||
data_type=wvcc.DataType.TEXT,
|
name = "text",
|
||||||
skip_vectorization=True
|
data_type = wvcc.DataType.TEXT,
|
||||||
)
|
skip_vectorization = True
|
||||||
]
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
else:
|
||||||
|
future.set_result(self.get_collection(collection_name))
|
||||||
|
|
||||||
return await future
|
return await future
|
||||||
|
|
||||||
|
|
@ -70,36 +82,60 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
from weaviate.classes.data import DataObject
|
from weaviate.classes.data import DataObject
|
||||||
|
|
||||||
data_vectors = await self.embed_data(
|
data_vectors = await self.embed_data(
|
||||||
list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
|
[data_point.get_embeddable_data() for data_point in data_points]
|
||||||
|
)
|
||||||
|
|
||||||
def convert_to_weaviate_data_points(data_point: DataPoint):
|
def convert_to_weaviate_data_points(data_point: DataPoint):
|
||||||
vector = data_vectors[data_points.index(data_point)]
|
vector = data_vectors[data_points.index(data_point)]
|
||||||
|
properties = data_point.model_dump()
|
||||||
|
|
||||||
|
if "id" in properties:
|
||||||
|
properties["uuid"] = str(data_point.id)
|
||||||
|
del properties["id"]
|
||||||
|
|
||||||
return DataObject(
|
return DataObject(
|
||||||
uuid = data_point.id,
|
uuid = data_point.id,
|
||||||
properties = data_point.payload.dict(),
|
properties = properties,
|
||||||
vector = vector
|
vector = vector
|
||||||
)
|
)
|
||||||
|
|
||||||
data_points = list(map(convert_to_weaviate_data_points, data_points))
|
data_points = [convert_to_weaviate_data_points(data_point) for data_point in data_points]
|
||||||
|
|
||||||
collection = self.get_collection(collection_name)
|
collection = self.get_collection(collection_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if len(data_points) > 1:
|
if len(data_points) > 1:
|
||||||
return collection.data.insert_many(data_points)
|
with collection.batch.dynamic() as batch:
|
||||||
|
for data_point in data_points:
|
||||||
|
batch.add_object(
|
||||||
|
uuid = data_point.uuid,
|
||||||
|
vector = data_point.vector,
|
||||||
|
properties = data_point.properties,
|
||||||
|
references = data_point.references,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return collection.data.insert(data_points[0])
|
data_point: DataObject = data_points[0]
|
||||||
# with collection.batch.dynamic() as batch:
|
return collection.data.update(
|
||||||
# for point in data_points:
|
uuid = data_point.uuid,
|
||||||
# batch.add_object(
|
vector = data_point.vector,
|
||||||
# uuid = point.uuid,
|
properties = data_point.properties,
|
||||||
# properties = point.properties,
|
references = data_point.references,
|
||||||
# vector = point.vector
|
)
|
||||||
# )
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error("Error creating data points: %s", str(error))
|
logger.error("Error creating data points: %s", str(error))
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
|
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||||
|
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||||
|
|
||||||
|
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
|
||||||
|
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||||
|
IndexSchema(
|
||||||
|
id = data_point.id,
|
||||||
|
text = data_point.get_embeddable_data(),
|
||||||
|
) for data_point in data_points
|
||||||
|
])
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||||
from weaviate.classes.query import Filter
|
from weaviate.classes.query import Filter
|
||||||
future = asyncio.Future()
|
future = asyncio.Future()
|
||||||
|
|
@ -143,9 +179,9 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=str(result.uuid),
|
id = UUID(str(result.uuid)),
|
||||||
payload=result.properties,
|
payload = result.properties,
|
||||||
score=float(result.metadata.score)
|
score = 1 - float(result.metadata.score)
|
||||||
) for result in search_result.objects
|
) for result in search_result.objects
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
1
cognee/infrastructure/engine/__init__.py
Normal file
1
cognee/infrastructure/engine/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from .models.DataPoint import DataPoint
|
||||||
|
|
@ -0,0 +1,72 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.graph.utils import get_graph_from_model, get_model_instance_from_graph
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
class CarTypeName(Enum):
|
||||||
|
Pickup = "Pickup"
|
||||||
|
Sedan = "Sedan"
|
||||||
|
SUV = "SUV"
|
||||||
|
Coupe = "Coupe"
|
||||||
|
Convertible = "Convertible"
|
||||||
|
Hatchback = "Hatchback"
|
||||||
|
Wagon = "Wagon"
|
||||||
|
Minivan = "Minivan"
|
||||||
|
Van = "Van"
|
||||||
|
|
||||||
|
class CarType(DataPoint):
|
||||||
|
id: str
|
||||||
|
name: CarTypeName
|
||||||
|
_metadata: dict = dict(index_fields = ["name"])
|
||||||
|
|
||||||
|
class Car(DataPoint):
|
||||||
|
id: str
|
||||||
|
brand: str
|
||||||
|
model: str
|
||||||
|
year: int
|
||||||
|
color: str
|
||||||
|
is_type: CarType
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
owns_car: list[Car]
|
||||||
|
driving_licence: Optional[dict]
|
||||||
|
_metadata: dict = dict(index_fields = ["name"])
|
||||||
|
|
||||||
|
boris = Person(
|
||||||
|
id = "boris",
|
||||||
|
name = "Boris",
|
||||||
|
age = 30,
|
||||||
|
owns_car = [
|
||||||
|
Car(
|
||||||
|
id = "car1",
|
||||||
|
brand = "Toyota",
|
||||||
|
model = "Camry",
|
||||||
|
year = 2020,
|
||||||
|
color = "Blue",
|
||||||
|
is_type = CarType(id = "sedan", name = CarTypeName.Sedan),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
driving_licence = {
|
||||||
|
"issued_by": "PU Vrsac",
|
||||||
|
"issued_on": "2025-11-06",
|
||||||
|
"number": "1234567890",
|
||||||
|
"expires_on": "2025-11-06",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes, edges = get_graph_from_model(boris)
|
||||||
|
|
||||||
|
print(nodes)
|
||||||
|
print(edges)
|
||||||
|
|
||||||
|
person_data = nodes[len(nodes) - 1]
|
||||||
|
|
||||||
|
parsed_person = get_model_instance_from_graph(nodes, edges, 'boris')
|
||||||
|
|
||||||
|
print(parsed_person)
|
||||||
24
cognee/infrastructure/engine/models/DataPoint.py
Normal file
24
cognee/infrastructure/engine/models/DataPoint.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class MetaData(TypedDict):
|
||||||
|
index_fields: list[str]
|
||||||
|
|
||||||
|
class DataPoint(BaseModel):
|
||||||
|
id: UUID = Field(default_factory = uuid4)
|
||||||
|
updated_at: Optional[datetime] = datetime.now(timezone.utc)
|
||||||
|
_metadata: Optional[MetaData] = {
|
||||||
|
"index_fields": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# class Config:
|
||||||
|
# underscore_attrs_are_private = True
|
||||||
|
|
||||||
|
def get_embeddable_data(self):
|
||||||
|
if self._metadata and len(self._metadata["index_fields"]) > 0 \
|
||||||
|
and hasattr(self, self._metadata["index_fields"][0]):
|
||||||
|
|
||||||
|
return getattr(self, self._metadata["index_fields"][0])
|
||||||
|
|
@ -1,18 +1,18 @@
|
||||||
from uuid import UUID, uuid5, NAMESPACE_OID
|
from uuid import uuid5, NAMESPACE_OID
|
||||||
|
|
||||||
from cognee.modules.chunking import DocumentChunk
|
from .models.DocumentChunk import DocumentChunk
|
||||||
from cognee.tasks.chunking import chunk_by_paragraph
|
from cognee.tasks.chunks import chunk_by_paragraph
|
||||||
|
|
||||||
class TextChunker():
|
class TextChunker():
|
||||||
id: UUID
|
document = None
|
||||||
max_chunk_size: int
|
max_chunk_size: int
|
||||||
|
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
chunk_size = 0
|
chunk_size = 0
|
||||||
paragraph_chunks = []
|
paragraph_chunks = []
|
||||||
|
|
||||||
def __init__(self, id: UUID, get_text: callable, chunk_size: int = 1024):
|
def __init__(self, document, get_text: callable, chunk_size: int = 1024):
|
||||||
self.id = id
|
self.document = document
|
||||||
self.max_chunk_size = chunk_size
|
self.max_chunk_size = chunk_size
|
||||||
self.get_text = get_text
|
self.get_text = get_text
|
||||||
|
|
||||||
|
|
@ -29,10 +29,10 @@ class TextChunker():
|
||||||
else:
|
else:
|
||||||
if len(self.paragraph_chunks) == 0:
|
if len(self.paragraph_chunks) == 0:
|
||||||
yield DocumentChunk(
|
yield DocumentChunk(
|
||||||
|
id = chunk_data["chunk_id"],
|
||||||
text = chunk_data["text"],
|
text = chunk_data["text"],
|
||||||
word_count = chunk_data["word_count"],
|
word_count = chunk_data["word_count"],
|
||||||
document_id = str(self.id),
|
is_part_of = self.document,
|
||||||
chunk_id = str(chunk_data["chunk_id"]),
|
|
||||||
chunk_index = self.chunk_index,
|
chunk_index = self.chunk_index,
|
||||||
cut_type = chunk_data["cut_type"],
|
cut_type = chunk_data["cut_type"],
|
||||||
)
|
)
|
||||||
|
|
@ -40,25 +40,31 @@ class TextChunker():
|
||||||
self.chunk_size = 0
|
self.chunk_size = 0
|
||||||
else:
|
else:
|
||||||
chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks)
|
chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks)
|
||||||
yield DocumentChunk(
|
try:
|
||||||
text = chunk_text,
|
yield DocumentChunk(
|
||||||
word_count = self.chunk_size,
|
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
|
||||||
document_id = str(self.id),
|
text = chunk_text,
|
||||||
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{self.chunk_index}")),
|
word_count = self.chunk_size,
|
||||||
chunk_index = self.chunk_index,
|
is_part_of = self.document,
|
||||||
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
|
chunk_index = self.chunk_index,
|
||||||
)
|
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
self.paragraph_chunks = [chunk_data]
|
self.paragraph_chunks = [chunk_data]
|
||||||
self.chunk_size = chunk_data["word_count"]
|
self.chunk_size = chunk_data["word_count"]
|
||||||
|
|
||||||
self.chunk_index += 1
|
self.chunk_index += 1
|
||||||
|
|
||||||
if len(self.paragraph_chunks) > 0:
|
if len(self.paragraph_chunks) > 0:
|
||||||
yield DocumentChunk(
|
try:
|
||||||
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
|
yield DocumentChunk(
|
||||||
word_count = self.chunk_size,
|
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
|
||||||
document_id = str(self.id),
|
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
|
||||||
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{self.chunk_index}")),
|
word_count = self.chunk_size,
|
||||||
chunk_index = self.chunk_index,
|
is_part_of = self.document,
|
||||||
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
|
chunk_index = self.chunk_index,
|
||||||
)
|
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
from .models.DocumentChunk import DocumentChunk
|
|
||||||
from .TextChunker import TextChunker
|
|
||||||
|
|
@ -1,9 +1,14 @@
|
||||||
from pydantic import BaseModel
|
from typing import Optional
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.data.processing.document_types import Document
|
||||||
|
|
||||||
class DocumentChunk(BaseModel):
|
class DocumentChunk(DataPoint):
|
||||||
text: str
|
text: str
|
||||||
word_count: int
|
word_count: int
|
||||||
document_id: str
|
|
||||||
chunk_id: str
|
|
||||||
chunk_index: int
|
chunk_index: int
|
||||||
cut_type: str
|
cut_type: str
|
||||||
|
is_part_of: Document
|
||||||
|
|
||||||
|
_metadata: Optional[dict] = {
|
||||||
|
"index_fields": ["text"],
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from .knowledge_graph.extract_content_graph import extract_content_graph
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from .extract_content_graph import extract_content_graph
|
||||||
|
|
@ -1,36 +1,36 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def detect_language(text: str):
|
||||||
async def detect_language(data:str):
|
|
||||||
"""
|
"""
|
||||||
Detect the language of the given text and return its ISO 639-1 language code.
|
Detect the language of the given text and return its ISO 639-1 language code.
|
||||||
If the detected language is Croatian ('hr'), it maps to Serbian ('sr').
|
If the detected language is Croatian ("hr"), it maps to Serbian ("sr").
|
||||||
The text is trimmed to the first 100 characters for efficient processing.
|
The text is trimmed to the first 100 characters for efficient processing.
|
||||||
Parameters:
|
Parameters:
|
||||||
text (str): The text for language detection.
|
text (str): The text for language detection.
|
||||||
Returns:
|
Returns:
|
||||||
str: The ISO 639-1 language code of the detected language, or 'None' in case of an error.
|
str: The ISO 639-1 language code of the detected language, or "None" in case of an error.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Trim the text to the first 100 characters
|
|
||||||
from langdetect import detect, LangDetectException
|
from langdetect import detect, LangDetectException
|
||||||
trimmed_text = data[:100]
|
# Trim the text to the first 100 characters
|
||||||
|
trimmed_text = text[:100]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Detect the language using langdetect
|
# Detect the language using langdetect
|
||||||
detected_lang_iso639_1 = detect(trimmed_text)
|
detected_lang_iso639_1 = detect(trimmed_text)
|
||||||
logging.info(f"Detected ISO 639-1 code: {detected_lang_iso639_1}")
|
|
||||||
|
|
||||||
# Special case: map 'hr' (Croatian) to 'sr' (Serbian ISO 639-2)
|
# Special case: map "hr" (Croatian) to "sr" (Serbian ISO 639-2)
|
||||||
if detected_lang_iso639_1 == 'hr':
|
if detected_lang_iso639_1 == "hr":
|
||||||
yield 'sr'
|
return "sr"
|
||||||
yield detected_lang_iso639_1
|
|
||||||
|
return detected_lang_iso639_1
|
||||||
|
|
||||||
except LangDetectException as e:
|
except LangDetectException as e:
|
||||||
logging.error(f"Language detection error: {e}")
|
logger.error(f"Language detection error: {e}")
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Unexpected error: {e}")
|
|
||||||
|
|
||||||
yield None
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
41
cognee/modules/data/operations/translate_text.py
Normal file
41
cognee/modules/data/operations/translate_text.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def translate_text(text, source_language: str = "sr", target_language: str = "en", region_name = "eu-west-1"):
|
||||||
|
"""
|
||||||
|
Translate text from source language to target language using AWS Translate.
|
||||||
|
Parameters:
|
||||||
|
text (str): The text to be translated.
|
||||||
|
source_language (str): The source language code (e.g., "sr" for Serbian). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
|
||||||
|
target_language (str): The target language code (e.g., "en" for English). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
|
||||||
|
region_name (str): AWS region name.
|
||||||
|
Returns:
|
||||||
|
str: Translated text or an error message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import BotoCoreError, ClientError
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
raise ValueError("No text to translate.")
|
||||||
|
|
||||||
|
if not source_language or not target_language:
|
||||||
|
raise ValueError("Source and target language codes are required.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
translate = boto3.client(service_name = "translate", region_name = region_name, use_ssl = True)
|
||||||
|
result = translate.translate_text(
|
||||||
|
Text = text,
|
||||||
|
SourceLanguageCode = source_language,
|
||||||
|
TargetLanguageCode = target_language,
|
||||||
|
)
|
||||||
|
yield result.get("TranslatedText", "No translation found.")
|
||||||
|
|
||||||
|
except BotoCoreError as e:
|
||||||
|
logger.error(f"BotoCoreError occurred: {e}")
|
||||||
|
yield None
|
||||||
|
|
||||||
|
except ClientError as e:
|
||||||
|
logger.error(f"ClientError occurred: {e}")
|
||||||
|
yield None
|
||||||
|
|
@ -1,34 +1,15 @@
|
||||||
from uuid import UUID, uuid5, NAMESPACE_OID
|
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from .Document import Document
|
from .Document import Document
|
||||||
|
|
||||||
class AudioDocument(Document):
|
class AudioDocument(Document):
|
||||||
type: str = "audio"
|
type: str = "audio"
|
||||||
title: str
|
|
||||||
raw_data_location: str
|
|
||||||
chunking_strategy: str
|
|
||||||
|
|
||||||
def __init__(self, id: UUID, title: str, raw_data_location: str, chunking_strategy:str="paragraph"):
|
|
||||||
self.id = id or uuid5(NAMESPACE_OID, title)
|
|
||||||
self.title = title
|
|
||||||
self.raw_data_location = raw_data_location
|
|
||||||
self.chunking_strategy = chunking_strategy
|
|
||||||
|
|
||||||
def read(self, chunk_size: int):
|
def read(self, chunk_size: int):
|
||||||
# Transcribe the audio file
|
# Transcribe the audio file
|
||||||
result = get_llm_client().create_transcript(self.raw_data_location)
|
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||||
text = result.text
|
text = result.text
|
||||||
|
|
||||||
chunker = TextChunker(self.id, chunk_size = chunk_size, get_text = lambda: text)
|
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: text)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return dict(
|
|
||||||
id=str(self.id),
|
|
||||||
type=self.type,
|
|
||||||
title=self.title,
|
|
||||||
raw_data_location=self.raw_data_location,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
from uuid import UUID
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from typing import Protocol
|
|
||||||
|
|
||||||
class Document(Protocol):
|
class Document(DataPoint):
|
||||||
id: UUID
|
|
||||||
type: str
|
type: str
|
||||||
title: str
|
name: str
|
||||||
raw_data_location: str
|
raw_data_location: str
|
||||||
|
|
||||||
def read(self, chunk_size: int) -> str:
|
def read(self, chunk_size: int) -> str:
|
||||||
|
|
|
||||||
|
|
@ -1,33 +1,15 @@
|
||||||
from uuid import UUID, uuid5, NAMESPACE_OID
|
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from .Document import Document
|
from .Document import Document
|
||||||
|
|
||||||
|
|
||||||
class ImageDocument(Document):
|
class ImageDocument(Document):
|
||||||
type: str = "image"
|
type: str = "image"
|
||||||
title: str
|
|
||||||
raw_data_location: str
|
|
||||||
|
|
||||||
def __init__(self, id: UUID, title: str, raw_data_location: str):
|
|
||||||
self.id = id or uuid5(NAMESPACE_OID, title)
|
|
||||||
self.title = title
|
|
||||||
self.raw_data_location = raw_data_location
|
|
||||||
|
|
||||||
def read(self, chunk_size: int):
|
def read(self, chunk_size: int):
|
||||||
# Transcribe the image file
|
# Transcribe the image file
|
||||||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||||
text = result.choices[0].message.content
|
text = result.choices[0].message.content
|
||||||
|
|
||||||
chunker = TextChunker(self.id, chunk_size = chunk_size, get_text = lambda: text)
|
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: text)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return dict(
|
|
||||||
id=str(self.id),
|
|
||||||
type=self.type,
|
|
||||||
title=self.title,
|
|
||||||
raw_data_location=self.raw_data_location,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,11 @@
|
||||||
from uuid import UUID, uuid5, NAMESPACE_OID
|
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from .Document import Document
|
from .Document import Document
|
||||||
|
|
||||||
class PdfDocument(Document):
|
class PdfDocument(Document):
|
||||||
type: str = "pdf"
|
type: str = "pdf"
|
||||||
title: str
|
|
||||||
raw_data_location: str
|
|
||||||
|
|
||||||
def __init__(self, id: UUID, title: str, raw_data_location: str):
|
def read(self, chunk_size: int):
|
||||||
self.id = id or uuid5(NAMESPACE_OID, title)
|
|
||||||
self.title = title
|
|
||||||
self.raw_data_location = raw_data_location
|
|
||||||
|
|
||||||
def read(self, chunk_size: int) -> PdfReader:
|
|
||||||
file = PdfReader(self.raw_data_location)
|
file = PdfReader(self.raw_data_location)
|
||||||
|
|
||||||
def get_text():
|
def get_text():
|
||||||
|
|
@ -21,16 +13,8 @@ class PdfDocument(Document):
|
||||||
page_text = page.extract_text()
|
page_text = page.extract_text()
|
||||||
yield page_text
|
yield page_text
|
||||||
|
|
||||||
chunker = TextChunker(self.id, chunk_size = chunk_size, get_text = get_text)
|
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
||||||
file.stream.close()
|
file.stream.close()
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return dict(
|
|
||||||
id = str(self.id),
|
|
||||||
type = self.type,
|
|
||||||
title = self.title,
|
|
||||||
raw_data_location = self.raw_data_location,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,8 @@
|
||||||
from uuid import UUID, uuid5, NAMESPACE_OID
|
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from .Document import Document
|
from .Document import Document
|
||||||
|
|
||||||
class TextDocument(Document):
|
class TextDocument(Document):
|
||||||
type: str = "text"
|
type: str = "text"
|
||||||
title: str
|
|
||||||
raw_data_location: str
|
|
||||||
|
|
||||||
def __init__(self, id: UUID, title: str, raw_data_location: str):
|
|
||||||
self.id = id or uuid5(NAMESPACE_OID, title)
|
|
||||||
self.title = title
|
|
||||||
self.raw_data_location = raw_data_location
|
|
||||||
|
|
||||||
def read(self, chunk_size: int):
|
def read(self, chunk_size: int):
|
||||||
def get_text():
|
def get_text():
|
||||||
|
|
@ -23,16 +15,6 @@ class TextDocument(Document):
|
||||||
|
|
||||||
yield text
|
yield text
|
||||||
|
|
||||||
|
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
|
||||||
chunker = TextChunker(self.id,chunk_size = chunk_size, get_text = get_text)
|
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return dict(
|
|
||||||
id = str(self.id),
|
|
||||||
type = self.type,
|
|
||||||
title = self.title,
|
|
||||||
raw_data_location = self.raw_data_location,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from .Document import Document
|
||||||
from .PdfDocument import PdfDocument
|
from .PdfDocument import PdfDocument
|
||||||
from .TextDocument import TextDocument
|
from .TextDocument import TextDocument
|
||||||
from .ImageDocument import ImageDocument
|
from .ImageDocument import ImageDocument
|
||||||
|
|
|
||||||
12
cognee/modules/engine/models/Entity.py
Normal file
12
cognee/modules/engine/models/Entity.py
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||||
|
from .EntityType import EntityType
|
||||||
|
|
||||||
|
class Entity(DataPoint):
|
||||||
|
name: str
|
||||||
|
is_a: EntityType
|
||||||
|
description: str
|
||||||
|
mentioned_in: DocumentChunk
|
||||||
|
_metadata: dict = {
|
||||||
|
"index_fields": ["name"],
|
||||||
|
}
|
||||||
11
cognee/modules/engine/models/EntityType.py
Normal file
11
cognee/modules/engine/models/EntityType.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||||
|
|
||||||
|
class EntityType(DataPoint):
|
||||||
|
name: str
|
||||||
|
type: str
|
||||||
|
description: str
|
||||||
|
exists_in: DocumentChunk
|
||||||
|
_metadata: dict = {
|
||||||
|
"index_fields": ["name"],
|
||||||
|
}
|
||||||
2
cognee/modules/engine/models/__init__.py
Normal file
2
cognee/modules/engine/models/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .Entity import Entity
|
||||||
|
from .EntityType import EntityType
|
||||||
3
cognee/modules/engine/utils/__init__.py
Normal file
3
cognee/modules/engine/utils/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .generate_node_id import generate_node_id
|
||||||
|
from .generate_node_name import generate_node_name
|
||||||
|
from .generate_edge_name import generate_edge_name
|
||||||
2
cognee/modules/engine/utils/generate_edge_name.py
Normal file
2
cognee/modules/engine/utils/generate_edge_name.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
def generate_edge_name(name: str) -> str:
|
||||||
|
return name.lower().replace(" ", "_").replace("'", "")
|
||||||
4
cognee/modules/engine/utils/generate_node_id.py
Normal file
4
cognee/modules/engine/utils/generate_node_id.py
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
|
def generate_node_id(node_id: str) -> str:
|
||||||
|
return uuid5(NAMESPACE_OID, node_id.lower().replace(" ", "_").replace("'", ""))
|
||||||
2
cognee/modules/engine/utils/generate_node_name.py
Normal file
2
cognee/modules/engine/utils/generate_node_name.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
def generate_node_name(name: str) -> str:
|
||||||
|
return name.lower().replace("'", "")
|
||||||
|
|
@ -1,5 +0,0 @@
|
||||||
def generate_node_name(name: str) -> str:
|
|
||||||
return name.lower().replace(" ", "_").replace("'", "")
|
|
||||||
|
|
||||||
def generate_node_id(node_id: str) -> str:
|
|
||||||
return node_id.lower().replace(" ", "_").replace("'", "")
|
|
||||||
2
cognee/modules/graph/utils/__init__.py
Normal file
2
cognee/modules/graph/utils/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .get_graph_from_model import get_graph_from_model
|
||||||
|
from .get_model_instance_from_graph import get_model_instance_from_graph
|
||||||
107
cognee/modules/graph/utils/get_graph_from_model.py
Normal file
107
cognee/modules/graph/utils/get_graph_from_model.py
Normal file
|
|
@ -0,0 +1,107 @@
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.storage.utils import copy_model
|
||||||
|
|
||||||
|
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
|
||||||
|
nodes = []
|
||||||
|
edges = []
|
||||||
|
|
||||||
|
data_point_properties = {}
|
||||||
|
excluded_properties = set()
|
||||||
|
|
||||||
|
for field_name, field_value in data_point:
|
||||||
|
if field_name == "_metadata":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(field_value, DataPoint):
|
||||||
|
excluded_properties.add(field_name)
|
||||||
|
|
||||||
|
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
|
||||||
|
|
||||||
|
for node in property_nodes:
|
||||||
|
if str(node.id) not in added_nodes:
|
||||||
|
nodes.append(node)
|
||||||
|
added_nodes[str(node.id)] = True
|
||||||
|
|
||||||
|
for edge in property_edges:
|
||||||
|
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||||
|
|
||||||
|
if str(edge_key) not in added_edges:
|
||||||
|
edges.append(edge)
|
||||||
|
added_edges[str(edge_key)] = True
|
||||||
|
|
||||||
|
for property_node in get_own_properties(property_nodes, property_edges):
|
||||||
|
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||||
|
|
||||||
|
if str(edge_key) not in added_edges:
|
||||||
|
edges.append((data_point.id, property_node.id, field_name, {
|
||||||
|
"source_node_id": data_point.id,
|
||||||
|
"target_node_id": property_node.id,
|
||||||
|
"relationship_name": field_name,
|
||||||
|
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
}))
|
||||||
|
added_edges[str(edge_key)] = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||||
|
excluded_properties.add(field_name)
|
||||||
|
|
||||||
|
for item in field_value:
|
||||||
|
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)
|
||||||
|
|
||||||
|
for node in property_nodes:
|
||||||
|
if str(node.id) not in added_nodes:
|
||||||
|
nodes.append(node)
|
||||||
|
added_nodes[str(node.id)] = True
|
||||||
|
|
||||||
|
for edge in property_edges:
|
||||||
|
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||||
|
|
||||||
|
if str(edge_key) not in added_edges:
|
||||||
|
edges.append(edge)
|
||||||
|
added_edges[edge_key] = True
|
||||||
|
|
||||||
|
for property_node in get_own_properties(property_nodes, property_edges):
|
||||||
|
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||||
|
|
||||||
|
if str(edge_key) not in added_edges:
|
||||||
|
edges.append((data_point.id, property_node.id, field_name, {
|
||||||
|
"source_node_id": data_point.id,
|
||||||
|
"target_node_id": property_node.id,
|
||||||
|
"relationship_name": field_name,
|
||||||
|
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
"metadata": {
|
||||||
|
"type": "list"
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
added_edges[edge_key] = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
data_point_properties[field_name] = field_value
|
||||||
|
|
||||||
|
SimpleDataPointModel = copy_model(
|
||||||
|
type(data_point),
|
||||||
|
include_fields = {
|
||||||
|
"_metadata": (dict, data_point._metadata),
|
||||||
|
},
|
||||||
|
exclude_fields = excluded_properties,
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_root:
|
||||||
|
nodes.append(SimpleDataPointModel(**data_point_properties))
|
||||||
|
|
||||||
|
return nodes, edges
|
||||||
|
|
||||||
|
|
||||||
|
def get_own_properties(property_nodes, property_edges):
|
||||||
|
own_properties = []
|
||||||
|
|
||||||
|
destination_nodes = [str(property_edge[1]) for property_edge in property_edges]
|
||||||
|
|
||||||
|
for node in property_nodes:
|
||||||
|
if str(node.id) in destination_nodes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
own_properties.append(node)
|
||||||
|
|
||||||
|
return own_properties
|
||||||
29
cognee/modules/graph/utils/get_model_instance_from_graph.py
Normal file
29
cognee/modules/graph/utils/get_model_instance_from_graph.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
from pydantic_core import PydanticUndefined
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.storage.utils import copy_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
|
||||||
|
node_map = {}
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
node_map[node.id] = node
|
||||||
|
|
||||||
|
for edge in edges:
|
||||||
|
source_node = node_map[edge[0]]
|
||||||
|
target_node = node_map[edge[1]]
|
||||||
|
edge_label = edge[2]
|
||||||
|
edge_properties = edge[3] if len(edge) == 4 else {}
|
||||||
|
edge_metadata = edge_properties.get("metadata", {})
|
||||||
|
edge_type = edge_metadata.get("type")
|
||||||
|
|
||||||
|
if edge_type == "list":
|
||||||
|
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
|
||||||
|
|
||||||
|
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
|
||||||
|
else:
|
||||||
|
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })
|
||||||
|
|
||||||
|
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
|
||||||
|
|
||||||
|
return node_map[entity_id]
|
||||||
|
|
@ -7,7 +7,7 @@ from ..tasks.Task import Task
|
||||||
|
|
||||||
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
|
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
|
||||||
|
|
||||||
async def run_tasks_base(tasks: [Task], data = None, user: User = None):
|
async def run_tasks_base(tasks: list[Task], data = None, user: User = None):
|
||||||
if len(tasks) == 0:
|
if len(tasks) == 0:
|
||||||
yield data
|
yield data
|
||||||
return
|
return
|
||||||
|
|
@ -16,7 +16,7 @@ async def run_tasks_base(tasks: [Task], data = None, user: User = None):
|
||||||
|
|
||||||
running_task = tasks[0]
|
running_task = tasks[0]
|
||||||
leftover_tasks = tasks[1:]
|
leftover_tasks = tasks[1:]
|
||||||
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
|
next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
|
||||||
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
||||||
|
|
||||||
if inspect.isasyncgenfunction(running_task.executable):
|
if inspect.isasyncgenfunction(running_task.executable):
|
||||||
|
|
|
||||||
|
|
@ -1,33 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import nest_asyncio
|
|
||||||
import dspy
|
|
||||||
from cognee.modules.search.vector.search_similarity import search_similarity
|
|
||||||
|
|
||||||
nest_asyncio.apply()
|
|
||||||
|
|
||||||
class AnswerFromContext(dspy.Signature):
|
|
||||||
question: str = dspy.InputField()
|
|
||||||
context: str = dspy.InputField(desc = "Context to use for answer generation.")
|
|
||||||
answer: str = dspy.OutputField()
|
|
||||||
|
|
||||||
question_answer_llm = dspy.OpenAI(model = "gpt-3.5-turbo-instruct")
|
|
||||||
|
|
||||||
class CogneeSearch(dspy.Module):
|
|
||||||
def __init__(self, ):
|
|
||||||
super().__init__()
|
|
||||||
self.generate_answer = dspy.TypedChainOfThought(AnswerFromContext)
|
|
||||||
|
|
||||||
def forward(self, question):
|
|
||||||
context = asyncio.run(search_similarity(question))
|
|
||||||
|
|
||||||
context_text = "\n".join(context)
|
|
||||||
print(f"Context: {context_text}")
|
|
||||||
|
|
||||||
with dspy.context(lm = question_answer_llm):
|
|
||||||
answer_prediction = self.generate_answer(context = context_text, question = question)
|
|
||||||
answer = answer_prediction.answer
|
|
||||||
|
|
||||||
print(f"Question: {question}")
|
|
||||||
print(f"Answer: {answer}")
|
|
||||||
|
|
||||||
return dspy.Prediction(context = context_text, answer = answer)
|
|
||||||
|
|
@ -1,43 +0,0 @@
|
||||||
import asyncio
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
||||||
|
|
||||||
async def search_adjacent(query: str) -> list[(str, str)]:
|
|
||||||
"""
|
|
||||||
Find the neighbours of a given node in the graph and return their ids and descriptions.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- query (str): The query string to filter nodes by.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- list[(str, str)]: A list containing the unique identifiers and names of the neighbours of the given node.
|
|
||||||
"""
|
|
||||||
node_id = query
|
|
||||||
|
|
||||||
if node_id is None:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
|
||||||
|
|
||||||
exact_node = await graph_engine.extract_node(node_id)
|
|
||||||
|
|
||||||
if exact_node is not None and "uuid" in exact_node:
|
|
||||||
neighbours = await graph_engine.get_neighbours(exact_node["uuid"])
|
|
||||||
else:
|
|
||||||
vector_engine = get_vector_engine()
|
|
||||||
results = await asyncio.gather(
|
|
||||||
vector_engine.search("entities", query_text = query, limit = 10),
|
|
||||||
vector_engine.search("classification", query_text = query, limit = 10),
|
|
||||||
)
|
|
||||||
results = [*results[0], *results[1]]
|
|
||||||
relevant_results = [result for result in results if result.score < 0.5][:5]
|
|
||||||
|
|
||||||
if len(relevant_results) == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
node_neighbours = await asyncio.gather(*[graph_engine.get_neighbours(result.id) for result in relevant_results])
|
|
||||||
neighbours = []
|
|
||||||
for neighbour_ids in node_neighbours:
|
|
||||||
neighbours.extend(neighbour_ids)
|
|
||||||
|
|
||||||
return neighbours
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine, get_graph_config
|
|
||||||
|
|
||||||
async def search_cypher(query: str):
|
|
||||||
"""
|
|
||||||
Use a Cypher query to search the graph and return the results.
|
|
||||||
"""
|
|
||||||
graph_config = get_graph_config()
|
|
||||||
|
|
||||||
if graph_config.graph_database_provider == "neo4j":
|
|
||||||
graph_engine = await get_graph_engine()
|
|
||||||
result = await graph_engine.graph().run(query)
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported search type for the used graph engine.")
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
||||||
|
|
||||||
async def search_similarity(query: str) -> list[str, str]:
|
|
||||||
"""
|
|
||||||
Parameters:
|
|
||||||
- query (str): The query string to filter nodes by.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- list(chunk): A list of objects providing information about the chunks related to query.
|
|
||||||
"""
|
|
||||||
vector_engine = get_vector_engine()
|
|
||||||
|
|
||||||
similar_results = await vector_engine.search("chunks", query, limit = 5)
|
|
||||||
|
|
||||||
results = [
|
|
||||||
parse_payload(result.payload) for result in similar_results
|
|
||||||
]
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def parse_payload(payload: dict) -> dict:
|
|
||||||
return {
|
|
||||||
"text": payload["text"],
|
|
||||||
"chunk_id": payload["chunk_id"],
|
|
||||||
"document_id": payload["document_id"],
|
|
||||||
}
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
||||||
|
|
||||||
async def search_summary(query: str) -> list:
|
|
||||||
"""
|
|
||||||
Parameters:
|
|
||||||
- query (str): The query string to filter summaries by.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- list[str, UUID]: A list of objects providing information about the summaries related to query.
|
|
||||||
"""
|
|
||||||
vector_engine = get_vector_engine()
|
|
||||||
|
|
||||||
summaries_results = await vector_engine.search("summaries", query, limit = 5)
|
|
||||||
|
|
||||||
summaries = [summary.payload for summary in summaries_results]
|
|
||||||
|
|
||||||
return summaries
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
from typing import Type
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from cognee.infrastructure.llm.prompts import render_prompt
|
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
|
||||||
|
|
||||||
async def categorize_relevant_category(query: str, summary, response_model: Type[BaseModel]):
|
|
||||||
llm_client = get_llm_client()
|
|
||||||
|
|
||||||
enriched_query= render_prompt("categorize_categories.txt", {"query": query, "categories": summary})
|
|
||||||
|
|
||||||
|
|
||||||
system_prompt = " Choose the relevant categories and return appropriate output based on the model"
|
|
||||||
|
|
||||||
llm_output = await llm_client.acreate_structured_output(enriched_query, system_prompt, response_model)
|
|
||||||
|
|
||||||
return llm_output.model_dump()
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
from typing import Type
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from cognee.infrastructure.llm.prompts import render_prompt
|
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
|
||||||
|
|
||||||
async def categorize_relevant_summary(query: str, summaries, response_model: Type[BaseModel]):
|
|
||||||
llm_client = get_llm_client()
|
|
||||||
|
|
||||||
enriched_query= render_prompt("categorize_summary.txt", {"query": query, "summaries": summaries})
|
|
||||||
|
|
||||||
system_prompt = "Choose the relevant summaries and return appropriate output based on the model"
|
|
||||||
|
|
||||||
llm_output = await llm_client.acreate_structured_output(enriched_query, system_prompt, response_model)
|
|
||||||
|
|
||||||
return llm_output
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
import logging
|
|
||||||
from typing import List, Dict
|
|
||||||
from cognee.modules.cognify.config import get_cognify_config
|
|
||||||
from .extraction.categorize_relevant_summary import categorize_relevant_summary
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
async def get_cognitive_layers(content: str, categories: List[Dict]):
|
|
||||||
try:
|
|
||||||
cognify_config = get_cognify_config()
|
|
||||||
return (await categorize_relevant_summary(
|
|
||||||
content,
|
|
||||||
categories[0],
|
|
||||||
cognify_config.summarization_model,
|
|
||||||
)).cognitive_layers
|
|
||||||
except Exception as error:
|
|
||||||
logger.error("Error extracting cognitive layers from content: %s", error, exc_info = True)
|
|
||||||
raise error
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
""" Placeholder for BM25 implementation"""
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
"""Placeholder for fusions search implementation"""
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
||||||
import asyncio
|
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
||||||
|
|
||||||
async def search_traverse(query: str):
|
|
||||||
node_id = query
|
|
||||||
rules = set()
|
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
|
||||||
vector_engine = get_vector_engine()
|
|
||||||
|
|
||||||
exact_node = await graph_engine.extract_node(node_id)
|
|
||||||
|
|
||||||
if exact_node is not None and "uuid" in exact_node:
|
|
||||||
edges = await graph_engine.get_edges(exact_node["uuid"])
|
|
||||||
|
|
||||||
for edge in edges:
|
|
||||||
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
|
|
||||||
else:
|
|
||||||
results = await asyncio.gather(
|
|
||||||
vector_engine.search("entities", query_text = query, limit = 10),
|
|
||||||
vector_engine.search("classification", query_text = query, limit = 10),
|
|
||||||
)
|
|
||||||
results = [*results[0], *results[1]]
|
|
||||||
relevant_results = [result for result in results if result.score < 0.5][:5]
|
|
||||||
|
|
||||||
if len(relevant_results) > 0:
|
|
||||||
for result in relevant_results:
|
|
||||||
graph_node_id = result.id
|
|
||||||
|
|
||||||
edges = await graph_engine.get_edges(graph_node_id)
|
|
||||||
|
|
||||||
for edge in edges:
|
|
||||||
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
|
|
||||||
|
|
||||||
return list(rules)
|
|
||||||
46
cognee/modules/storage/utils/__init__.py
Normal file
46
cognee/modules/storage/utils/__init__.py
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
import json
|
||||||
|
from uuid import UUID
|
||||||
|
from datetime import datetime
|
||||||
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
class JSONEncoder(json.JSONEncoder):
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, datetime):
|
||||||
|
return obj.isoformat() # Convert datetime to ISO 8601 string
|
||||||
|
elif isinstance(obj, UUID):
|
||||||
|
# if the obj is uuid, we simply return the value of uuid
|
||||||
|
return str(obj)
|
||||||
|
return json.JSONEncoder.default(self, obj)
|
||||||
|
|
||||||
|
|
||||||
|
from pydantic import create_model
|
||||||
|
|
||||||
|
def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list = []):
|
||||||
|
fields = {
|
||||||
|
name: (field.annotation, field.default if field.default is not None else PydanticUndefined)
|
||||||
|
for name, field in model.model_fields.items()
|
||||||
|
if name not in exclude_fields
|
||||||
|
}
|
||||||
|
|
||||||
|
final_fields = {
|
||||||
|
**fields,
|
||||||
|
**include_fields
|
||||||
|
}
|
||||||
|
|
||||||
|
return create_model(model.__name__, **final_fields)
|
||||||
|
|
||||||
|
def get_own_properties(data_point: DataPoint):
|
||||||
|
properties = {}
|
||||||
|
|
||||||
|
for field_name, field_value in data_point:
|
||||||
|
if field_name == "_metadata" \
|
||||||
|
or isinstance(field_value, dict) \
|
||||||
|
or isinstance(field_value, DataPoint) \
|
||||||
|
or (isinstance(field_value, list) and isinstance(field_value[0], DataPoint)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
properties[field_name] = field_value
|
||||||
|
|
||||||
|
return properties
|
||||||
|
|
@ -1,84 +1,95 @@
|
||||||
from typing import List, Union, Literal, Optional
|
from typing import Any, List, Union, Literal, Optional
|
||||||
from pydantic import BaseModel
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
class BaseClass(BaseModel):
|
class Variable(DataPoint):
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
type: Literal["Class"] = "Class"
|
|
||||||
description: str
|
|
||||||
constructor_parameters: Optional[List[str]] = None
|
|
||||||
|
|
||||||
class Class(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
type: Literal["Class"] = "Class"
|
|
||||||
description: str
|
|
||||||
constructor_parameters: Optional[List[str]] = None
|
|
||||||
from_class: Optional[BaseClass] = None
|
|
||||||
|
|
||||||
class ClassInstance(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
type: Literal["ClassInstance"] = "ClassInstance"
|
|
||||||
description: str
|
|
||||||
from_class: Class
|
|
||||||
|
|
||||||
class Function(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
type: Literal["Function"] = "Function"
|
|
||||||
description: str
|
|
||||||
parameters: Optional[List[str]] = None
|
|
||||||
return_type: str
|
|
||||||
is_static: Optional[bool] = False
|
|
||||||
|
|
||||||
class Variable(BaseModel):
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
type: Literal["Variable"] = "Variable"
|
type: Literal["Variable"] = "Variable"
|
||||||
description: str
|
description: str
|
||||||
is_static: Optional[bool] = False
|
is_static: Optional[bool] = False
|
||||||
default_value: Optional[str] = None
|
default_value: Optional[str] = None
|
||||||
|
data_type: str
|
||||||
|
|
||||||
class Operator(BaseModel):
|
_metadata = {
|
||||||
|
"index_fields": ["name"]
|
||||||
|
}
|
||||||
|
|
||||||
|
class Operator(DataPoint):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
type: Literal["Operator"] = "Operator"
|
type: Literal["Operator"] = "Operator"
|
||||||
description: str
|
description: str
|
||||||
return_type: str
|
return_type: str
|
||||||
|
|
||||||
class ExpressionPart(BaseModel):
|
class Class(DataPoint):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
type: Literal["Class"] = "Class"
|
||||||
|
description: str
|
||||||
|
constructor_parameters: List[Variable]
|
||||||
|
extended_from_class: Optional["Class"] = None
|
||||||
|
has_methods: list["Function"]
|
||||||
|
|
||||||
|
_metadata = {
|
||||||
|
"index_fields": ["name"]
|
||||||
|
}
|
||||||
|
|
||||||
|
class ClassInstance(DataPoint):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
type: Literal["ClassInstance"] = "ClassInstance"
|
||||||
|
description: str
|
||||||
|
from_class: Class
|
||||||
|
instantiated_by: Union["Function"]
|
||||||
|
instantiation_arguments: List[Variable]
|
||||||
|
|
||||||
|
_metadata = {
|
||||||
|
"index_fields": ["name"]
|
||||||
|
}
|
||||||
|
|
||||||
|
class Function(DataPoint):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
type: Literal["Function"] = "Function"
|
||||||
|
description: str
|
||||||
|
parameters: List[Variable]
|
||||||
|
return_type: str
|
||||||
|
is_static: Optional[bool] = False
|
||||||
|
|
||||||
|
_metadata = {
|
||||||
|
"index_fields": ["name"]
|
||||||
|
}
|
||||||
|
|
||||||
|
class FunctionCall(DataPoint):
|
||||||
|
id: str
|
||||||
|
type: Literal["FunctionCall"] = "FunctionCall"
|
||||||
|
called_by: Union[Function, Literal["main"]]
|
||||||
|
function_called: Function
|
||||||
|
function_arguments: List[Any]
|
||||||
|
|
||||||
|
class Expression(DataPoint):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
type: Literal["Expression"] = "Expression"
|
type: Literal["Expression"] = "Expression"
|
||||||
description: str
|
description: str
|
||||||
expression: str
|
expression: str
|
||||||
members: List[Union[Variable, Function, Operator]]
|
members: List[Union[Variable, Function, Operator, "Expression"]]
|
||||||
|
|
||||||
class Expression(BaseModel):
|
class SourceCodeGraph(DataPoint):
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
type: Literal["Expression"] = "Expression"
|
|
||||||
description: str
|
|
||||||
expression: str
|
|
||||||
members: List[Union[Variable, Function, Operator, ExpressionPart]]
|
|
||||||
|
|
||||||
class Edge(BaseModel):
|
|
||||||
source_node_id: str
|
|
||||||
target_node_id: str
|
|
||||||
relationship_name: Literal["called in", "stored in", "defined in", "returned by", "instantiated in", "uses", "updates"]
|
|
||||||
|
|
||||||
class SourceCodeGraph(BaseModel):
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
language: str
|
language: str
|
||||||
nodes: List[Union[
|
nodes: List[Union[
|
||||||
Class,
|
Class,
|
||||||
|
ClassInstance,
|
||||||
Function,
|
Function,
|
||||||
|
FunctionCall,
|
||||||
Variable,
|
Variable,
|
||||||
Operator,
|
Operator,
|
||||||
Expression,
|
Expression,
|
||||||
ClassInstance,
|
|
||||||
]]
|
]]
|
||||||
edges: List[Edge]
|
|
||||||
|
Class.model_rebuild()
|
||||||
|
ClassInstance.model_rebuild()
|
||||||
|
Expression.model_rebuild()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
""" This module contains utility functions for the cognee. """
|
""" This module contains utility functions for the cognee. """
|
||||||
import os
|
import os
|
||||||
import datetime
|
from datetime import datetime, timezone
|
||||||
import graphistry
|
import graphistry
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -45,7 +45,7 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
|
||||||
host = "https://eu.i.posthog.com"
|
host = "https://eu.i.posthog.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
current_time = datetime.datetime.now()
|
current_time = datetime.now(timezone.utc)
|
||||||
properties = {
|
properties = {
|
||||||
"time": current_time.strftime("%m/%d/%Y"),
|
"time": current_time.strftime("%m/%d/%Y"),
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
|
@ -110,30 +110,36 @@ async def register_graphistry():
|
||||||
graphistry.register(api = 3, username = config.graphistry_username, password = config.graphistry_password)
|
graphistry.register(api = 3, username = config.graphistry_username, password = config.graphistry_password)
|
||||||
|
|
||||||
|
|
||||||
def prepare_edges(graph):
|
def prepare_edges(graph, source, target, edge_key):
|
||||||
return nx.to_pandas_edgelist(graph)
|
edge_list = [{
|
||||||
|
source: str(edge[0]),
|
||||||
|
target: str(edge[1]),
|
||||||
|
edge_key: str(edge[2]),
|
||||||
|
} for edge in graph.edges(keys = True, data = True)]
|
||||||
|
|
||||||
|
return pd.DataFrame(edge_list)
|
||||||
|
|
||||||
|
|
||||||
def prepare_nodes(graph, include_size=False):
|
def prepare_nodes(graph, include_size=False):
|
||||||
nodes_data = []
|
nodes_data = []
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
node_info = graph.nodes[node]
|
node_info = graph.nodes[node]
|
||||||
description = node_info.get("layer_description", {}).get("layer", "Default Layer") if isinstance(
|
|
||||||
node_info.get("layer_description"), dict) else node_info.get("layer_description", "Default Layer")
|
|
||||||
# description = node_info['layer_description']['layer'] if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info['layer_description'] else node_info.get('layer_description', node)
|
|
||||||
# if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info.get('layer_description'):
|
|
||||||
# description = node_info['layer_description']['layer']
|
|
||||||
# # Use 'layer_description' directly if it's not a dictionary, otherwise default to node ID
|
|
||||||
# else:
|
|
||||||
# description = node_info.get('layer_description', node)
|
|
||||||
|
|
||||||
node_data = {"id": node, "layer_description": description}
|
if not node_info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
node_data = {
|
||||||
|
"id": str(node),
|
||||||
|
"name": node_info["name"] if "name" in node_info else str(node),
|
||||||
|
}
|
||||||
|
|
||||||
if include_size:
|
if include_size:
|
||||||
default_size = 10 # Default node size
|
default_size = 10 # Default node size
|
||||||
larger_size = 20 # Size for nodes with specific keywords in their ID
|
larger_size = 20 # Size for nodes with specific keywords in their ID
|
||||||
keywords = ["DOCUMENT", "User", "LAYER"]
|
keywords = ["DOCUMENT", "User"]
|
||||||
node_size = larger_size if any(keyword in str(node) for keyword in keywords) else default_size
|
node_size = larger_size if any(keyword in str(node) for keyword in keywords) else default_size
|
||||||
node_data["size"] = node_size
|
node_data["size"] = node_size
|
||||||
|
|
||||||
nodes_data.append(node_data)
|
nodes_data.append(node_data)
|
||||||
|
|
||||||
return pd.DataFrame(nodes_data)
|
return pd.DataFrame(nodes_data)
|
||||||
|
|
@ -153,28 +159,28 @@ async def render_graph(graph, include_nodes=False, include_color=False, include_
|
||||||
|
|
||||||
graph = networkx_graph
|
graph = networkx_graph
|
||||||
|
|
||||||
edges = prepare_edges(graph)
|
edges = prepare_edges(graph, "source_node", "target_node", "relationship_name")
|
||||||
plotter = graphistry.edges(edges, "source", "target")
|
plotter = graphistry.edges(edges, "source_node", "target_node")
|
||||||
|
plotter = plotter.bind(edge_label = "relationship_name")
|
||||||
|
|
||||||
if include_nodes:
|
if include_nodes:
|
||||||
nodes = prepare_nodes(graph, include_size=include_size)
|
nodes = prepare_nodes(graph, include_size = include_size)
|
||||||
plotter = plotter.nodes(nodes, "id")
|
plotter = plotter.nodes(nodes, "id")
|
||||||
|
|
||||||
|
|
||||||
if include_size:
|
if include_size:
|
||||||
plotter = plotter.bind(point_size="size")
|
plotter = plotter.bind(point_size = "size")
|
||||||
|
|
||||||
|
|
||||||
if include_color:
|
if include_color:
|
||||||
unique_layers = nodes["layer_description"].unique()
|
pass
|
||||||
color_palette = generate_color_palette(unique_layers)
|
# unique_layers = nodes["layer_description"].unique()
|
||||||
plotter = plotter.encode_point_color("layer_description", categorical_mapping=color_palette,
|
# color_palette = generate_color_palette(unique_layers)
|
||||||
default_mapping="silver")
|
# plotter = plotter.encode_point_color("layer_description", categorical_mapping=color_palette,
|
||||||
|
# default_mapping="silver")
|
||||||
|
|
||||||
|
|
||||||
if include_labels:
|
if include_labels:
|
||||||
plotter = plotter.bind(point_label = "layer_description")
|
plotter = plotter.bind(point_label = "name")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Visualization
|
# Visualization
|
||||||
|
|
|
||||||
|
|
@ -1,10 +0,0 @@
|
||||||
from .summarization.summarize_text import summarize_text
|
|
||||||
from .chunk_naive_llm_classifier.chunk_naive_llm_classifier import chunk_naive_llm_classifier
|
|
||||||
from .chunk_remove_disconnected.chunk_remove_disconnected import chunk_remove_disconnected
|
|
||||||
from .chunk_update_check.chunk_update_check import chunk_update_check
|
|
||||||
from .save_chunks_to_store.save_chunks_to_store import save_chunks_to_store
|
|
||||||
from .source_documents_to_chunks.source_documents_to_chunks import source_documents_to_chunks
|
|
||||||
from .infer_data_ontology.infer_data_ontology import infer_data_ontology
|
|
||||||
from .check_permissions_on_documents.check_permissions_on_documents import check_permissions_on_documents
|
|
||||||
from .classify_documents.classify_documents import classify_documents
|
|
||||||
from .graph.chunks_into_graph import chunks_into_graph
|
|
||||||
|
|
@ -5,7 +5,7 @@ from pydantic import BaseModel
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
|
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
|
||||||
from cognee.modules.data.extraction.extract_categories import extract_categories
|
from cognee.modules.data.extraction.extract_categories import extract_categories
|
||||||
from cognee.modules.chunking import DocumentChunk
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||||
|
|
||||||
|
|
||||||
async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
|
async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
|
||||||
|
|
@ -65,7 +65,7 @@ async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classific
|
||||||
"chunk_id": str(data_chunk.chunk_id),
|
"chunk_id": str(data_chunk.chunk_id),
|
||||||
"document_id": str(data_chunk.document_id),
|
"document_id": str(data_chunk.document_id),
|
||||||
}),
|
}),
|
||||||
embed_field="text",
|
index_fields=["text"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -104,7 +104,7 @@ async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classific
|
||||||
"chunk_id": str(data_chunk.chunk_id),
|
"chunk_id": str(data_chunk.chunk_id),
|
||||||
"document_id": str(data_chunk.document_id),
|
"document_id": str(data_chunk.document_id),
|
||||||
}),
|
}),
|
||||||
embed_field="text",
|
index_fields=["text"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,39 +0,0 @@
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from cognee.base_config import get_base_config
|
|
||||||
|
|
||||||
BaseConfig = get_base_config()
|
|
||||||
|
|
||||||
async def translate_text(data, source_language:str='sr', target_language:str='en', region_name='eu-west-1'):
|
|
||||||
"""
|
|
||||||
Translate text from source language to target language using AWS Translate.
|
|
||||||
Parameters:
|
|
||||||
data (str): The text to be translated.
|
|
||||||
source_language (str): The source language code (e.g., 'sr' for Serbian). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
|
|
||||||
target_language (str): The target language code (e.g., 'en' for English). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
|
|
||||||
region_name (str): AWS region name.
|
|
||||||
Returns:
|
|
||||||
str: Translated text or an error message.
|
|
||||||
"""
|
|
||||||
import boto3
|
|
||||||
from botocore.exceptions import BotoCoreError, ClientError
|
|
||||||
|
|
||||||
if not data:
|
|
||||||
yield "No text provided for translation."
|
|
||||||
|
|
||||||
if not source_language or not target_language:
|
|
||||||
yield "Both source and target language codes are required."
|
|
||||||
|
|
||||||
try:
|
|
||||||
translate = boto3.client(service_name='translate', region_name=region_name, use_ssl=True)
|
|
||||||
result = translate.translate_text(Text=data, SourceLanguageCode=source_language, TargetLanguageCode=target_language)
|
|
||||||
yield result.get('TranslatedText', 'No translation found.')
|
|
||||||
|
|
||||||
except BotoCoreError as e:
|
|
||||||
logging.info(f"BotoCoreError occurred: {e}")
|
|
||||||
yield "Error with AWS Translate service configuration or request."
|
|
||||||
|
|
||||||
except ClientError as e:
|
|
||||||
logging.info(f"ClientError occurred: {e}")
|
|
||||||
yield "Error with AWS client or network issue."
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
||||||
from cognee.modules.chunking import DocumentChunk
|
|
||||||
|
|
||||||
|
|
||||||
async def chunk_update_check(data_chunks: list[DocumentChunk], collection_name: str) -> list[DocumentChunk]:
|
|
||||||
vector_engine = get_vector_engine()
|
|
||||||
|
|
||||||
if not await vector_engine.has_collection(collection_name):
|
|
||||||
# If collection doesn't exist, all data_chunks are new
|
|
||||||
return data_chunks
|
|
||||||
|
|
||||||
existing_chunks = await vector_engine.retrieve(
|
|
||||||
collection_name,
|
|
||||||
[str(chunk.chunk_id) for chunk in data_chunks],
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_chunks_map = {str(chunk.id): chunk.payload for chunk in existing_chunks}
|
|
||||||
|
|
||||||
affected_data_chunks = []
|
|
||||||
|
|
||||||
for chunk in data_chunks:
|
|
||||||
if chunk.chunk_id not in existing_chunks_map or \
|
|
||||||
chunk.text != existing_chunks_map[chunk.chunk_id]["text"]:
|
|
||||||
affected_data_chunks.append(chunk)
|
|
||||||
|
|
||||||
return affected_data_chunks
|
|
||||||
|
|
@ -2,3 +2,4 @@ from .query_chunks import query_chunks
|
||||||
from .chunk_by_word import chunk_by_word
|
from .chunk_by_word import chunk_by_word
|
||||||
from .chunk_by_sentence import chunk_by_sentence
|
from .chunk_by_sentence import chunk_by_sentence
|
||||||
from .chunk_by_paragraph import chunk_by_paragraph
|
from .chunk_by_paragraph import chunk_by_paragraph
|
||||||
|
from .remove_disconnected_chunks import remove_disconnected_chunks
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from cognee.tasks.chunking import chunk_by_paragraph
|
from cognee.tasks.chunks import chunk_by_paragraph
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
def test_chunking_on_whole_text():
|
def test_chunking_on_whole_text():
|
||||||
|
|
@ -10,7 +10,7 @@ async def query_chunks(query: str) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
found_chunks = await vector_engine.search("chunks", query, limit = 5)
|
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit = 5)
|
||||||
|
|
||||||
chunks = [result.payload for result in found_chunks]
|
chunks = [result.payload for result in found_chunks]
|
||||||
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.modules.chunking import DocumentChunk
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||||
|
|
||||||
async def chunk_remove_disconnected(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
|
async def remove_disconnected_chunks(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
document_ids = set((data_chunk.document_id for data_chunk in data_chunks))
|
document_ids = set((data_chunk.document_id for data_chunk in data_chunks))
|
||||||
|
|
@ -1,13 +0,0 @@
|
||||||
from cognee.modules.data.models import Data
|
|
||||||
from cognee.modules.data.processing.document_types import Document, PdfDocument, AudioDocument, ImageDocument, TextDocument
|
|
||||||
|
|
||||||
def classify_documents(data_documents: list[Data]) -> list[Document]:
|
|
||||||
documents = [
|
|
||||||
PdfDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "pdf" else
|
|
||||||
AudioDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "audio" else
|
|
||||||
ImageDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "image" else
|
|
||||||
TextDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location)
|
|
||||||
for data_item in data_documents
|
|
||||||
]
|
|
||||||
|
|
||||||
return documents
|
|
||||||
3
cognee/tasks/documents/__init__.py
Normal file
3
cognee/tasks/documents/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .classify_documents import classify_documents
|
||||||
|
from .extract_chunks_from_documents import extract_chunks_from_documents
|
||||||
|
from .check_permissions_on_documents import check_permissions_on_documents
|
||||||
13
cognee/tasks/documents/classify_documents.py
Normal file
13
cognee/tasks/documents/classify_documents.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
from cognee.modules.data.models import Data
|
||||||
|
from cognee.modules.data.processing.document_types import Document, PdfDocument, AudioDocument, ImageDocument, TextDocument
|
||||||
|
|
||||||
|
def classify_documents(data_documents: list[Data]) -> list[Document]:
|
||||||
|
documents = [
|
||||||
|
PdfDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "pdf" else
|
||||||
|
AudioDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "audio" else
|
||||||
|
ImageDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "image" else
|
||||||
|
TextDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location)
|
||||||
|
for data_item in data_documents
|
||||||
|
]
|
||||||
|
|
||||||
|
return documents
|
||||||
7
cognee/tasks/documents/extract_chunks_from_documents.py
Normal file
7
cognee/tasks/documents/extract_chunks_from_documents.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
from cognee.modules.data.processing.document_types.Document import Document
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024):
|
||||||
|
for document in documents:
|
||||||
|
for document_chunk in document.read(chunk_size = chunk_size):
|
||||||
|
yield document_chunk
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
from .chunks_into_graph import chunks_into_graph
|
from .extract_graph_from_data import extract_graph_from_data
|
||||||
|
from .extract_graph_from_code import extract_graph_from_code
|
||||||
from .query_graph_connections import query_graph_connections
|
from .query_graph_connections import query_graph_connections
|
||||||
|
|
|
||||||
|
|
@ -1,213 +0,0 @@
|
||||||
import json
|
|
||||||
import asyncio
|
|
||||||
from uuid import uuid5, NAMESPACE_OID
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Type
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
||||||
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
|
|
||||||
from cognee.modules.data.extraction.knowledge_graph.extract_content_graph import extract_content_graph
|
|
||||||
from cognee.modules.chunking import DocumentChunk
|
|
||||||
from cognee.modules.graph.utils import generate_node_id, generate_node_name
|
|
||||||
|
|
||||||
|
|
||||||
class EntityNode(BaseModel):
|
|
||||||
uuid: str
|
|
||||||
name: str
|
|
||||||
type: str
|
|
||||||
description: str
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
|
|
||||||
async def chunks_into_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel], collection_name: str):
|
|
||||||
chunk_graphs = await asyncio.gather(
|
|
||||||
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
|
||||||
)
|
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
|
||||||
graph_engine = await get_graph_engine()
|
|
||||||
|
|
||||||
has_collection = await vector_engine.has_collection(collection_name)
|
|
||||||
|
|
||||||
if not has_collection:
|
|
||||||
await vector_engine.create_collection(collection_name, payload_schema = EntityNode)
|
|
||||||
|
|
||||||
processed_nodes = {}
|
|
||||||
type_node_edges = []
|
|
||||||
entity_node_edges = []
|
|
||||||
type_entity_edges = []
|
|
||||||
|
|
||||||
for (chunk_index, chunk) in enumerate(data_chunks):
|
|
||||||
chunk_graph = chunk_graphs[chunk_index]
|
|
||||||
for node in chunk_graph.nodes:
|
|
||||||
type_node_id = generate_node_id(node.type)
|
|
||||||
entity_node_id = generate_node_id(node.id)
|
|
||||||
|
|
||||||
if type_node_id not in processed_nodes:
|
|
||||||
type_node_edges.append((str(chunk.chunk_id), type_node_id, "contains_entity_type"))
|
|
||||||
processed_nodes[type_node_id] = True
|
|
||||||
|
|
||||||
if entity_node_id not in processed_nodes:
|
|
||||||
entity_node_edges.append((str(chunk.chunk_id), entity_node_id, "contains_entity"))
|
|
||||||
type_entity_edges.append((entity_node_id, type_node_id, "is_entity_type"))
|
|
||||||
processed_nodes[entity_node_id] = True
|
|
||||||
|
|
||||||
graph_node_edges = [
|
|
||||||
(edge.target_node_id, edge.source_node_id, edge.relationship_name) \
|
|
||||||
for edge in chunk_graph.edges
|
|
||||||
]
|
|
||||||
|
|
||||||
existing_edges = await graph_engine.has_edges([
|
|
||||||
*type_node_edges,
|
|
||||||
*entity_node_edges,
|
|
||||||
*type_entity_edges,
|
|
||||||
*graph_node_edges,
|
|
||||||
])
|
|
||||||
|
|
||||||
existing_edges_map = {}
|
|
||||||
existing_nodes_map = {}
|
|
||||||
|
|
||||||
for edge in existing_edges:
|
|
||||||
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
|
|
||||||
existing_nodes_map[edge[0]] = True
|
|
||||||
|
|
||||||
graph_nodes = []
|
|
||||||
graph_edges = []
|
|
||||||
data_points = []
|
|
||||||
|
|
||||||
for (chunk_index, chunk) in enumerate(data_chunks):
|
|
||||||
graph = chunk_graphs[chunk_index]
|
|
||||||
if graph is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for node in graph.nodes:
|
|
||||||
node_id = generate_node_id(node.id)
|
|
||||||
node_name = generate_node_name(node.name)
|
|
||||||
|
|
||||||
type_node_id = generate_node_id(node.type)
|
|
||||||
type_node_name = generate_node_name(node.type)
|
|
||||||
|
|
||||||
if node_id not in existing_nodes_map:
|
|
||||||
node_data = dict(
|
|
||||||
uuid = node_id,
|
|
||||||
name = node_name,
|
|
||||||
type = node_name,
|
|
||||||
description = node.description,
|
|
||||||
created_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_nodes.append((
|
|
||||||
node_id,
|
|
||||||
dict(
|
|
||||||
**node_data,
|
|
||||||
properties = json.dumps(node.properties),
|
|
||||||
)
|
|
||||||
))
|
|
||||||
|
|
||||||
data_points.append(DataPoint[EntityNode](
|
|
||||||
id = str(uuid5(NAMESPACE_OID, node_id)),
|
|
||||||
payload = node_data,
|
|
||||||
embed_field = "name",
|
|
||||||
))
|
|
||||||
|
|
||||||
existing_nodes_map[node_id] = True
|
|
||||||
|
|
||||||
edge_key = str(chunk.chunk_id) + node_id + "contains_entity"
|
|
||||||
|
|
||||||
if edge_key not in existing_edges_map:
|
|
||||||
graph_edges.append((
|
|
||||||
str(chunk.chunk_id),
|
|
||||||
node_id,
|
|
||||||
"contains_entity",
|
|
||||||
dict(
|
|
||||||
relationship_name = "contains_entity",
|
|
||||||
source_node_id = str(chunk.chunk_id),
|
|
||||||
target_node_id = node_id,
|
|
||||||
),
|
|
||||||
))
|
|
||||||
|
|
||||||
# Add relationship between entity type and entity itself: "Jake is Person"
|
|
||||||
graph_edges.append((
|
|
||||||
node_id,
|
|
||||||
type_node_id,
|
|
||||||
"is_entity_type",
|
|
||||||
dict(
|
|
||||||
relationship_name = "is_entity_type",
|
|
||||||
source_node_id = type_node_id,
|
|
||||||
target_node_id = node_id,
|
|
||||||
),
|
|
||||||
))
|
|
||||||
|
|
||||||
existing_edges_map[edge_key] = True
|
|
||||||
|
|
||||||
if type_node_id not in existing_nodes_map:
|
|
||||||
type_node_data = dict(
|
|
||||||
uuid = type_node_id,
|
|
||||||
name = type_node_name,
|
|
||||||
type = type_node_id,
|
|
||||||
description = type_node_name,
|
|
||||||
created_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_nodes.append((type_node_id, dict(
|
|
||||||
**type_node_data,
|
|
||||||
properties = json.dumps(node.properties)
|
|
||||||
)))
|
|
||||||
|
|
||||||
data_points.append(DataPoint[EntityNode](
|
|
||||||
id = str(uuid5(NAMESPACE_OID, type_node_id)),
|
|
||||||
payload = type_node_data,
|
|
||||||
embed_field = "name",
|
|
||||||
))
|
|
||||||
|
|
||||||
existing_nodes_map[type_node_id] = True
|
|
||||||
|
|
||||||
edge_key = str(chunk.chunk_id) + type_node_id + "contains_entity_type"
|
|
||||||
|
|
||||||
if edge_key not in existing_edges_map:
|
|
||||||
graph_edges.append((
|
|
||||||
str(chunk.chunk_id),
|
|
||||||
type_node_id,
|
|
||||||
"contains_entity_type",
|
|
||||||
dict(
|
|
||||||
relationship_name = "contains_entity_type",
|
|
||||||
source_node_id = str(chunk.chunk_id),
|
|
||||||
target_node_id = type_node_id,
|
|
||||||
),
|
|
||||||
))
|
|
||||||
|
|
||||||
existing_edges_map[edge_key] = True
|
|
||||||
|
|
||||||
# Add relationship that came from graphs.
|
|
||||||
for edge in graph.edges:
|
|
||||||
source_node_id = generate_node_id(edge.source_node_id)
|
|
||||||
target_node_id = generate_node_id(edge.target_node_id)
|
|
||||||
relationship_name = generate_node_name(edge.relationship_name)
|
|
||||||
edge_key = source_node_id + target_node_id + relationship_name
|
|
||||||
|
|
||||||
if edge_key not in existing_edges_map:
|
|
||||||
graph_edges.append((
|
|
||||||
generate_node_id(edge.source_node_id),
|
|
||||||
generate_node_id(edge.target_node_id),
|
|
||||||
edge.relationship_name,
|
|
||||||
dict(
|
|
||||||
relationship_name = generate_node_name(edge.relationship_name),
|
|
||||||
source_node_id = generate_node_id(edge.source_node_id),
|
|
||||||
target_node_id = generate_node_id(edge.target_node_id),
|
|
||||||
properties = json.dumps(edge.properties),
|
|
||||||
),
|
|
||||||
))
|
|
||||||
existing_edges_map[edge_key] = True
|
|
||||||
|
|
||||||
if len(data_points) > 0:
|
|
||||||
await vector_engine.create_data_points(collection_name, data_points)
|
|
||||||
|
|
||||||
if len(graph_nodes) > 0:
|
|
||||||
await graph_engine.add_nodes(graph_nodes)
|
|
||||||
|
|
||||||
if len(graph_edges) > 0:
|
|
||||||
await graph_engine.add_edges(graph_edges)
|
|
||||||
|
|
||||||
return data_chunks
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue