Merge branch 'main' into COG-575-remove-graph-overwrite-on-error
This commit is contained in:
commit
be792a7ba6
127 changed files with 3122 additions and 2951 deletions
4
.github/workflows/test_neo4j.yml
vendored
4
.github/workflows/test_neo4j.yml
vendored
|
|
@ -5,7 +5,7 @@ on:
|
|||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
types: [labeled]
|
||||
types: [labeled, synchronize]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
run_neo4j_integration_test:
|
||||
name: test
|
||||
needs: get_docs_changes
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
defaults:
|
||||
|
|
|
|||
4
.github/workflows/test_notebook.yml
vendored
4
.github/workflows/test_notebook.yml
vendored
|
|
@ -5,7 +5,7 @@ on:
|
|||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
types: [labeled]
|
||||
types: [labeled, synchronize]
|
||||
|
||||
|
||||
concurrency:
|
||||
|
|
@ -23,7 +23,7 @@ jobs:
|
|||
run_notebook_test:
|
||||
name: test
|
||||
needs: get_docs_changes
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
|
|
|
|||
4
.github/workflows/test_pgvector.yml
vendored
4
.github/workflows/test_pgvector.yml
vendored
|
|
@ -5,7 +5,7 @@ on:
|
|||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
types: [labeled]
|
||||
types: [labeled, synchronize]
|
||||
|
||||
|
||||
concurrency:
|
||||
|
|
@ -23,7 +23,7 @@ jobs:
|
|||
run_pgvector_integration_test:
|
||||
name: test
|
||||
needs: get_docs_changes
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
|
|
|
|||
6
.github/workflows/test_python_3_10.yml
vendored
6
.github/workflows/test_python_3_10.yml
vendored
|
|
@ -5,10 +5,10 @@ on:
|
|||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
types: [labeled]
|
||||
types: [labeled, synchronize]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ${{ github.event.label.name == 'run-checks' }}
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
run_common:
|
||||
name: test
|
||||
needs: get_docs_changes
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
|
|
|||
6
.github/workflows/test_python_3_11.yml
vendored
6
.github/workflows/test_python_3_11.yml
vendored
|
|
@ -5,10 +5,10 @@ on:
|
|||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
types: [labeled]
|
||||
types: [labeled, synchronize]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ${{ github.event.label.name == 'run-checks' }}
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
run_common:
|
||||
name: test
|
||||
needs: get_docs_changes
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
|
|
|||
6
.github/workflows/test_python_3_9.yml
vendored
6
.github/workflows/test_python_3_9.yml
vendored
|
|
@ -5,10 +5,10 @@ on:
|
|||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
types: [labeled]
|
||||
types: [labeled, synchronize]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ${{ github.event.label.name == 'run-checks' }}
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
run_common:
|
||||
name: test
|
||||
needs: get_docs_changes
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
|
|
|||
4
.github/workflows/test_qdrant.yml
vendored
4
.github/workflows/test_qdrant.yml
vendored
|
|
@ -5,7 +5,7 @@ on:
|
|||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
types: [labeled]
|
||||
types: [labeled, synchronize]
|
||||
|
||||
|
||||
concurrency:
|
||||
|
|
@ -23,7 +23,7 @@ jobs:
|
|||
run_qdrant_integration_test:
|
||||
name: test
|
||||
needs: get_docs_changes
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
defaults:
|
||||
|
|
|
|||
4
.github/workflows/test_weaviate.yml
vendored
4
.github/workflows/test_weaviate.yml
vendored
|
|
@ -5,7 +5,7 @@ on:
|
|||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
types: [labeled]
|
||||
types: [labeled, synchronize]
|
||||
|
||||
|
||||
concurrency:
|
||||
|
|
@ -23,7 +23,7 @@ jobs:
|
|||
run_weaviate_integration_test:
|
||||
name: test
|
||||
needs: get_docs_changes
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
|
||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
defaults:
|
||||
|
|
|
|||
24
README.md
24
README.md
|
|
@ -109,24 +109,34 @@ import asyncio
|
|||
from cognee.api.v1.search import SearchType
|
||||
|
||||
async def main():
|
||||
await cognee.prune.prune_data() # Reset cognee data
|
||||
await cognee.prune.prune_system(metadata=True) # Reset cognee system state
|
||||
# Reset cognee data
|
||||
await cognee.prune.prune_data()
|
||||
# Reset cognee system state
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
text = """
|
||||
Natural language processing (NLP) is an interdisciplinary
|
||||
subfield of computer science and information retrieval.
|
||||
"""
|
||||
|
||||
await cognee.add(text) # Add text to cognee
|
||||
await cognee.cognify() # Use LLMs and cognee to create knowledge graph
|
||||
# Add text to cognee
|
||||
await cognee.add(text)
|
||||
|
||||
search_results = await cognee.search( # Search cognee for insights
|
||||
# Use LLMs and cognee to create knowledge graph
|
||||
await cognee.cognify()
|
||||
|
||||
# Search cognee for insights
|
||||
search_results = await cognee.search(
|
||||
SearchType.INSIGHTS,
|
||||
{'query': 'Tell me about NLP'}
|
||||
"Tell me about NLP",
|
||||
)
|
||||
|
||||
for result_text in search_results: # Display results
|
||||
# Display results
|
||||
for result_text in search_results:
|
||||
print(result_text)
|
||||
# natural_language_processing is_a field
|
||||
# natural_language_processing is_subfield_of computer_science
|
||||
# natural_language_processing is_subfield_of information_retrieval
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
|
|
|||
110
cognee/api/v1/cognify/code_graph_pipeline.py
Normal file
110
cognee/api/v1/cognify/code_graph_pipeline.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
from cognee.shared.SourceCodeGraph import SourceCodeGraph
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.data.models import Dataset, Data
|
||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
|
||||
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
|
||||
from cognee.tasks.graph import extract_graph_from_code
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
logger = logging.getLogger("code_graph_pipeline")
|
||||
|
||||
update_status_lock = asyncio.Lock()
|
||||
|
||||
class PermissionDeniedException(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None):
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
existing_datasets = await get_datasets(user.id)
|
||||
|
||||
if datasets is None or len(datasets) == 0:
|
||||
# If no datasets are provided, cognify all existing datasets.
|
||||
datasets = existing_datasets
|
||||
|
||||
if type(datasets[0]) == str:
|
||||
datasets = await get_datasets_by_name(datasets, user.id)
|
||||
|
||||
existing_datasets_map = {
|
||||
generate_dataset_name(dataset.name): True for dataset in existing_datasets
|
||||
}
|
||||
|
||||
awaitables = []
|
||||
|
||||
for dataset in datasets:
|
||||
dataset_name = generate_dataset_name(dataset.name)
|
||||
|
||||
if dataset_name in existing_datasets_map:
|
||||
awaitables.append(run_pipeline(dataset, user))
|
||||
|
||||
return await asyncio.gather(*awaitables)
|
||||
|
||||
|
||||
async def run_pipeline(dataset: Dataset, user: User):
|
||||
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
|
||||
|
||||
document_ids_str = [str(document.id) for document in data_documents]
|
||||
|
||||
dataset_id = dataset.id
|
||||
dataset_name = generate_dataset_name(dataset.name)
|
||||
|
||||
send_telemetry("code_graph_pipeline EXECUTION STARTED", user.id)
|
||||
|
||||
async with update_status_lock:
|
||||
task_status = await get_pipeline_status([dataset_id])
|
||||
|
||||
if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
|
||||
logger.info("Dataset %s is already being processed.", dataset_name)
|
||||
return
|
||||
|
||||
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
try:
|
||||
tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
|
||||
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
|
||||
Task(add_data_points, task_config = { "batch_size": 10 }),
|
||||
Task(extract_graph_from_code, graph_model = SourceCodeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, data_documents, "code_graph_pipeline")
|
||||
|
||||
async for result in pipeline:
|
||||
print(result)
|
||||
|
||||
send_telemetry("code_graph_pipeline EXECUTION COMPLETED", user.id)
|
||||
|
||||
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_COMPLETED, {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
except Exception as error:
|
||||
send_telemetry("code_graph_pipeline EXECUTION ERRORED", user.id)
|
||||
|
||||
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_ERRORED, {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
raise error
|
||||
|
||||
|
||||
def generate_dataset_name(dataset_name: str) -> str:
|
||||
return dataset_name.replace(".", "_").replace(" ", "_")
|
||||
|
|
@ -9,21 +9,15 @@ from cognee.modules.data.models import Dataset, Data
|
|||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.pipelines import run_tasks, run_tasks_parallel
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
|
||||
from cognee.tasks import chunk_naive_llm_classifier, \
|
||||
chunk_remove_disconnected, \
|
||||
infer_data_ontology, \
|
||||
save_chunks_to_store, \
|
||||
chunk_update_check, \
|
||||
chunks_into_graph, \
|
||||
source_documents_to_chunks, \
|
||||
check_permissions_on_documents, \
|
||||
classify_documents
|
||||
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
|
||||
logger = logging.getLogger("cognify.v2")
|
||||
|
|
@ -87,31 +81,17 @@ async def run_cognify_pipeline(dataset: Dataset, user: User):
|
|||
try:
|
||||
cognee_config = get_cognify_config()
|
||||
|
||||
root_node_id = None
|
||||
|
||||
tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
|
||||
Task(infer_data_ontology, root_node_id = root_node_id, ontology_model = KnowledgeGraph),
|
||||
Task(source_documents_to_chunks, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||
Task(chunks_into_graph, graph_model = KnowledgeGraph, collection_name = "entities", task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
||||
Task(chunk_update_check, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
||||
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
|
||||
Task(add_data_points, task_config = { "batch_size": 10 }),
|
||||
Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
save_chunks_to_store,
|
||||
collection_name = "chunks",
|
||||
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
|
||||
run_tasks_parallel([
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model = cognee_config.summarization_model,
|
||||
collection_name = "summaries",
|
||||
),
|
||||
Task(
|
||||
chunk_naive_llm_classifier,
|
||||
classification_model = cognee_config.classification_model,
|
||||
),
|
||||
]),
|
||||
Task(chunk_remove_disconnected), # Remove the obsolete document chunks.
|
||||
summarize_text,
|
||||
summarization_model = cognee_config.summarization_model,
|
||||
task_config = { "batch_size": 10 }
|
||||
),
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from cognee.shared.utils import send_telemetry
|
|||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
from cognee.tasks.chunking import query_chunks
|
||||
from cognee.tasks.chunks import query_chunks
|
||||
from cognee.tasks.graph import query_graph_connections
|
||||
from cognee.tasks.summarization import query_summaries
|
||||
|
||||
|
|
|
|||
|
|
@ -1,198 +0,0 @@
|
|||
""" FalcorDB Adapter for Graph Database"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Any, List, Dict
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
from falkordb.asyncio import FalkorDB
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
|
||||
logger = logging.getLogger("FalcorDBAdapter")
|
||||
|
||||
class FalcorDBAdapter(GraphDBInterface):
|
||||
def __init__(
|
||||
self,
|
||||
graph_database_url: str,
|
||||
graph_database_username: str,
|
||||
graph_database_password: str,
|
||||
graph_database_port: int,
|
||||
driver: Optional[Any] = None,
|
||||
graph_name: str = "DefaultGraph",
|
||||
):
|
||||
self.driver = FalkorDB(
|
||||
host = graph_database_url,
|
||||
port = graph_database_port)
|
||||
self.graph_name = graph_name
|
||||
|
||||
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
selected_graph = self.driver.select_graph(self.graph_name)
|
||||
|
||||
result = await selected_graph.query(query)
|
||||
return result.result_set
|
||||
|
||||
except Exception as error:
|
||||
logger.error("Falkor query error: %s", error, exc_info = True)
|
||||
raise error
|
||||
|
||||
async def graph(self):
|
||||
return self.driver
|
||||
|
||||
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
|
||||
node_id = node_id.replace(":", "_")
|
||||
|
||||
serialized_properties = self.serialize_properties(node_properties)
|
||||
|
||||
if "name" not in serialized_properties:
|
||||
serialized_properties["name"] = node_id
|
||||
|
||||
# serialized_properties["created_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
# serialized_properties["updated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# properties = ", ".join(f"{property_name}: ${property_name}" for property_name in serialized_properties.keys())
|
||||
|
||||
query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
|
||||
ON CREATE SET node += $properties
|
||||
RETURN ID(node) AS internal_id, node.id AS nodeId"""
|
||||
|
||||
params = {
|
||||
"node_id": node_id,
|
||||
"properties": serialized_properties,
|
||||
}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def add_nodes(self, nodes: list[tuple[str, dict[str, Any]]]) -> None:
|
||||
for node in nodes:
|
||||
node_id, node_properties = node
|
||||
node_id = node_id.replace(":", "_")
|
||||
|
||||
await self.add_node(
|
||||
node_id = node_id,
|
||||
node_properties = node_properties,
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def extract_node_description(self, node_id: str):
|
||||
query = """MATCH (n)-[r]->(m)
|
||||
WHERE n.id = $node_id
|
||||
AND NOT m.id CONTAINS 'DefaultGraphModel'
|
||||
RETURN m
|
||||
"""
|
||||
|
||||
result = await self.query(query, dict(node_id = node_id))
|
||||
|
||||
descriptions = []
|
||||
|
||||
for node in result:
|
||||
# Assuming 'm' is a consistent key in your data structure
|
||||
attributes = node.get("m", {})
|
||||
|
||||
# Ensure all required attributes are present
|
||||
if all(key in attributes for key in ["id", "layer_id", "description"]):
|
||||
descriptions.append({
|
||||
"id": attributes["id"],
|
||||
"layer_id": attributes["layer_id"],
|
||||
"description": attributes["description"],
|
||||
})
|
||||
|
||||
return descriptions
|
||||
|
||||
async def get_layer_nodes(self):
|
||||
query = """MATCH (node) WHERE node.layer_id IS NOT NULL
|
||||
RETURN node"""
|
||||
|
||||
return [result["node"] for result in (await self.query(query))]
|
||||
|
||||
async def extract_node(self, node_id: str):
|
||||
results = self.extract_nodes([node_id])
|
||||
|
||||
return results[0] if len(results) > 0 else None
|
||||
|
||||
async def extract_nodes(self, node_ids: List[str]):
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (node {id: id})
|
||||
RETURN node"""
|
||||
|
||||
params = {
|
||||
"node_ids": node_ids
|
||||
}
|
||||
|
||||
results = await self.query(query, params)
|
||||
|
||||
return results
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
node_id = id.replace(":", "_")
|
||||
|
||||
query = f"MATCH (node:`{node_id}` {{id: $node_id}}) DETACH DELETE n"
|
||||
params = { "node_id": node_id }
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
|
||||
serialized_properties = self.serialize_properties(edge_properties)
|
||||
from_node = from_node.replace(":", "_")
|
||||
to_node = to_node.replace(":", "_")
|
||||
|
||||
query = f"""MATCH (from_node:`{from_node}` {{id: $from_node}}), (to_node:`{to_node}` {{id: $to_node}})
|
||||
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
|
||||
SET r += $properties
|
||||
RETURN r"""
|
||||
|
||||
params = {
|
||||
"from_node": from_node,
|
||||
"to_node": to_node,
|
||||
"properties": serialized_properties
|
||||
}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||
# edges_data = []
|
||||
|
||||
for edge in edges:
|
||||
from_node, to_node, relationship_name, edge_properties = edge
|
||||
from_node = from_node.replace(":", "_")
|
||||
to_node = to_node.replace(":", "_")
|
||||
|
||||
await self.add_edge(
|
||||
from_node = from_node,
|
||||
to_node = to_node,
|
||||
relationship_name = relationship_name,
|
||||
edge_properties = edge_properties
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def filter_nodes(self, search_criteria):
|
||||
query = f"""MATCH (node)
|
||||
WHERE node.id CONTAINS '{search_criteria}'
|
||||
RETURN node"""
|
||||
|
||||
|
||||
return await self.query(query)
|
||||
|
||||
|
||||
async def delete_graph(self):
|
||||
query = """MATCH (node)
|
||||
DETACH DELETE node;"""
|
||||
|
||||
return await self.query(query)
|
||||
|
||||
def serialize_properties(self, properties = dict()):
|
||||
return {
|
||||
property_key: json.dumps(property_value)
|
||||
if isinstance(property_value, (dict, list))
|
||||
else property_value for property_key, property_value in properties.items()
|
||||
}
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from .config import get_graph_config
|
||||
from .graph_db_interface import GraphDBInterface
|
||||
from .networkx.adapter import NetworkXAdapter
|
||||
|
||||
|
||||
async def get_graph_engine() -> GraphDBInterface :
|
||||
|
|
@ -21,19 +20,19 @@ async def get_graph_engine() -> GraphDBInterface :
|
|||
except:
|
||||
pass
|
||||
|
||||
elif config.graph_database_provider == "falkorb":
|
||||
try:
|
||||
from .falkordb.adapter import FalcorDBAdapter
|
||||
elif config.graph_database_provider == "falkordb":
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
||||
return FalcorDBAdapter(
|
||||
graph_database_url = config.graph_database_url,
|
||||
graph_database_username = config.graph_database_username,
|
||||
graph_database_password = config.graph_database_password,
|
||||
graph_database_port = config.graph_database_port
|
||||
)
|
||||
except:
|
||||
pass
|
||||
embedding_engine = get_embedding_engine()
|
||||
|
||||
return FalkorDBAdapter(
|
||||
database_url = config.graph_database_url,
|
||||
database_port = config.graph_database_port,
|
||||
embedding_engine = embedding_engine,
|
||||
)
|
||||
|
||||
from .networkx.adapter import NetworkXAdapter
|
||||
graph_client = NetworkXAdapter(filename = config.graph_file_path)
|
||||
|
||||
if graph_client.graph is None:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from abc import abstractmethod
|
|||
|
||||
class GraphDBInterface(Protocol):
|
||||
@abstractmethod
|
||||
async def graph(self):
|
||||
async def query(self, query: str, params: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
""" Neo4j Adapter for Graph Database"""
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from textwrap import dedent
|
||||
from typing import Optional, Any, List, Dict
|
||||
from contextlib import asynccontextmanager
|
||||
from uuid import UUID
|
||||
from neo4j import AsyncSession
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from neo4j.exceptions import Neo4jError
|
||||
from networkx import predecessor
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
|
||||
logger = logging.getLogger("Neo4jAdapter")
|
||||
|
|
@ -41,17 +42,13 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
async with self.get_session() as session:
|
||||
result = await session.run(query, parameters=params)
|
||||
result = await session.run(query, parameters = params)
|
||||
data = await result.data()
|
||||
await self.close()
|
||||
return data
|
||||
except Neo4jError as error:
|
||||
logger.error("Neo4j query error: %s", error, exc_info = True)
|
||||
raise error
|
||||
|
||||
async def graph(self):
|
||||
return await self.get_session()
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
results = self.query(
|
||||
"""
|
||||
|
|
@ -63,73 +60,40 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
)
|
||||
return results[0]["node_exists"] if len(results) > 0 else False
|
||||
|
||||
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
|
||||
node_id = node_id.replace(":", "_")
|
||||
async def add_node(self, node: DataPoint):
|
||||
serialized_properties = self.serialize_properties(node.model_dump())
|
||||
|
||||
serialized_properties = self.serialize_properties(node_properties)
|
||||
|
||||
if "name" not in serialized_properties:
|
||||
serialized_properties["name"] = node_id
|
||||
|
||||
query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
|
||||
ON CREATE SET node += $properties
|
||||
RETURN ID(node) AS internal_id, node.id AS nodeId"""
|
||||
query = dedent("""MERGE (node {id: $node_id})
|
||||
ON CREATE SET node += $properties, node.updated_at = timestamp()
|
||||
ON MATCH SET node += $properties, node.updated_at = timestamp()
|
||||
RETURN ID(node) AS internal_id, node.id AS nodeId""")
|
||||
|
||||
params = {
|
||||
"node_id": node_id,
|
||||
"node_id": str(node.id),
|
||||
"properties": serialized_properties,
|
||||
}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def add_nodes(self, nodes: list[tuple[str, dict[str, Any]]]) -> None:
|
||||
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
||||
query = """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n {id: node.node_id})
|
||||
ON CREATE SET n += node.properties
|
||||
ON CREATE SET n += node.properties, n.updated_at = timestamp()
|
||||
ON MATCH SET n += node.properties, n.updated_at = timestamp()
|
||||
WITH n, node.node_id AS label
|
||||
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
|
||||
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
||||
"""
|
||||
|
||||
nodes = [{
|
||||
"node_id": node_id,
|
||||
"properties": self.serialize_properties(node_properties),
|
||||
} for (node_id, node_properties) in nodes]
|
||||
"node_id": str(node.id),
|
||||
"properties": self.serialize_properties(node.model_dump()),
|
||||
} for node in nodes]
|
||||
|
||||
results = await self.query(query, dict(nodes = nodes))
|
||||
return results
|
||||
|
||||
async def extract_node_description(self, node_id: str):
|
||||
query = """MATCH (n)-[r]->(m)
|
||||
WHERE n.id = $node_id
|
||||
AND NOT m.id CONTAINS 'DefaultGraphModel'
|
||||
RETURN m
|
||||
"""
|
||||
|
||||
result = await self.query(query, dict(node_id = node_id))
|
||||
|
||||
descriptions = []
|
||||
|
||||
for node in result:
|
||||
# Assuming 'm' is a consistent key in your data structure
|
||||
attributes = node.get("m", {})
|
||||
|
||||
# Ensure all required attributes are present
|
||||
if all(key in attributes for key in ["id", "layer_id", "description"]):
|
||||
descriptions.append({
|
||||
"id": attributes["id"],
|
||||
"layer_id": attributes["layer_id"],
|
||||
"description": attributes["description"],
|
||||
})
|
||||
|
||||
return descriptions
|
||||
|
||||
async def get_layer_nodes(self):
|
||||
query = """MATCH (node) WHERE node.layer_id IS NOT NULL
|
||||
RETURN node"""
|
||||
|
||||
return [result["node"] for result in (await self.query(query))]
|
||||
|
||||
async def extract_node(self, node_id: str):
|
||||
results = await self.extract_nodes([node_id])
|
||||
|
|
@ -170,13 +134,20 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
||||
query = f"""
|
||||
MATCH (from_node:`{from_node}`)-[relationship:`{edge_label}`]->(to_node:`{to_node}`)
|
||||
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
|
||||
query = """
|
||||
MATCH (from_node)-[relationship]->(to_node)
|
||||
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
|
||||
RETURN COUNT(relationship) > 0 AS edge_exists
|
||||
"""
|
||||
|
||||
edge_exists = await self.query(query)
|
||||
params = {
|
||||
"from_node_id": str(from_node),
|
||||
"to_node_id": str(to_node),
|
||||
"edge_label": edge_label,
|
||||
}
|
||||
|
||||
edge_exists = await self.query(query, params)
|
||||
return edge_exists
|
||||
|
||||
async def has_edges(self, edges):
|
||||
|
|
@ -190,8 +161,8 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
try:
|
||||
params = {
|
||||
"edges": [{
|
||||
"from_node": edge[0],
|
||||
"to_node": edge[1],
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
} for edge in edges],
|
||||
}
|
||||
|
|
@ -203,21 +174,21 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
raise error
|
||||
|
||||
|
||||
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
|
||||
async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
|
||||
serialized_properties = self.serialize_properties(edge_properties)
|
||||
from_node = from_node.replace(":", "_")
|
||||
to_node = to_node.replace(":", "_")
|
||||
|
||||
query = f"""MATCH (from_node:`{from_node}`
|
||||
{{id: $from_node}}),
|
||||
(to_node:`{to_node}` {{id: $to_node}})
|
||||
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
|
||||
SET r += $properties
|
||||
RETURN r"""
|
||||
query = dedent("""MATCH (from_node {id: $from_node}),
|
||||
(to_node {id: $to_node})
|
||||
MERGE (from_node)-[r]->(to_node)
|
||||
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
|
||||
ON MATCH SET r += $properties, r.updated_at = timestamp()
|
||||
RETURN r
|
||||
""")
|
||||
|
||||
params = {
|
||||
"from_node": from_node,
|
||||
"to_node": to_node,
|
||||
"from_node": str(from_node),
|
||||
"to_node": str(to_node),
|
||||
"relationship_name": relationship_name,
|
||||
"properties": serialized_properties
|
||||
}
|
||||
|
||||
|
|
@ -234,13 +205,13 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
"""
|
||||
|
||||
edges = [{
|
||||
"from_node": edge[0],
|
||||
"to_node": edge[1],
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
"properties": {
|
||||
**(edge[3] if edge[3] else {}),
|
||||
"source_node_id": edge[0],
|
||||
"target_node_id": edge[1],
|
||||
"source_node_id": str(edge[0]),
|
||||
"target_node_id": str(edge[1]),
|
||||
},
|
||||
} for edge in edges]
|
||||
|
||||
|
|
@ -300,14 +271,6 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
return results[0]["ids"] if len(results) > 0 else []
|
||||
|
||||
|
||||
async def filter_nodes(self, search_criteria):
|
||||
query = f"""MATCH (node)
|
||||
WHERE node.id CONTAINS '{search_criteria}'
|
||||
RETURN node"""
|
||||
|
||||
return await self.query(query)
|
||||
|
||||
|
||||
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||
if edge_label is not None:
|
||||
query = """
|
||||
|
|
@ -379,7 +342,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return predecessors + successors
|
||||
|
||||
async def get_connections(self, node_id: str) -> list:
|
||||
async def get_connections(self, node_id: UUID) -> list:
|
||||
predecessors_query = """
|
||||
MATCH (node)<-[relation]-(neighbour)
|
||||
WHERE node.id = $node_id
|
||||
|
|
@ -392,8 +355,8 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
"""
|
||||
|
||||
predecessors, successors = await asyncio.gather(
|
||||
self.query(predecessors_query, dict(node_id = node_id)),
|
||||
self.query(successors_query, dict(node_id = node_id)),
|
||||
self.query(predecessors_query, dict(node_id = str(node_id))),
|
||||
self.query(successors_query, dict(node_id = str(node_id))),
|
||||
)
|
||||
|
||||
connections = []
|
||||
|
|
@ -438,15 +401,22 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
return await self.query(query)
|
||||
|
||||
def serialize_properties(self, properties = dict()):
|
||||
return {
|
||||
property_key: json.dumps(property_value)
|
||||
if isinstance(property_value, (dict, list))
|
||||
else property_value for property_key, property_value in properties.items()
|
||||
}
|
||||
serialized_properties = {}
|
||||
|
||||
for property_key, property_value in properties.items():
|
||||
if isinstance(property_value, UUID):
|
||||
serialized_properties[property_key] = str(property_value)
|
||||
continue
|
||||
|
||||
serialized_properties[property_key] = property_value
|
||||
|
||||
return serialized_properties
|
||||
|
||||
async def get_graph_data(self):
|
||||
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
||||
|
||||
result = await self.query(query)
|
||||
|
||||
nodes = [(
|
||||
record["properties"]["id"],
|
||||
record["properties"],
|
||||
|
|
|
|||
|
|
@ -1,14 +1,19 @@
|
|||
"""Adapter for NetworkX graph database."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
from re import A
|
||||
from typing import Dict, Any, List
|
||||
from uuid import UUID
|
||||
import aiofiles
|
||||
import aiofiles.os as aiofiles_os
|
||||
import networkx as nx
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
||||
logger = logging.getLogger("NetworkXAdapter")
|
||||
|
||||
|
|
@ -25,29 +30,38 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
def __init__(self, filename = "cognee_graph.pkl"):
|
||||
self.filename = filename
|
||||
|
||||
async def get_graph_data(self):
|
||||
await self.load_graph_from_file()
|
||||
return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True)))
|
||||
|
||||
async def query(self, query: str, params: dict):
|
||||
pass
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
return self.graph.has_node(node_id)
|
||||
|
||||
async def add_node(
|
||||
self,
|
||||
node_id: str,
|
||||
node_properties,
|
||||
node: DataPoint,
|
||||
) -> None:
|
||||
if not self.graph.has_node(id):
|
||||
self.graph.add_node(node_id, **node_properties)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
self.graph.add_node(node.id, **node.model_dump())
|
||||
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def add_nodes(
|
||||
self,
|
||||
nodes: List[tuple[str, dict]],
|
||||
nodes: list[DataPoint],
|
||||
) -> None:
|
||||
nodes = [(node.id, node.model_dump()) for node in nodes]
|
||||
|
||||
self.graph.add_nodes_from(nodes)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
|
||||
async def get_graph(self):
|
||||
return self.graph
|
||||
|
||||
|
||||
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
||||
return self.graph.has_edge(from_node, to_node, key = edge_label)
|
||||
|
||||
|
|
@ -55,18 +69,20 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
result = []
|
||||
|
||||
for (from_node, to_node, edge_label) in edges:
|
||||
if await self.has_edge(from_node, to_node, edge_label):
|
||||
if self.graph.has_edge(from_node, to_node, edge_label):
|
||||
result.append((from_node, to_node, edge_label))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def add_edge(
|
||||
self,
|
||||
from_node: str,
|
||||
to_node: str,
|
||||
relationship_name: str,
|
||||
edge_properties: Dict[str, Any] = None,
|
||||
edge_properties: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
edge_properties["updated_at"] = datetime.now(timezone.utc)
|
||||
self.graph.add_edge(from_node, to_node, key = relationship_name, **(edge_properties if edge_properties else {}))
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
|
|
@ -74,22 +90,29 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
self,
|
||||
edges: tuple[str, str, str, dict],
|
||||
) -> None:
|
||||
edges = [(edge[0], edge[1], edge[2], {
|
||||
**(edge[3] if len(edge) == 4 else {}),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}) for edge in edges]
|
||||
|
||||
self.graph.add_edges_from(edges)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def get_edges(self, node_id: str):
|
||||
return list(self.graph.in_edges(node_id, data = True)) + list(self.graph.out_edges(node_id, data = True))
|
||||
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""Asynchronously delete a node from the graph if it exists."""
|
||||
if self.graph.has_node(id):
|
||||
self.graph.remove_node(id)
|
||||
if self.graph.has_node(node_id):
|
||||
self.graph.remove_node(node_id)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def delete_nodes(self, node_ids: List[str]) -> None:
|
||||
self.graph.remove_nodes_from(node_ids)
|
||||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
|
||||
async def get_disconnected_nodes(self) -> List[str]:
|
||||
connected_components = list(nx.weakly_connected_components(self.graph))
|
||||
|
||||
|
|
@ -102,33 +125,6 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
return disconnected_nodes
|
||||
|
||||
async def extract_node_description(self, node_id: str) -> Dict[str, Any]:
|
||||
descriptions = []
|
||||
|
||||
if self.graph.has_node(node_id):
|
||||
# Get the attributes of the node
|
||||
for neighbor in self.graph.neighbors(node_id):
|
||||
# Get the attributes of the neighboring node
|
||||
attributes = self.graph.nodes[neighbor]
|
||||
|
||||
# Ensure all required attributes are present before extracting description
|
||||
if all(key in attributes for key in ["id", "layer_id", "description"]):
|
||||
descriptions.append({
|
||||
"id": attributes["id"],
|
||||
"layer_id": attributes["layer_id"],
|
||||
"description": attributes["description"],
|
||||
})
|
||||
|
||||
return descriptions
|
||||
|
||||
async def get_layer_nodes(self):
|
||||
layer_nodes = []
|
||||
|
||||
for _, data in self.graph.nodes(data = True):
|
||||
if "layer_id" in data:
|
||||
layer_nodes.append(data)
|
||||
|
||||
return layer_nodes
|
||||
|
||||
async def extract_node(self, node_id: str) -> dict:
|
||||
if self.graph.has_node(node_id):
|
||||
|
|
@ -139,7 +135,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
async def extract_nodes(self, node_ids: List[str]) -> List[dict]:
|
||||
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
|
||||
|
||||
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list:
|
||||
async def get_predecessors(self, node_id: UUID, edge_label: str = None) -> list:
|
||||
if self.graph.has_node(node_id):
|
||||
if edge_label is None:
|
||||
return [
|
||||
|
|
@ -155,7 +151,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
return nodes
|
||||
|
||||
async def get_successors(self, node_id: str, edge_label: str = None) -> list:
|
||||
async def get_successors(self, node_id: UUID, edge_label: str = None) -> list:
|
||||
if self.graph.has_node(node_id):
|
||||
if edge_label is None:
|
||||
return [
|
||||
|
|
@ -184,13 +180,13 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
return neighbours
|
||||
|
||||
async def get_connections(self, node_id: str) -> list:
|
||||
async def get_connections(self, node_id: UUID) -> list:
|
||||
if not self.graph.has_node(node_id):
|
||||
return []
|
||||
|
||||
node = self.graph.nodes[node_id]
|
||||
|
||||
if "uuid" not in node:
|
||||
if "id" not in node:
|
||||
return []
|
||||
|
||||
predecessors, successors = await asyncio.gather(
|
||||
|
|
@ -201,14 +197,14 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
connections = []
|
||||
|
||||
for neighbor in predecessors:
|
||||
if "uuid" in neighbor:
|
||||
edge_data = self.graph.get_edge_data(neighbor["uuid"], node["uuid"])
|
||||
if "id" in neighbor:
|
||||
edge_data = self.graph.get_edge_data(neighbor["id"], node["id"])
|
||||
for edge_properties in edge_data.values():
|
||||
connections.append((neighbor, edge_properties, node))
|
||||
|
||||
for neighbor in successors:
|
||||
if "uuid" in neighbor:
|
||||
edge_data = self.graph.get_edge_data(node["uuid"], neighbor["uuid"])
|
||||
if "id" in neighbor:
|
||||
edge_data = self.graph.get_edge_data(node["id"], neighbor["id"])
|
||||
for edge_properties in edge_data.values():
|
||||
connections.append((node, edge_properties, neighbor))
|
||||
|
||||
|
|
@ -240,7 +236,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
graph_data = nx.readwrite.json_graph.node_link_data(self.graph)
|
||||
|
||||
async with aiofiles.open(file_path, "w") as file:
|
||||
await file.write(json.dumps(graph_data))
|
||||
await file.write(json.dumps(graph_data, cls = JSONEncoder))
|
||||
|
||||
|
||||
async def load_graph_from_file(self, file_path: str = None):
|
||||
|
|
@ -254,6 +250,29 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
if os.path.exists(file_path):
|
||||
async with aiofiles.open(file_path, "r") as file:
|
||||
graph_data = json.loads(await file.read())
|
||||
for node in graph_data["nodes"]:
|
||||
try:
|
||||
node["id"] = UUID(node["id"])
|
||||
except:
|
||||
pass
|
||||
if "updated_at" in node:
|
||||
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
|
||||
for edge in graph_data["links"]:
|
||||
try:
|
||||
source_id = UUID(edge["source"])
|
||||
target_id = UUID(edge["target"])
|
||||
|
||||
edge["source"] = source_id
|
||||
edge["target"] = target_id
|
||||
edge["source_node_id"] = source_id
|
||||
edge["target_node_id"] = target_id
|
||||
except:
|
||||
pass
|
||||
|
||||
if "updated_at" in edge:
|
||||
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
|
||||
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
|
||||
else:
|
||||
# Log that the file does not exist and an empty graph is initialized
|
||||
|
|
@ -265,9 +284,11 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
os.makedirs(file_dir, exist_ok = True)
|
||||
|
||||
await self.save_graph_to_file(file_path)
|
||||
|
||||
except Exception:
|
||||
logger.error("Failed to load graph from file: %s", file_path)
|
||||
|
||||
|
||||
async def delete_graph(self, file_path: str = None):
|
||||
"""Asynchronously delete the graph file from the filesystem."""
|
||||
if file_path is None:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,267 @@
|
|||
import asyncio
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
from falkordb import FalkorDB
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
|
||||
from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"]
|
||||
}
|
||||
|
||||
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||
def __init__(
|
||||
self,
|
||||
database_url: str,
|
||||
database_port: int,
|
||||
embedding_engine = EmbeddingEngine,
|
||||
):
|
||||
self.driver = FalkorDB(
|
||||
host = database_url,
|
||||
port = database_port,
|
||||
)
|
||||
self.embedding_engine = embedding_engine
|
||||
self.graph_name = "cognee_graph"
|
||||
|
||||
def query(self, query: str, params: dict = {}):
|
||||
graph = self.driver.select_graph(self.graph_name)
|
||||
|
||||
try:
|
||||
result = graph.query(query, params)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error executing query: {e}")
|
||||
raise e
|
||||
|
||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
async def stringify_properties(self, properties: dict, vectorize_fields = []) -> str:
|
||||
async def get_value(key, value):
|
||||
return f"'{value}'" if key not in vectorize_fields else await self.get_vectorized_value(value)
|
||||
|
||||
return ",".join([f"{key}:{await get_value(key, value)}" for key, value in properties.items()])
|
||||
|
||||
async def get_vectorized_value(self, value: Any) -> str:
|
||||
vector = (await self.embed_data([value]))[0]
|
||||
return f"vecf32({vector})"
|
||||
|
||||
async def create_data_point_query(self, data_point: DataPoint):
|
||||
node_label = type(data_point).__name__
|
||||
node_properties = await self.stringify_properties(
|
||||
data_point.model_dump(),
|
||||
data_point._metadata["index_fields"],
|
||||
# data_point._metadata["index_fields"] if hasattr(data_point, "_metadata") else [],
|
||||
)
|
||||
|
||||
return dedent(f"""
|
||||
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
|
||||
ON CREATE SET node += ({{{node_properties}}})
|
||||
ON CREATE SET node.updated_at = timestamp()
|
||||
ON MATCH SET node += ({{{node_properties}}})
|
||||
ON MATCH SET node.updated_at = timestamp()
|
||||
""").strip()
|
||||
|
||||
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
|
||||
properties = await self.stringify_properties(edge[3])
|
||||
properties = f"{{{properties}}}"
|
||||
|
||||
return dedent(f"""
|
||||
MERGE (source {{id:'{edge[0]}'}})
|
||||
MERGE (target {{id: '{edge[1]}'}})
|
||||
MERGE (source)-[edge:{edge[2]} {properties}]->(target)
|
||||
ON MATCH SET edge.updated_at = timestamp()
|
||||
ON CREATE SET edge.updated_at = timestamp()
|
||||
""").strip()
|
||||
|
||||
async def create_collection(self, collection_name: str):
|
||||
pass
|
||||
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
collections = self.driver.list_graphs()
|
||||
|
||||
return collection_name in collections
|
||||
|
||||
async def create_data_points(self, data_points: list[DataPoint]):
|
||||
queries = [await self.create_data_point_query(data_point) for data_point in data_points]
|
||||
for query in queries:
|
||||
self.query(query)
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
graph = self.driver.select_graph(self.graph_name)
|
||||
|
||||
if not self.has_vector_index(graph, index_name, index_property_name):
|
||||
graph.create_node_vector_index(index_name, index_property_name, dim = self.embedding_engine.get_vector_size())
|
||||
|
||||
def has_vector_index(self, graph, index_name: str, index_property_name: str) -> bool:
|
||||
try:
|
||||
indices = graph.list_indices()
|
||||
|
||||
return any([(index[0] == index_name and index_property_name in index[1]) for index in indices.result_set])
|
||||
except:
|
||||
return False
|
||||
|
||||
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
|
||||
pass
|
||||
|
||||
async def add_node(self, node: DataPoint):
|
||||
await self.create_data_points([node])
|
||||
|
||||
async def add_nodes(self, nodes: list[DataPoint]):
|
||||
await self.create_data_points(nodes)
|
||||
|
||||
async def add_edge(self, edge: tuple[str, str, str, dict]):
|
||||
query = await self.create_edge_query(edge)
|
||||
|
||||
self.query(query)
|
||||
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict]]):
|
||||
queries = [await self.create_edge_query(edge) for edge in edges]
|
||||
|
||||
for query in queries:
|
||||
self.query(query)
|
||||
|
||||
async def has_edges(self, edges):
|
||||
query = dedent("""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (a)-[r]->(b)
|
||||
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
|
||||
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
|
||||
""").strip()
|
||||
|
||||
params = {
|
||||
"edges": [{
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
} for edge in edges],
|
||||
}
|
||||
|
||||
results = self.query(query, params).result_set
|
||||
|
||||
return [result["edge_exists"] for result in results]
|
||||
|
||||
async def retrieve(self, data_point_ids: list[str]):
|
||||
return self.query(
|
||||
f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
|
||||
{
|
||||
"node_ids": data_point_ids,
|
||||
},
|
||||
)
|
||||
|
||||
async def extract_node(self, data_point_id: str):
|
||||
return await self.retrieve([data_point_id])
|
||||
|
||||
async def extract_nodes(self, data_point_ids: list[str]):
|
||||
return await self.retrieve(data_point_ids)
|
||||
|
||||
async def get_connections(self, node_id: UUID) -> list:
|
||||
predecessors_query = """
|
||||
MATCH (node)<-[relation]-(neighbour)
|
||||
WHERE node.id = $node_id
|
||||
RETURN neighbour, relation, node
|
||||
"""
|
||||
successors_query = """
|
||||
MATCH (node)-[relation]->(neighbour)
|
||||
WHERE node.id = $node_id
|
||||
RETURN node, relation, neighbour
|
||||
"""
|
||||
|
||||
predecessors, successors = await asyncio.gather(
|
||||
self.query(predecessors_query, dict(node_id = node_id)),
|
||||
self.query(successors_query, dict(node_id = node_id)),
|
||||
)
|
||||
|
||||
connections = []
|
||||
|
||||
for neighbour in predecessors:
|
||||
neighbour = neighbour["relation"]
|
||||
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
|
||||
|
||||
for neighbour in successors:
|
||||
neighbour = neighbour["relation"]
|
||||
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
|
||||
|
||||
return connections
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: list[float] = None,
|
||||
limit: int = 10,
|
||||
with_vector: bool = False,
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_text and not query_vector:
|
||||
query_vector = (await self.embed_data([query_text]))[0]
|
||||
|
||||
query = dedent(f"""
|
||||
CALL db.idx.vector.queryNodes(
|
||||
{collection_name},
|
||||
'text',
|
||||
{limit},
|
||||
vecf32({query_vector})
|
||||
) YIELD node, score
|
||||
""").strip()
|
||||
|
||||
result = self.query(query)
|
||||
|
||||
return result
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_texts: list[str],
|
||||
limit: int = None,
|
||||
with_vectors: bool = False,
|
||||
):
|
||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
|
||||
return await asyncio.gather(
|
||||
*[self.search(
|
||||
collection_name = collection_name,
|
||||
query_vector = query_vector,
|
||||
limit = limit,
|
||||
with_vector = with_vectors,
|
||||
) for query_vector in query_vectors]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
return self.query(
|
||||
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
|
||||
{
|
||||
"node_ids": data_point_ids,
|
||||
},
|
||||
)
|
||||
|
||||
async def delete_node(self, collection_name: str, data_point_id: str):
|
||||
return await self.delete_data_points([data_point_id])
|
||||
|
||||
async def delete_nodes(self, collection_name: str, data_point_ids: list[str]):
|
||||
self.delete_data_points(data_point_ids)
|
||||
|
||||
async def delete_graph(self):
|
||||
try:
|
||||
graph = self.driver.select_graph(self.graph_name)
|
||||
|
||||
indices = graph.list_indices()
|
||||
for index in indices.result_set:
|
||||
for field in index[1]:
|
||||
graph.drop_node_vector_index(index[0], field)
|
||||
|
||||
graph.delete()
|
||||
except Exception as e:
|
||||
print(f"Error deleting graph: {e}")
|
||||
|
||||
async def prune(self):
|
||||
self.delete_graph()
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
from .models.DataPoint import DataPoint
|
||||
from .models.VectorConfig import VectorConfig
|
||||
from .models.CollectionConfig import CollectionConfig
|
||||
from .vector_db_interface import VectorDBInterface
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ class VectorConfig(BaseSettings):
|
|||
os.path.join(get_absolute_path(".cognee_system"), "databases"),
|
||||
"cognee.lancedb"
|
||||
)
|
||||
vector_db_port: int = 1234
|
||||
vector_db_key: str = ""
|
||||
vector_db_provider: str = "lancedb"
|
||||
|
||||
|
|
@ -16,6 +17,7 @@ class VectorConfig(BaseSettings):
|
|||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"vector_db_url": self.vector_db_url,
|
||||
"vector_db_port": self.vector_db_port,
|
||||
"vector_db_key": self.vector_db_key,
|
||||
"vector_db_provider": self.vector_db_provider,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
from typing import Dict
|
||||
|
||||
from ..relational.config import get_relational_config
|
||||
|
||||
class VectorConfig(Dict):
|
||||
vector_db_url: str
|
||||
vector_db_port: str
|
||||
vector_db_key: str
|
||||
vector_db_provider: str
|
||||
|
||||
|
|
@ -29,6 +28,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
embedding_engine = embedding_engine
|
||||
)
|
||||
elif config["vector_db_provider"] == "pgvector":
|
||||
from cognee.infrastructure.databases.relational import get_relational_config
|
||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||
|
||||
# Get configuration for postgres database
|
||||
|
|
@ -43,9 +43,18 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||
)
|
||||
|
||||
return PGVectorAdapter(connection_string,
|
||||
config["vector_db_key"],
|
||||
embedding_engine
|
||||
return PGVectorAdapter(
|
||||
connection_string,
|
||||
config["vector_db_key"],
|
||||
embedding_engine,
|
||||
)
|
||||
elif config["vector_db_provider"] == "falkordb":
|
||||
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
||||
return FalkorDBAdapter(
|
||||
database_url = config["vector_db_url"],
|
||||
database_port = config["vector_db_port"],
|
||||
embedding_engine = embedding_engine,
|
||||
)
|
||||
else:
|
||||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||
|
|
|
|||
|
|
@ -1,57 +0,0 @@
|
|||
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from falkordb import FalkorDB
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..models.DataPoint import DataPoint
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
||||
|
||||
|
||||
|
||||
class FalcorDBAdapter(VectorDBInterface):
|
||||
def __init__(
|
||||
self,
|
||||
graph_database_url: str,
|
||||
graph_database_username: str,
|
||||
graph_database_password: str,
|
||||
graph_database_port: int,
|
||||
driver: Optional[Any] = None,
|
||||
embedding_engine = EmbeddingEngine,
|
||||
graph_name: str = "DefaultGraph",
|
||||
):
|
||||
self.driver = FalkorDB(
|
||||
host = graph_database_url,
|
||||
port = graph_database_port)
|
||||
self.graph_name = graph_name
|
||||
self.embedding_engine = embedding_engine
|
||||
|
||||
|
||||
|
||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
|
||||
async def create_collection(self, collection_name: str, payload_schema = None):
|
||||
pass
|
||||
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
pass
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
pass
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
limit: int = 10,
|
||||
with_vector: bool = False,
|
||||
):
|
||||
pass
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
pass
|
||||
|
|
@ -1,12 +1,25 @@
|
|||
import inspect
|
||||
from typing import List, Optional, get_type_hints, Generic, TypeVar
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
import lancedb
|
||||
from pydantic import BaseModel
|
||||
from lancedb.pydantic import Vector, LanceModel
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..vector_db_interface import VectorDBInterface, DataPoint
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
id: str
|
||||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"]
|
||||
}
|
||||
|
||||
class LanceDBAdapter(VectorDBInterface):
|
||||
name = "LanceDB"
|
||||
url: str
|
||||
|
|
@ -38,10 +51,12 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
collection_names = await connection.table_names()
|
||||
return collection_name in collection_names
|
||||
|
||||
async def create_collection(self, collection_name: str, payload_schema = None):
|
||||
data_point_types = get_type_hints(DataPoint)
|
||||
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
payload_schema = self.get_data_point_schema(payload_schema)
|
||||
data_point_types = get_type_hints(payload_schema)
|
||||
|
||||
class LanceDataPoint(LanceModel):
|
||||
id: data_point_types["id"]
|
||||
vector: Vector(vector_size)
|
||||
|
|
@ -55,13 +70,16 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
exist_ok = True,
|
||||
)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
||||
connection = await self.get_connection()
|
||||
|
||||
payload_schema = type(data_points[0])
|
||||
payload_schema = self.get_data_point_schema(payload_schema)
|
||||
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name,
|
||||
payload_schema = type(data_points[0].payload),
|
||||
payload_schema,
|
||||
)
|
||||
|
||||
collection = await connection.open_table(collection_name)
|
||||
|
|
@ -79,15 +97,26 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
vector: Vector(vector_size)
|
||||
payload: PayloadSchema
|
||||
|
||||
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
|
||||
properties = get_own_properties(data_point)
|
||||
properties["id"] = str(properties["id"])
|
||||
|
||||
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
|
||||
id = str(data_point.id),
|
||||
vector = vector,
|
||||
payload = properties,
|
||||
)
|
||||
|
||||
lance_data_points = [
|
||||
LanceDataPoint[type(data_point.id), type(data_point.payload)](
|
||||
id = data_point.id,
|
||||
vector = data_vectors[data_index],
|
||||
payload = data_point.payload,
|
||||
) for (data_index, data_point) in enumerate(data_points)
|
||||
create_lance_data_point(data_point, data_vectors[data_point_index])
|
||||
for (data_point_index, data_point) in enumerate(data_points)
|
||||
]
|
||||
|
||||
await collection.add(lance_data_points)
|
||||
await collection.merge_insert("id") \
|
||||
.when_matched_update_all() \
|
||||
.when_not_matched_insert_all() \
|
||||
.execute(lance_data_points)
|
||||
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
connection = await self.get_connection()
|
||||
|
|
@ -99,7 +128,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
|
||||
|
||||
return [ScoredResult(
|
||||
id = result["id"],
|
||||
id = UUID(result["id"]),
|
||||
payload = result["payload"],
|
||||
score = 0,
|
||||
) for result in results.to_dict("index").values()]
|
||||
|
|
@ -135,10 +164,19 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
if value < min_value:
|
||||
min_value = value
|
||||
|
||||
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
|
||||
normalized_values = []
|
||||
min_value = min(result["_distance"] for result in result_values)
|
||||
max_value = max(result["_distance"] for result in result_values)
|
||||
|
||||
if max_value == min_value:
|
||||
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
|
||||
normalized_values = [0 for _ in result_values]
|
||||
else:
|
||||
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
|
||||
result_values]
|
||||
|
||||
return [ScoredResult(
|
||||
id = str(result["id"]),
|
||||
id = UUID(result["id"]),
|
||||
payload = result["payload"],
|
||||
score = normalized_values[value_index],
|
||||
) for value_index, result in enumerate(result_values)]
|
||||
|
|
@ -170,7 +208,27 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
|
||||
return results
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
await self.create_collection(f"{index_name}_{index_property_name}", payload_schema = IndexSchema)
|
||||
|
||||
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
|
||||
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||
IndexSchema(
|
||||
id = str(data_point.id),
|
||||
text = getattr(data_point, data_point._metadata["index_fields"][0]),
|
||||
) for data_point in data_points
|
||||
])
|
||||
|
||||
async def prune(self):
|
||||
# Clean up the database if it was set up as temporary
|
||||
if self.url.startswith("/"):
|
||||
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside
|
||||
|
||||
def get_data_point_schema(self, model_type):
|
||||
return copy_model(
|
||||
model_type,
|
||||
include_fields = {
|
||||
"id": (str, ...),
|
||||
},
|
||||
exclude_fields = ["_metadata"],
|
||||
)
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
from typing import Generic, TypeVar
|
||||
from pydantic import BaseModel
|
||||
|
||||
PayloadSchema = TypeVar("PayloadSchema", bound = BaseModel)
|
||||
|
||||
class DataPoint(BaseModel, Generic[PayloadSchema]):
|
||||
id: str
|
||||
payload: PayloadSchema
|
||||
embed_field: str = "value"
|
||||
|
||||
def get_embeddable_data(self):
|
||||
if hasattr(self.payload, self.embed_field):
|
||||
return getattr(self.payload, self.embed_field)
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Any, Dict
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ScoredResult(BaseModel):
|
||||
id: str
|
||||
id: UUID
|
||||
score: float # Lower score is better
|
||||
payload: Dict[str, Any]
|
||||
|
|
|
|||
|
|
@ -1,17 +1,26 @@
|
|||
import asyncio
|
||||
from uuid import UUID
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from typing import List, Optional, get_type_hints
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import JSON, Column, Table, select, delete
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
|
||||
from .serialize_datetime import serialize_datetime
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
from .serialize_data import serialize_data
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..vector_db_interface import VectorDBInterface, DataPoint
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||
from ...relational.ModelBase import Base
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"]
|
||||
}
|
||||
|
||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||
|
||||
|
|
@ -45,7 +54,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
if not await self.has_collection(collection_name):
|
||||
|
||||
class PGVectorDataPoint(Base):
|
||||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
|
@ -71,47 +79,58 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
async def create_data_points(
|
||||
self, collection_name: str, data_points: List[DataPoint]
|
||||
):
|
||||
async with self.get_async_session() as session:
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name=collection_name,
|
||||
payload_schema=type(data_points[0].payload),
|
||||
)
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
[data_point.get_embeddable_data() for data_point in data_points]
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name = collection_name,
|
||||
payload_schema = type(data_points[0]),
|
||||
)
|
||||
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
data_vectors = await self.embed_data(
|
||||
[data_point.get_embeddable_data() for data_point in data_points]
|
||||
)
|
||||
|
||||
class PGVectorDataPoint(Base):
|
||||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
primary_key: Mapped[int] = mapped_column(
|
||||
primary_key=True, autoincrement=True
|
||||
)
|
||||
id: Mapped[type(data_points[0].id)]
|
||||
payload = Column(JSON)
|
||||
vector = Column(Vector(vector_size))
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
def __init__(self, id, payload, vector):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
class PGVectorDataPoint(Base):
|
||||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
primary_key: Mapped[int] = mapped_column(
|
||||
primary_key=True, autoincrement=True
|
||||
)
|
||||
id: Mapped[type(data_points[0].id)]
|
||||
payload = Column(JSON)
|
||||
vector = Column(Vector(vector_size))
|
||||
|
||||
pgvector_data_points = [
|
||||
PGVectorDataPoint(
|
||||
id=data_point.id,
|
||||
vector=data_vectors[data_index],
|
||||
payload=serialize_datetime(data_point.payload.dict()),
|
||||
)
|
||||
for (data_index, data_point) in enumerate(data_points)
|
||||
]
|
||||
def __init__(self, id, payload, vector):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
|
||||
pgvector_data_points = [
|
||||
PGVectorDataPoint(
|
||||
id = data_point.id,
|
||||
vector = data_vectors[data_index],
|
||||
payload = serialize_data(data_point.model_dump()),
|
||||
)
|
||||
for (data_index, data_point) in enumerate(data_points)
|
||||
]
|
||||
|
||||
async with self.get_async_session() as session:
|
||||
session.add_all(pgvector_data_points)
|
||||
await session.commit()
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||
|
||||
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
|
||||
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||
IndexSchema(
|
||||
id = data_point.id,
|
||||
text = data_point.get_embeddable_data(),
|
||||
) for data_point in data_points
|
||||
])
|
||||
|
||||
async def get_table(self, collection_name: str) -> Table:
|
||||
"""
|
||||
Dynamically loads a table using the given collection name
|
||||
|
|
@ -126,18 +145,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
raise ValueError(f"Table '{collection_name}' not found.")
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
||||
async with self.get_async_session() as session:
|
||||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
async with self.get_async_session() as session:
|
||||
results = await session.execute(
|
||||
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
|
||||
)
|
||||
results = results.all()
|
||||
|
||||
return [
|
||||
ScoredResult(id=result.id, payload=result.payload, score=0)
|
||||
for result in results
|
||||
ScoredResult(
|
||||
id = UUID(result.id),
|
||||
payload = result.payload,
|
||||
score = 0
|
||||
) for result in results
|
||||
]
|
||||
|
||||
async def search(
|
||||
|
|
@ -154,11 +176,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
closest_items = []
|
||||
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
# Find closest vectors to query_vector
|
||||
closest_items = await session.execute(
|
||||
select(
|
||||
|
|
@ -171,19 +195,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
.limit(limit)
|
||||
)
|
||||
|
||||
vector_list = []
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items:
|
||||
# TODO: Add normalization of similarity score
|
||||
vector_list.append(vector)
|
||||
vector_list = []
|
||||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(
|
||||
id=str(row.id), payload=row.payload, score=row.similarity
|
||||
)
|
||||
for row in vector_list
|
||||
]
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items:
|
||||
# TODO: Add normalization of similarity score
|
||||
vector_list.append(vector)
|
||||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(
|
||||
id = UUID(str(row.id)),
|
||||
payload = row.payload,
|
||||
score = row.similarity
|
||||
) for row in vector_list
|
||||
]
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
def serialize_datetime(data):
|
||||
def serialize_data(data):
|
||||
"""Recursively convert datetime objects in dictionaries/lists to ISO format."""
|
||||
if isinstance(data, dict):
|
||||
return {key: serialize_datetime(value) for key, value in data.items()}
|
||||
return {key: serialize_data(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [serialize_datetime(item) for item in data]
|
||||
return [serialize_data(item) for item in data]
|
||||
elif isinstance(data, datetime):
|
||||
return data.isoformat() # Convert datetime to ISO 8601 string
|
||||
elif isinstance(data, UUID):
|
||||
return str(data)
|
||||
else:
|
||||
return data
|
||||
|
|
@ -1,12 +1,22 @@
|
|||
import logging
|
||||
from uuid import UUID
|
||||
from typing import List, Dict, Optional
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..models.DataPoint import DataPoint
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
||||
logger = logging.getLogger("QDrantAdapter")
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"]
|
||||
}
|
||||
|
||||
# class CollectionConfig(BaseModel, extra = "forbid"):
|
||||
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
|
||||
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
|
||||
|
|
@ -75,20 +85,19 @@ class QDrantAdapter(VectorDBInterface):
|
|||
):
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
result = await client.create_collection(
|
||||
collection_name = collection_name,
|
||||
vectors_config = {
|
||||
"text": models.VectorParams(
|
||||
size = self.embedding_engine.get_vector_size(),
|
||||
distance = "Cosine"
|
||||
)
|
||||
}
|
||||
)
|
||||
if not await client.collection_exists(collection_name):
|
||||
await client.create_collection(
|
||||
collection_name = collection_name,
|
||||
vectors_config = {
|
||||
"text": models.VectorParams(
|
||||
size = self.embedding_engine.get_vector_size(),
|
||||
distance = "Cosine"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
await client.close()
|
||||
|
||||
return result
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
|
|
@ -96,8 +105,8 @@ class QDrantAdapter(VectorDBInterface):
|
|||
|
||||
def convert_to_qdrant_point(data_point: DataPoint):
|
||||
return models.PointStruct(
|
||||
id = data_point.id,
|
||||
payload = data_point.payload.dict(),
|
||||
id = str(data_point.id),
|
||||
payload = data_point.model_dump(),
|
||||
vector = {
|
||||
"text": data_vectors[data_points.index(data_point)]
|
||||
}
|
||||
|
|
@ -116,6 +125,17 @@ class QDrantAdapter(VectorDBInterface):
|
|||
finally:
|
||||
await client.close()
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||
|
||||
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
|
||||
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||
IndexSchema(
|
||||
id = data_point.id,
|
||||
text = getattr(data_point, data_point._metadata["index_fields"][0]),
|
||||
) for data_point in data_points
|
||||
])
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
client = self.get_qdrant_client()
|
||||
results = await client.retrieve(collection_name, data_point_ids, with_payload = True)
|
||||
|
|
@ -135,7 +155,7 @@ class QDrantAdapter(VectorDBInterface):
|
|||
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
result = await client.search(
|
||||
results = await client.search(
|
||||
collection_name = collection_name,
|
||||
query_vector = models.NamedVector(
|
||||
name = "text",
|
||||
|
|
@ -147,7 +167,16 @@ class QDrantAdapter(VectorDBInterface):
|
|||
|
||||
await client.close()
|
||||
|
||||
return result
|
||||
return [
|
||||
ScoredResult(
|
||||
id = UUID(result.id),
|
||||
payload = {
|
||||
**result.payload,
|
||||
"id": UUID(result.id),
|
||||
},
|
||||
score = 1 - result.score,
|
||||
) for result in results
|
||||
]
|
||||
|
||||
|
||||
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int = None, with_vectors: bool = False):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import List, Protocol, Optional
|
||||
from abc import abstractmethod
|
||||
from .models.DataPoint import DataPoint
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from .models.PayloadSchema import PayloadSchema
|
||||
|
||||
class VectorDBInterface(Protocol):
|
||||
|
|
|
|||
|
|
@ -1,13 +1,22 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..models.DataPoint import DataPoint
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
||||
logger = logging.getLogger("WeaviateAdapter")
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"]
|
||||
}
|
||||
|
||||
class WeaviateAdapter(VectorDBInterface):
|
||||
name = "Weaviate"
|
||||
url: str
|
||||
|
|
@ -48,18 +57,21 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
|
||||
future = asyncio.Future()
|
||||
|
||||
future.set_result(
|
||||
self.client.collections.create(
|
||||
name=collection_name,
|
||||
properties=[
|
||||
wvcc.Property(
|
||||
name="text",
|
||||
data_type=wvcc.DataType.TEXT,
|
||||
skip_vectorization=True
|
||||
)
|
||||
]
|
||||
if not self.client.collections.exists(collection_name):
|
||||
future.set_result(
|
||||
self.client.collections.create(
|
||||
name = collection_name,
|
||||
properties = [
|
||||
wvcc.Property(
|
||||
name = "text",
|
||||
data_type = wvcc.DataType.TEXT,
|
||||
skip_vectorization = True
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
future.set_result(self.get_collection(collection_name))
|
||||
|
||||
return await future
|
||||
|
||||
|
|
@ -70,36 +82,60 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
from weaviate.classes.data import DataObject
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
|
||||
[data_point.get_embeddable_data() for data_point in data_points]
|
||||
)
|
||||
|
||||
def convert_to_weaviate_data_points(data_point: DataPoint):
|
||||
vector = data_vectors[data_points.index(data_point)]
|
||||
properties = data_point.model_dump()
|
||||
|
||||
if "id" in properties:
|
||||
properties["uuid"] = str(data_point.id)
|
||||
del properties["id"]
|
||||
|
||||
return DataObject(
|
||||
uuid = data_point.id,
|
||||
properties = data_point.payload.dict(),
|
||||
properties = properties,
|
||||
vector = vector
|
||||
)
|
||||
|
||||
data_points = list(map(convert_to_weaviate_data_points, data_points))
|
||||
data_points = [convert_to_weaviate_data_points(data_point) for data_point in data_points]
|
||||
|
||||
collection = self.get_collection(collection_name)
|
||||
|
||||
try:
|
||||
if len(data_points) > 1:
|
||||
return collection.data.insert_many(data_points)
|
||||
with collection.batch.dynamic() as batch:
|
||||
for data_point in data_points:
|
||||
batch.add_object(
|
||||
uuid = data_point.uuid,
|
||||
vector = data_point.vector,
|
||||
properties = data_point.properties,
|
||||
references = data_point.references,
|
||||
)
|
||||
else:
|
||||
return collection.data.insert(data_points[0])
|
||||
# with collection.batch.dynamic() as batch:
|
||||
# for point in data_points:
|
||||
# batch.add_object(
|
||||
# uuid = point.uuid,
|
||||
# properties = point.properties,
|
||||
# vector = point.vector
|
||||
# )
|
||||
data_point: DataObject = data_points[0]
|
||||
return collection.data.update(
|
||||
uuid = data_point.uuid,
|
||||
vector = data_point.vector,
|
||||
properties = data_point.properties,
|
||||
references = data_point.references,
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error("Error creating data points: %s", str(error))
|
||||
raise error
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||
|
||||
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
|
||||
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||
IndexSchema(
|
||||
id = data_point.id,
|
||||
text = data_point.get_embeddable_data(),
|
||||
) for data_point in data_points
|
||||
])
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
from weaviate.classes.query import Filter
|
||||
future = asyncio.Future()
|
||||
|
|
@ -143,9 +179,9 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
|
||||
return [
|
||||
ScoredResult(
|
||||
id=str(result.uuid),
|
||||
payload=result.properties,
|
||||
score=float(result.metadata.score)
|
||||
id = UUID(str(result.uuid)),
|
||||
payload = result.properties,
|
||||
score = 1 - float(result.metadata.score)
|
||||
) for result in search_result.objects
|
||||
]
|
||||
|
||||
|
|
|
|||
1
cognee/infrastructure/engine/__init__.py
Normal file
1
cognee/infrastructure/engine/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .models.DataPoint import DataPoint
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.utils import get_graph_from_model, get_model_instance_from_graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
class CarTypeName(Enum):
|
||||
Pickup = "Pickup"
|
||||
Sedan = "Sedan"
|
||||
SUV = "SUV"
|
||||
Coupe = "Coupe"
|
||||
Convertible = "Convertible"
|
||||
Hatchback = "Hatchback"
|
||||
Wagon = "Wagon"
|
||||
Minivan = "Minivan"
|
||||
Van = "Van"
|
||||
|
||||
class CarType(DataPoint):
|
||||
id: str
|
||||
name: CarTypeName
|
||||
_metadata: dict = dict(index_fields = ["name"])
|
||||
|
||||
class Car(DataPoint):
|
||||
id: str
|
||||
brand: str
|
||||
model: str
|
||||
year: int
|
||||
color: str
|
||||
is_type: CarType
|
||||
|
||||
class Person(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
age: int
|
||||
owns_car: list[Car]
|
||||
driving_licence: Optional[dict]
|
||||
_metadata: dict = dict(index_fields = ["name"])
|
||||
|
||||
boris = Person(
|
||||
id = "boris",
|
||||
name = "Boris",
|
||||
age = 30,
|
||||
owns_car = [
|
||||
Car(
|
||||
id = "car1",
|
||||
brand = "Toyota",
|
||||
model = "Camry",
|
||||
year = 2020,
|
||||
color = "Blue",
|
||||
is_type = CarType(id = "sedan", name = CarTypeName.Sedan),
|
||||
),
|
||||
],
|
||||
driving_licence = {
|
||||
"issued_by": "PU Vrsac",
|
||||
"issued_on": "2025-11-06",
|
||||
"number": "1234567890",
|
||||
"expires_on": "2025-11-06",
|
||||
},
|
||||
)
|
||||
|
||||
nodes, edges = get_graph_from_model(boris)
|
||||
|
||||
print(nodes)
|
||||
print(edges)
|
||||
|
||||
person_data = nodes[len(nodes) - 1]
|
||||
|
||||
parsed_person = get_model_instance_from_graph(nodes, edges, 'boris')
|
||||
|
||||
print(parsed_person)
|
||||
24
cognee/infrastructure/engine/models/DataPoint.py
Normal file
24
cognee/infrastructure/engine/models/DataPoint.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
from typing_extensions import TypedDict
|
||||
from uuid import UUID, uuid4
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class MetaData(TypedDict):
|
||||
index_fields: list[str]
|
||||
|
||||
class DataPoint(BaseModel):
|
||||
id: UUID = Field(default_factory = uuid4)
|
||||
updated_at: Optional[datetime] = datetime.now(timezone.utc)
|
||||
_metadata: Optional[MetaData] = {
|
||||
"index_fields": []
|
||||
}
|
||||
|
||||
# class Config:
|
||||
# underscore_attrs_are_private = True
|
||||
|
||||
def get_embeddable_data(self):
|
||||
if self._metadata and len(self._metadata["index_fields"]) > 0 \
|
||||
and hasattr(self, self._metadata["index_fields"][0]):
|
||||
|
||||
return getattr(self, self._metadata["index_fields"][0])
|
||||
|
|
@ -1,18 +1,18 @@
|
|||
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
|
||||
from cognee.modules.chunking import DocumentChunk
|
||||
from cognee.tasks.chunking import chunk_by_paragraph
|
||||
from .models.DocumentChunk import DocumentChunk
|
||||
from cognee.tasks.chunks import chunk_by_paragraph
|
||||
|
||||
class TextChunker():
|
||||
id: UUID
|
||||
document = None
|
||||
max_chunk_size: int
|
||||
|
||||
chunk_index = 0
|
||||
chunk_size = 0
|
||||
paragraph_chunks = []
|
||||
|
||||
def __init__(self, id: UUID, get_text: callable, chunk_size: int = 1024):
|
||||
self.id = id
|
||||
def __init__(self, document, get_text: callable, chunk_size: int = 1024):
|
||||
self.document = document
|
||||
self.max_chunk_size = chunk_size
|
||||
self.get_text = get_text
|
||||
|
||||
|
|
@ -29,10 +29,10 @@ class TextChunker():
|
|||
else:
|
||||
if len(self.paragraph_chunks) == 0:
|
||||
yield DocumentChunk(
|
||||
id = chunk_data["chunk_id"],
|
||||
text = chunk_data["text"],
|
||||
word_count = chunk_data["word_count"],
|
||||
document_id = str(self.id),
|
||||
chunk_id = str(chunk_data["chunk_id"]),
|
||||
is_part_of = self.document,
|
||||
chunk_index = self.chunk_index,
|
||||
cut_type = chunk_data["cut_type"],
|
||||
)
|
||||
|
|
@ -40,25 +40,31 @@ class TextChunker():
|
|||
self.chunk_size = 0
|
||||
else:
|
||||
chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks)
|
||||
yield DocumentChunk(
|
||||
text = chunk_text,
|
||||
word_count = self.chunk_size,
|
||||
document_id = str(self.id),
|
||||
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{self.chunk_index}")),
|
||||
chunk_index = self.chunk_index,
|
||||
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
|
||||
)
|
||||
try:
|
||||
yield DocumentChunk(
|
||||
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
|
||||
text = chunk_text,
|
||||
word_count = self.chunk_size,
|
||||
is_part_of = self.document,
|
||||
chunk_index = self.chunk_index,
|
||||
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.paragraph_chunks = [chunk_data]
|
||||
self.chunk_size = chunk_data["word_count"]
|
||||
|
||||
self.chunk_index += 1
|
||||
|
||||
if len(self.paragraph_chunks) > 0:
|
||||
yield DocumentChunk(
|
||||
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
|
||||
word_count = self.chunk_size,
|
||||
document_id = str(self.id),
|
||||
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{self.chunk_index}")),
|
||||
chunk_index = self.chunk_index,
|
||||
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
|
||||
)
|
||||
try:
|
||||
yield DocumentChunk(
|
||||
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
|
||||
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
|
||||
word_count = self.chunk_size,
|
||||
is_part_of = self.document,
|
||||
chunk_index = self.chunk_index,
|
||||
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
from .models.DocumentChunk import DocumentChunk
|
||||
from .TextChunker import TextChunker
|
||||
|
|
@ -1,9 +1,14 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
|
||||
class DocumentChunk(BaseModel):
|
||||
class DocumentChunk(DataPoint):
|
||||
text: str
|
||||
word_count: int
|
||||
document_id: str
|
||||
chunk_id: str
|
||||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
|
||||
_metadata: Optional[dict] = {
|
||||
"index_fields": ["text"],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .knowledge_graph.extract_content_graph import extract_content_graph
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .extract_content_graph import extract_content_graph
|
||||
|
|
@ -1,36 +1,36 @@
|
|||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def detect_language(data:str):
|
||||
async def detect_language(text: str):
|
||||
"""
|
||||
Detect the language of the given text and return its ISO 639-1 language code.
|
||||
If the detected language is Croatian ('hr'), it maps to Serbian ('sr').
|
||||
If the detected language is Croatian ("hr"), it maps to Serbian ("sr").
|
||||
The text is trimmed to the first 100 characters for efficient processing.
|
||||
Parameters:
|
||||
text (str): The text for language detection.
|
||||
Returns:
|
||||
str: The ISO 639-1 language code of the detected language, or 'None' in case of an error.
|
||||
str: The ISO 639-1 language code of the detected language, or "None" in case of an error.
|
||||
"""
|
||||
|
||||
# Trim the text to the first 100 characters
|
||||
from langdetect import detect, LangDetectException
|
||||
trimmed_text = data[:100]
|
||||
# Trim the text to the first 100 characters
|
||||
trimmed_text = text[:100]
|
||||
|
||||
try:
|
||||
# Detect the language using langdetect
|
||||
detected_lang_iso639_1 = detect(trimmed_text)
|
||||
logging.info(f"Detected ISO 639-1 code: {detected_lang_iso639_1}")
|
||||
|
||||
# Special case: map 'hr' (Croatian) to 'sr' (Serbian ISO 639-2)
|
||||
if detected_lang_iso639_1 == 'hr':
|
||||
yield 'sr'
|
||||
yield detected_lang_iso639_1
|
||||
# Special case: map "hr" (Croatian) to "sr" (Serbian ISO 639-2)
|
||||
if detected_lang_iso639_1 == "hr":
|
||||
return "sr"
|
||||
|
||||
return detected_lang_iso639_1
|
||||
|
||||
except LangDetectException as e:
|
||||
logging.error(f"Language detection error: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error: {e}")
|
||||
logger.error(f"Language detection error: {e}")
|
||||
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
|
||||
return None
|
||||
41
cognee/modules/data/operations/translate_text.py
Normal file
41
cognee/modules/data/operations/translate_text.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def translate_text(text, source_language: str = "sr", target_language: str = "en", region_name = "eu-west-1"):
|
||||
"""
|
||||
Translate text from source language to target language using AWS Translate.
|
||||
Parameters:
|
||||
text (str): The text to be translated.
|
||||
source_language (str): The source language code (e.g., "sr" for Serbian). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
|
||||
target_language (str): The target language code (e.g., "en" for English). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
|
||||
region_name (str): AWS region name.
|
||||
Returns:
|
||||
str: Translated text or an error message.
|
||||
"""
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError, ClientError
|
||||
|
||||
if not text:
|
||||
raise ValueError("No text to translate.")
|
||||
|
||||
if not source_language or not target_language:
|
||||
raise ValueError("Source and target language codes are required.")
|
||||
|
||||
try:
|
||||
translate = boto3.client(service_name = "translate", region_name = region_name, use_ssl = True)
|
||||
result = translate.translate_text(
|
||||
Text = text,
|
||||
SourceLanguageCode = source_language,
|
||||
TargetLanguageCode = target_language,
|
||||
)
|
||||
yield result.get("TranslatedText", "No translation found.")
|
||||
|
||||
except BotoCoreError as e:
|
||||
logger.error(f"BotoCoreError occurred: {e}")
|
||||
yield None
|
||||
|
||||
except ClientError as e:
|
||||
logger.error(f"ClientError occurred: {e}")
|
||||
yield None
|
||||
|
|
@ -1,34 +1,15 @@
|
|||
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
|
||||
class AudioDocument(Document):
|
||||
type: str = "audio"
|
||||
title: str
|
||||
raw_data_location: str
|
||||
chunking_strategy: str
|
||||
|
||||
def __init__(self, id: UUID, title: str, raw_data_location: str, chunking_strategy:str="paragraph"):
|
||||
self.id = id or uuid5(NAMESPACE_OID, title)
|
||||
self.title = title
|
||||
self.raw_data_location = raw_data_location
|
||||
self.chunking_strategy = chunking_strategy
|
||||
|
||||
def read(self, chunk_size: int):
|
||||
# Transcribe the audio file
|
||||
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||
text = result.text
|
||||
|
||||
chunker = TextChunker(self.id, chunk_size = chunk_size, get_text = lambda: text)
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: text)
|
||||
|
||||
yield from chunker.read()
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return dict(
|
||||
id=str(self.id),
|
||||
type=self.type,
|
||||
title=self.title,
|
||||
raw_data_location=self.raw_data_location,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
from uuid import UUID
|
||||
from typing import Protocol
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
class Document(Protocol):
|
||||
id: UUID
|
||||
class Document(DataPoint):
|
||||
type: str
|
||||
title: str
|
||||
name: str
|
||||
raw_data_location: str
|
||||
|
||||
def read(self, chunk_size: int) -> str:
|
||||
|
|
|
|||
|
|
@ -1,33 +1,15 @@
|
|||
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
|
||||
|
||||
class ImageDocument(Document):
|
||||
type: str = "image"
|
||||
title: str
|
||||
raw_data_location: str
|
||||
|
||||
def __init__(self, id: UUID, title: str, raw_data_location: str):
|
||||
self.id = id or uuid5(NAMESPACE_OID, title)
|
||||
self.title = title
|
||||
self.raw_data_location = raw_data_location
|
||||
|
||||
def read(self, chunk_size: int):
|
||||
# Transcribe the image file
|
||||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||
text = result.choices[0].message.content
|
||||
|
||||
chunker = TextChunker(self.id, chunk_size = chunk_size, get_text = lambda: text)
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: text)
|
||||
|
||||
yield from chunker.read()
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return dict(
|
||||
id=str(self.id),
|
||||
type=self.type,
|
||||
title=self.title,
|
||||
raw_data_location=self.raw_data_location,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,11 @@
|
|||
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||
from pypdf import PdfReader
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
|
||||
class PdfDocument(Document):
|
||||
type: str = "pdf"
|
||||
title: str
|
||||
raw_data_location: str
|
||||
|
||||
def __init__(self, id: UUID, title: str, raw_data_location: str):
|
||||
self.id = id or uuid5(NAMESPACE_OID, title)
|
||||
self.title = title
|
||||
self.raw_data_location = raw_data_location
|
||||
|
||||
def read(self, chunk_size: int) -> PdfReader:
|
||||
def read(self, chunk_size: int):
|
||||
file = PdfReader(self.raw_data_location)
|
||||
|
||||
def get_text():
|
||||
|
|
@ -21,16 +13,8 @@ class PdfDocument(Document):
|
|||
page_text = page.extract_text()
|
||||
yield page_text
|
||||
|
||||
chunker = TextChunker(self.id, chunk_size = chunk_size, get_text = get_text)
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
|
||||
|
||||
yield from chunker.read()
|
||||
|
||||
file.stream.close()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return dict(
|
||||
id = str(self.id),
|
||||
type = self.type,
|
||||
title = self.title,
|
||||
raw_data_location = self.raw_data_location,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,8 @@
|
|||
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
|
||||
class TextDocument(Document):
|
||||
type: str = "text"
|
||||
title: str
|
||||
raw_data_location: str
|
||||
|
||||
def __init__(self, id: UUID, title: str, raw_data_location: str):
|
||||
self.id = id or uuid5(NAMESPACE_OID, title)
|
||||
self.title = title
|
||||
self.raw_data_location = raw_data_location
|
||||
|
||||
def read(self, chunk_size: int):
|
||||
def get_text():
|
||||
|
|
@ -23,16 +15,6 @@ class TextDocument(Document):
|
|||
|
||||
yield text
|
||||
|
||||
|
||||
chunker = TextChunker(self.id,chunk_size = chunk_size, get_text = get_text)
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
|
||||
|
||||
yield from chunker.read()
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return dict(
|
||||
id = str(self.id),
|
||||
type = self.type,
|
||||
title = self.title,
|
||||
raw_data_location = self.raw_data_location,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .Document import Document
|
||||
from .PdfDocument import PdfDocument
|
||||
from .TextDocument import TextDocument
|
||||
from .ImageDocument import ImageDocument
|
||||
|
|
|
|||
12
cognee/modules/engine/models/Entity.py
Normal file
12
cognee/modules/engine/models/Entity.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from .EntityType import EntityType
|
||||
|
||||
class Entity(DataPoint):
|
||||
name: str
|
||||
is_a: EntityType
|
||||
description: str
|
||||
mentioned_in: DocumentChunk
|
||||
_metadata: dict = {
|
||||
"index_fields": ["name"],
|
||||
}
|
||||
11
cognee/modules/engine/models/EntityType.py
Normal file
11
cognee/modules/engine/models/EntityType.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
|
||||
class EntityType(DataPoint):
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
exists_in: DocumentChunk
|
||||
_metadata: dict = {
|
||||
"index_fields": ["name"],
|
||||
}
|
||||
2
cognee/modules/engine/models/__init__.py
Normal file
2
cognee/modules/engine/models/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from .Entity import Entity
|
||||
from .EntityType import EntityType
|
||||
3
cognee/modules/engine/utils/__init__.py
Normal file
3
cognee/modules/engine/utils/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .generate_node_id import generate_node_id
|
||||
from .generate_node_name import generate_node_name
|
||||
from .generate_edge_name import generate_edge_name
|
||||
2
cognee/modules/engine/utils/generate_edge_name.py
Normal file
2
cognee/modules/engine/utils/generate_edge_name.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
def generate_edge_name(name: str) -> str:
|
||||
return name.lower().replace(" ", "_").replace("'", "")
|
||||
4
cognee/modules/engine/utils/generate_node_id.py
Normal file
4
cognee/modules/engine/utils/generate_node_id.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
def generate_node_id(node_id: str) -> str:
|
||||
return uuid5(NAMESPACE_OID, node_id.lower().replace(" ", "_").replace("'", ""))
|
||||
2
cognee/modules/engine/utils/generate_node_name.py
Normal file
2
cognee/modules/engine/utils/generate_node_name.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
def generate_node_name(name: str) -> str:
|
||||
return name.lower().replace("'", "")
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
def generate_node_name(name: str) -> str:
|
||||
return name.lower().replace(" ", "_").replace("'", "")
|
||||
|
||||
def generate_node_id(node_id: str) -> str:
|
||||
return node_id.lower().replace(" ", "_").replace("'", "")
|
||||
2
cognee/modules/graph/utils/__init__.py
Normal file
2
cognee/modules/graph/utils/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from .get_graph_from_model import get_graph_from_model
|
||||
from .get_model_instance_from_graph import get_model_instance_from_graph
|
||||
107
cognee/modules/graph/utils/get_graph_from_model.py
Normal file
107
cognee/modules/graph/utils/get_graph_from_model.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
from datetime import datetime, timezone
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import copy_model
|
||||
|
||||
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
data_point_properties = {}
|
||||
excluded_properties = set()
|
||||
|
||||
for field_name, field_value in data_point:
|
||||
if field_name == "_metadata":
|
||||
continue
|
||||
|
||||
if isinstance(field_value, DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
|
||||
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
|
||||
|
||||
for node in property_nodes:
|
||||
if str(node.id) not in added_nodes:
|
||||
nodes.append(node)
|
||||
added_nodes[str(node.id)] = True
|
||||
|
||||
for edge in property_edges:
|
||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append(edge)
|
||||
added_edges[str(edge_key)] = True
|
||||
|
||||
for property_node in get_own_properties(property_nodes, property_edges):
|
||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append((data_point.id, property_node.id, field_name, {
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}))
|
||||
added_edges[str(edge_key)] = True
|
||||
continue
|
||||
|
||||
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
|
||||
for item in field_value:
|
||||
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)
|
||||
|
||||
for node in property_nodes:
|
||||
if str(node.id) not in added_nodes:
|
||||
nodes.append(node)
|
||||
added_nodes[str(node.id)] = True
|
||||
|
||||
for edge in property_edges:
|
||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append(edge)
|
||||
added_edges[edge_key] = True
|
||||
|
||||
for property_node in get_own_properties(property_nodes, property_edges):
|
||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append((data_point.id, property_node.id, field_name, {
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"metadata": {
|
||||
"type": "list"
|
||||
},
|
||||
}))
|
||||
added_edges[edge_key] = True
|
||||
continue
|
||||
|
||||
data_point_properties[field_name] = field_value
|
||||
|
||||
SimpleDataPointModel = copy_model(
|
||||
type(data_point),
|
||||
include_fields = {
|
||||
"_metadata": (dict, data_point._metadata),
|
||||
},
|
||||
exclude_fields = excluded_properties,
|
||||
)
|
||||
|
||||
if include_root:
|
||||
nodes.append(SimpleDataPointModel(**data_point_properties))
|
||||
|
||||
return nodes, edges
|
||||
|
||||
|
||||
def get_own_properties(property_nodes, property_edges):
|
||||
own_properties = []
|
||||
|
||||
destination_nodes = [str(property_edge[1]) for property_edge in property_edges]
|
||||
|
||||
for node in property_nodes:
|
||||
if str(node.id) in destination_nodes:
|
||||
continue
|
||||
|
||||
own_properties.append(node)
|
||||
|
||||
return own_properties
|
||||
29
cognee/modules/graph/utils/get_model_instance_from_graph.py
Normal file
29
cognee/modules/graph/utils/get_model_instance_from_graph.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from pydantic_core import PydanticUndefined
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import copy_model
|
||||
|
||||
|
||||
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
|
||||
node_map = {}
|
||||
|
||||
for node in nodes:
|
||||
node_map[node.id] = node
|
||||
|
||||
for edge in edges:
|
||||
source_node = node_map[edge[0]]
|
||||
target_node = node_map[edge[1]]
|
||||
edge_label = edge[2]
|
||||
edge_properties = edge[3] if len(edge) == 4 else {}
|
||||
edge_metadata = edge_properties.get("metadata", {})
|
||||
edge_type = edge_metadata.get("type")
|
||||
|
||||
if edge_type == "list":
|
||||
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
|
||||
|
||||
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
|
||||
else:
|
||||
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })
|
||||
|
||||
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
|
||||
|
||||
return node_map[entity_id]
|
||||
|
|
@ -7,7 +7,7 @@ from ..tasks.Task import Task
|
|||
|
||||
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
|
||||
|
||||
async def run_tasks_base(tasks: [Task], data = None, user: User = None):
|
||||
async def run_tasks_base(tasks: list[Task], data = None, user: User = None):
|
||||
if len(tasks) == 0:
|
||||
yield data
|
||||
return
|
||||
|
|
@ -16,7 +16,7 @@ async def run_tasks_base(tasks: [Task], data = None, user: User = None):
|
|||
|
||||
running_task = tasks[0]
|
||||
leftover_tasks = tasks[1:]
|
||||
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
|
||||
next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
|
||||
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
||||
|
||||
if inspect.isasyncgenfunction(running_task.executable):
|
||||
|
|
|
|||
|
|
@ -1,33 +0,0 @@
|
|||
import asyncio
|
||||
import nest_asyncio
|
||||
import dspy
|
||||
from cognee.modules.search.vector.search_similarity import search_similarity
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
class AnswerFromContext(dspy.Signature):
|
||||
question: str = dspy.InputField()
|
||||
context: str = dspy.InputField(desc = "Context to use for answer generation.")
|
||||
answer: str = dspy.OutputField()
|
||||
|
||||
question_answer_llm = dspy.OpenAI(model = "gpt-3.5-turbo-instruct")
|
||||
|
||||
class CogneeSearch(dspy.Module):
|
||||
def __init__(self, ):
|
||||
super().__init__()
|
||||
self.generate_answer = dspy.TypedChainOfThought(AnswerFromContext)
|
||||
|
||||
def forward(self, question):
|
||||
context = asyncio.run(search_similarity(question))
|
||||
|
||||
context_text = "\n".join(context)
|
||||
print(f"Context: {context_text}")
|
||||
|
||||
with dspy.context(lm = question_answer_llm):
|
||||
answer_prediction = self.generate_answer(context = context_text, question = question)
|
||||
answer = answer_prediction.answer
|
||||
|
||||
print(f"Question: {question}")
|
||||
print(f"Answer: {answer}")
|
||||
|
||||
return dspy.Prediction(context = context_text, answer = answer)
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
import asyncio
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
async def search_adjacent(query: str) -> list[(str, str)]:
|
||||
"""
|
||||
Find the neighbours of a given node in the graph and return their ids and descriptions.
|
||||
|
||||
Parameters:
|
||||
- query (str): The query string to filter nodes by.
|
||||
|
||||
Returns:
|
||||
- list[(str, str)]: A list containing the unique identifiers and names of the neighbours of the given node.
|
||||
"""
|
||||
node_id = query
|
||||
|
||||
if node_id is None:
|
||||
return {}
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
exact_node = await graph_engine.extract_node(node_id)
|
||||
|
||||
if exact_node is not None and "uuid" in exact_node:
|
||||
neighbours = await graph_engine.get_neighbours(exact_node["uuid"])
|
||||
else:
|
||||
vector_engine = get_vector_engine()
|
||||
results = await asyncio.gather(
|
||||
vector_engine.search("entities", query_text = query, limit = 10),
|
||||
vector_engine.search("classification", query_text = query, limit = 10),
|
||||
)
|
||||
results = [*results[0], *results[1]]
|
||||
relevant_results = [result for result in results if result.score < 0.5][:5]
|
||||
|
||||
if len(relevant_results) == 0:
|
||||
return []
|
||||
|
||||
node_neighbours = await asyncio.gather(*[graph_engine.get_neighbours(result.id) for result in relevant_results])
|
||||
neighbours = []
|
||||
for neighbour_ids in node_neighbours:
|
||||
neighbours.extend(neighbour_ids)
|
||||
|
||||
return neighbours
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine, get_graph_config
|
||||
|
||||
async def search_cypher(query: str):
|
||||
"""
|
||||
Use a Cypher query to search the graph and return the results.
|
||||
"""
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider == "neo4j":
|
||||
graph_engine = await get_graph_engine()
|
||||
result = await graph_engine.graph().run(query)
|
||||
return result
|
||||
else:
|
||||
raise ValueError("Unsupported search type for the used graph engine.")
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
async def search_similarity(query: str) -> list[str, str]:
|
||||
"""
|
||||
Parameters:
|
||||
- query (str): The query string to filter nodes by.
|
||||
|
||||
Returns:
|
||||
- list(chunk): A list of objects providing information about the chunks related to query.
|
||||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
similar_results = await vector_engine.search("chunks", query, limit = 5)
|
||||
|
||||
results = [
|
||||
parse_payload(result.payload) for result in similar_results
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def parse_payload(payload: dict) -> dict:
|
||||
return {
|
||||
"text": payload["text"],
|
||||
"chunk_id": payload["chunk_id"],
|
||||
"document_id": payload["document_id"],
|
||||
}
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
async def search_summary(query: str) -> list:
|
||||
"""
|
||||
Parameters:
|
||||
- query (str): The query string to filter summaries by.
|
||||
|
||||
Returns:
|
||||
- list[str, UUID]: A list of objects providing information about the summaries related to query.
|
||||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
summaries_results = await vector_engine.search("summaries", query, limit = 5)
|
||||
|
||||
summaries = [summary.payload for summary in summaries_results]
|
||||
|
||||
return summaries
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
||||
async def categorize_relevant_category(query: str, summary, response_model: Type[BaseModel]):
|
||||
llm_client = get_llm_client()
|
||||
|
||||
enriched_query= render_prompt("categorize_categories.txt", {"query": query, "categories": summary})
|
||||
|
||||
|
||||
system_prompt = " Choose the relevant categories and return appropriate output based on the model"
|
||||
|
||||
llm_output = await llm_client.acreate_structured_output(enriched_query, system_prompt, response_model)
|
||||
|
||||
return llm_output.model_dump()
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.prompts import render_prompt
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
||||
async def categorize_relevant_summary(query: str, summaries, response_model: Type[BaseModel]):
|
||||
llm_client = get_llm_client()
|
||||
|
||||
enriched_query= render_prompt("categorize_summary.txt", {"query": query, "summaries": summaries})
|
||||
|
||||
system_prompt = "Choose the relevant summaries and return appropriate output based on the model"
|
||||
|
||||
llm_output = await llm_client.acreate_structured_output(enriched_query, system_prompt, response_model)
|
||||
|
||||
return llm_output
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
import logging
|
||||
from typing import List, Dict
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from .extraction.categorize_relevant_summary import categorize_relevant_summary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
async def get_cognitive_layers(content: str, categories: List[Dict]):
|
||||
try:
|
||||
cognify_config = get_cognify_config()
|
||||
return (await categorize_relevant_summary(
|
||||
content,
|
||||
categories[0],
|
||||
cognify_config.summarization_model,
|
||||
)).cognitive_layers
|
||||
except Exception as error:
|
||||
logger.error("Error extracting cognitive layers from content: %s", error, exc_info = True)
|
||||
raise error
|
||||
|
|
@ -1 +0,0 @@
|
|||
""" Placeholder for BM25 implementation"""
|
||||
|
|
@ -1 +0,0 @@
|
|||
"""Placeholder for fusions search implementation"""
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
import asyncio
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
async def search_traverse(query: str):
|
||||
node_id = query
|
||||
rules = set()
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
exact_node = await graph_engine.extract_node(node_id)
|
||||
|
||||
if exact_node is not None and "uuid" in exact_node:
|
||||
edges = await graph_engine.get_edges(exact_node["uuid"])
|
||||
|
||||
for edge in edges:
|
||||
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
|
||||
else:
|
||||
results = await asyncio.gather(
|
||||
vector_engine.search("entities", query_text = query, limit = 10),
|
||||
vector_engine.search("classification", query_text = query, limit = 10),
|
||||
)
|
||||
results = [*results[0], *results[1]]
|
||||
relevant_results = [result for result in results if result.score < 0.5][:5]
|
||||
|
||||
if len(relevant_results) > 0:
|
||||
for result in relevant_results:
|
||||
graph_node_id = result.id
|
||||
|
||||
edges = await graph_engine.get_edges(graph_node_id)
|
||||
|
||||
for edge in edges:
|
||||
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
|
||||
|
||||
return list(rules)
|
||||
46
cognee/modules/storage/utils/__init__.py
Normal file
46
cognee/modules/storage/utils/__init__.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import json
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat() # Convert datetime to ISO 8601 string
|
||||
elif isinstance(obj, UUID):
|
||||
# if the obj is uuid, we simply return the value of uuid
|
||||
return str(obj)
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
from pydantic import create_model
|
||||
|
||||
def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list = []):
|
||||
fields = {
|
||||
name: (field.annotation, field.default if field.default is not None else PydanticUndefined)
|
||||
for name, field in model.model_fields.items()
|
||||
if name not in exclude_fields
|
||||
}
|
||||
|
||||
final_fields = {
|
||||
**fields,
|
||||
**include_fields
|
||||
}
|
||||
|
||||
return create_model(model.__name__, **final_fields)
|
||||
|
||||
def get_own_properties(data_point: DataPoint):
|
||||
properties = {}
|
||||
|
||||
for field_name, field_value in data_point:
|
||||
if field_name == "_metadata" \
|
||||
or isinstance(field_value, dict) \
|
||||
or isinstance(field_value, DataPoint) \
|
||||
or (isinstance(field_value, list) and isinstance(field_value[0], DataPoint)):
|
||||
continue
|
||||
|
||||
properties[field_name] = field_value
|
||||
|
||||
return properties
|
||||
|
|
@ -1,84 +1,95 @@
|
|||
from typing import List, Union, Literal, Optional
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, List, Union, Literal, Optional
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
class BaseClass(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Class"] = "Class"
|
||||
description: str
|
||||
constructor_parameters: Optional[List[str]] = None
|
||||
|
||||
class Class(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Class"] = "Class"
|
||||
description: str
|
||||
constructor_parameters: Optional[List[str]] = None
|
||||
from_class: Optional[BaseClass] = None
|
||||
|
||||
class ClassInstance(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["ClassInstance"] = "ClassInstance"
|
||||
description: str
|
||||
from_class: Class
|
||||
|
||||
class Function(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Function"] = "Function"
|
||||
description: str
|
||||
parameters: Optional[List[str]] = None
|
||||
return_type: str
|
||||
is_static: Optional[bool] = False
|
||||
|
||||
class Variable(BaseModel):
|
||||
class Variable(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Variable"] = "Variable"
|
||||
description: str
|
||||
is_static: Optional[bool] = False
|
||||
default_value: Optional[str] = None
|
||||
data_type: str
|
||||
|
||||
class Operator(BaseModel):
|
||||
_metadata = {
|
||||
"index_fields": ["name"]
|
||||
}
|
||||
|
||||
class Operator(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Operator"] = "Operator"
|
||||
description: str
|
||||
return_type: str
|
||||
|
||||
class ExpressionPart(BaseModel):
|
||||
class Class(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Class"] = "Class"
|
||||
description: str
|
||||
constructor_parameters: List[Variable]
|
||||
extended_from_class: Optional["Class"] = None
|
||||
has_methods: list["Function"]
|
||||
|
||||
_metadata = {
|
||||
"index_fields": ["name"]
|
||||
}
|
||||
|
||||
class ClassInstance(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["ClassInstance"] = "ClassInstance"
|
||||
description: str
|
||||
from_class: Class
|
||||
instantiated_by: Union["Function"]
|
||||
instantiation_arguments: List[Variable]
|
||||
|
||||
_metadata = {
|
||||
"index_fields": ["name"]
|
||||
}
|
||||
|
||||
class Function(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Function"] = "Function"
|
||||
description: str
|
||||
parameters: List[Variable]
|
||||
return_type: str
|
||||
is_static: Optional[bool] = False
|
||||
|
||||
_metadata = {
|
||||
"index_fields": ["name"]
|
||||
}
|
||||
|
||||
class FunctionCall(DataPoint):
|
||||
id: str
|
||||
type: Literal["FunctionCall"] = "FunctionCall"
|
||||
called_by: Union[Function, Literal["main"]]
|
||||
function_called: Function
|
||||
function_arguments: List[Any]
|
||||
|
||||
class Expression(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Expression"] = "Expression"
|
||||
description: str
|
||||
expression: str
|
||||
members: List[Union[Variable, Function, Operator]]
|
||||
members: List[Union[Variable, Function, Operator, "Expression"]]
|
||||
|
||||
class Expression(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Expression"] = "Expression"
|
||||
description: str
|
||||
expression: str
|
||||
members: List[Union[Variable, Function, Operator, ExpressionPart]]
|
||||
|
||||
class Edge(BaseModel):
|
||||
source_node_id: str
|
||||
target_node_id: str
|
||||
relationship_name: Literal["called in", "stored in", "defined in", "returned by", "instantiated in", "uses", "updates"]
|
||||
|
||||
class SourceCodeGraph(BaseModel):
|
||||
class SourceCodeGraph(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
language: str
|
||||
nodes: List[Union[
|
||||
Class,
|
||||
ClassInstance,
|
||||
Function,
|
||||
FunctionCall,
|
||||
Variable,
|
||||
Operator,
|
||||
Expression,
|
||||
ClassInstance,
|
||||
]]
|
||||
edges: List[Edge]
|
||||
|
||||
Class.model_rebuild()
|
||||
ClassInstance.model_rebuild()
|
||||
Expression.model_rebuild()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
""" This module contains utility functions for the cognee. """
|
||||
import os
|
||||
import datetime
|
||||
from datetime import datetime, timezone
|
||||
import graphistry
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
|
@ -45,7 +45,7 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
|
|||
host = "https://eu.i.posthog.com"
|
||||
)
|
||||
|
||||
current_time = datetime.datetime.now()
|
||||
current_time = datetime.now(timezone.utc)
|
||||
properties = {
|
||||
"time": current_time.strftime("%m/%d/%Y"),
|
||||
"user_id": user_id,
|
||||
|
|
@ -110,30 +110,36 @@ async def register_graphistry():
|
|||
graphistry.register(api = 3, username = config.graphistry_username, password = config.graphistry_password)
|
||||
|
||||
|
||||
def prepare_edges(graph):
|
||||
return nx.to_pandas_edgelist(graph)
|
||||
def prepare_edges(graph, source, target, edge_key):
|
||||
edge_list = [{
|
||||
source: str(edge[0]),
|
||||
target: str(edge[1]),
|
||||
edge_key: str(edge[2]),
|
||||
} for edge in graph.edges(keys = True, data = True)]
|
||||
|
||||
return pd.DataFrame(edge_list)
|
||||
|
||||
|
||||
def prepare_nodes(graph, include_size=False):
|
||||
nodes_data = []
|
||||
for node in graph.nodes:
|
||||
node_info = graph.nodes[node]
|
||||
description = node_info.get("layer_description", {}).get("layer", "Default Layer") if isinstance(
|
||||
node_info.get("layer_description"), dict) else node_info.get("layer_description", "Default Layer")
|
||||
# description = node_info['layer_description']['layer'] if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info['layer_description'] else node_info.get('layer_description', node)
|
||||
# if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info.get('layer_description'):
|
||||
# description = node_info['layer_description']['layer']
|
||||
# # Use 'layer_description' directly if it's not a dictionary, otherwise default to node ID
|
||||
# else:
|
||||
# description = node_info.get('layer_description', node)
|
||||
|
||||
node_data = {"id": node, "layer_description": description}
|
||||
if not node_info:
|
||||
continue
|
||||
|
||||
node_data = {
|
||||
"id": str(node),
|
||||
"name": node_info["name"] if "name" in node_info else str(node),
|
||||
}
|
||||
|
||||
if include_size:
|
||||
default_size = 10 # Default node size
|
||||
larger_size = 20 # Size for nodes with specific keywords in their ID
|
||||
keywords = ["DOCUMENT", "User", "LAYER"]
|
||||
keywords = ["DOCUMENT", "User"]
|
||||
node_size = larger_size if any(keyword in str(node) for keyword in keywords) else default_size
|
||||
node_data["size"] = node_size
|
||||
|
||||
nodes_data.append(node_data)
|
||||
|
||||
return pd.DataFrame(nodes_data)
|
||||
|
|
@ -153,28 +159,28 @@ async def render_graph(graph, include_nodes=False, include_color=False, include_
|
|||
|
||||
graph = networkx_graph
|
||||
|
||||
edges = prepare_edges(graph)
|
||||
plotter = graphistry.edges(edges, "source", "target")
|
||||
edges = prepare_edges(graph, "source_node", "target_node", "relationship_name")
|
||||
plotter = graphistry.edges(edges, "source_node", "target_node")
|
||||
plotter = plotter.bind(edge_label = "relationship_name")
|
||||
|
||||
if include_nodes:
|
||||
nodes = prepare_nodes(graph, include_size=include_size)
|
||||
nodes = prepare_nodes(graph, include_size = include_size)
|
||||
plotter = plotter.nodes(nodes, "id")
|
||||
|
||||
|
||||
if include_size:
|
||||
plotter = plotter.bind(point_size="size")
|
||||
plotter = plotter.bind(point_size = "size")
|
||||
|
||||
|
||||
if include_color:
|
||||
unique_layers = nodes["layer_description"].unique()
|
||||
color_palette = generate_color_palette(unique_layers)
|
||||
plotter = plotter.encode_point_color("layer_description", categorical_mapping=color_palette,
|
||||
default_mapping="silver")
|
||||
pass
|
||||
# unique_layers = nodes["layer_description"].unique()
|
||||
# color_palette = generate_color_palette(unique_layers)
|
||||
# plotter = plotter.encode_point_color("layer_description", categorical_mapping=color_palette,
|
||||
# default_mapping="silver")
|
||||
|
||||
|
||||
if include_labels:
|
||||
plotter = plotter.bind(point_label = "layer_description")
|
||||
|
||||
plotter = plotter.bind(point_label = "name")
|
||||
|
||||
|
||||
# Visualization
|
||||
|
|
|
|||
|
|
@ -1,10 +0,0 @@
|
|||
from .summarization.summarize_text import summarize_text
|
||||
from .chunk_naive_llm_classifier.chunk_naive_llm_classifier import chunk_naive_llm_classifier
|
||||
from .chunk_remove_disconnected.chunk_remove_disconnected import chunk_remove_disconnected
|
||||
from .chunk_update_check.chunk_update_check import chunk_update_check
|
||||
from .save_chunks_to_store.save_chunks_to_store import save_chunks_to_store
|
||||
from .source_documents_to_chunks.source_documents_to_chunks import source_documents_to_chunks
|
||||
from .infer_data_ontology.infer_data_ontology import infer_data_ontology
|
||||
from .check_permissions_on_documents.check_permissions_on_documents import check_permissions_on_documents
|
||||
from .classify_documents.classify_documents import classify_documents
|
||||
from .graph.chunks_into_graph import chunks_into_graph
|
||||
|
|
@ -5,7 +5,7 @@ from pydantic import BaseModel
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
|
||||
from cognee.modules.data.extraction.extract_categories import extract_categories
|
||||
from cognee.modules.chunking import DocumentChunk
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
|
||||
|
||||
async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
|
||||
|
|
@ -65,7 +65,7 @@ async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classific
|
|||
"chunk_id": str(data_chunk.chunk_id),
|
||||
"document_id": str(data_chunk.document_id),
|
||||
}),
|
||||
embed_field="text",
|
||||
index_fields=["text"],
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -104,7 +104,7 @@ async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classific
|
|||
"chunk_id": str(data_chunk.chunk_id),
|
||||
"document_id": str(data_chunk.document_id),
|
||||
}),
|
||||
embed_field="text",
|
||||
index_fields=["text"],
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,39 +0,0 @@
|
|||
|
||||
import logging
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
BaseConfig = get_base_config()
|
||||
|
||||
async def translate_text(data, source_language:str='sr', target_language:str='en', region_name='eu-west-1'):
|
||||
"""
|
||||
Translate text from source language to target language using AWS Translate.
|
||||
Parameters:
|
||||
data (str): The text to be translated.
|
||||
source_language (str): The source language code (e.g., 'sr' for Serbian). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
|
||||
target_language (str): The target language code (e.g., 'en' for English). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
|
||||
region_name (str): AWS region name.
|
||||
Returns:
|
||||
str: Translated text or an error message.
|
||||
"""
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError, ClientError
|
||||
|
||||
if not data:
|
||||
yield "No text provided for translation."
|
||||
|
||||
if not source_language or not target_language:
|
||||
yield "Both source and target language codes are required."
|
||||
|
||||
try:
|
||||
translate = boto3.client(service_name='translate', region_name=region_name, use_ssl=True)
|
||||
result = translate.translate_text(Text=data, SourceLanguageCode=source_language, TargetLanguageCode=target_language)
|
||||
yield result.get('TranslatedText', 'No translation found.')
|
||||
|
||||
except BotoCoreError as e:
|
||||
logging.info(f"BotoCoreError occurred: {e}")
|
||||
yield "Error with AWS Translate service configuration or request."
|
||||
|
||||
except ClientError as e:
|
||||
logging.info(f"ClientError occurred: {e}")
|
||||
yield "Error with AWS client or network issue."
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.chunking import DocumentChunk
|
||||
|
||||
|
||||
async def chunk_update_check(data_chunks: list[DocumentChunk], collection_name: str) -> list[DocumentChunk]:
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
if not await vector_engine.has_collection(collection_name):
|
||||
# If collection doesn't exist, all data_chunks are new
|
||||
return data_chunks
|
||||
|
||||
existing_chunks = await vector_engine.retrieve(
|
||||
collection_name,
|
||||
[str(chunk.chunk_id) for chunk in data_chunks],
|
||||
)
|
||||
|
||||
existing_chunks_map = {str(chunk.id): chunk.payload for chunk in existing_chunks}
|
||||
|
||||
affected_data_chunks = []
|
||||
|
||||
for chunk in data_chunks:
|
||||
if chunk.chunk_id not in existing_chunks_map or \
|
||||
chunk.text != existing_chunks_map[chunk.chunk_id]["text"]:
|
||||
affected_data_chunks.append(chunk)
|
||||
|
||||
return affected_data_chunks
|
||||
|
|
@ -2,3 +2,4 @@ from .query_chunks import query_chunks
|
|||
from .chunk_by_word import chunk_by_word
|
||||
from .chunk_by_sentence import chunk_by_sentence
|
||||
from .chunk_by_paragraph import chunk_by_paragraph
|
||||
from .remove_disconnected_chunks import remove_disconnected_chunks
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from cognee.tasks.chunking import chunk_by_paragraph
|
||||
from cognee.tasks.chunks import chunk_by_paragraph
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test_chunking_on_whole_text():
|
||||
|
|
@ -10,7 +10,7 @@ async def query_chunks(query: str) -> list[dict]:
|
|||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
found_chunks = await vector_engine.search("chunks", query, limit = 5)
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit = 5)
|
||||
|
||||
chunks = [result.payload for result in found_chunks]
|
||||
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.chunking import DocumentChunk
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
|
||||
async def chunk_remove_disconnected(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
|
||||
async def remove_disconnected_chunks(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
document_ids = set((data_chunk.document_id for data_chunk in data_chunks))
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.data.processing.document_types import Document, PdfDocument, AudioDocument, ImageDocument, TextDocument
|
||||
|
||||
def classify_documents(data_documents: list[Data]) -> list[Document]:
|
||||
documents = [
|
||||
PdfDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "pdf" else
|
||||
AudioDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "audio" else
|
||||
ImageDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "image" else
|
||||
TextDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location)
|
||||
for data_item in data_documents
|
||||
]
|
||||
|
||||
return documents
|
||||
3
cognee/tasks/documents/__init__.py
Normal file
3
cognee/tasks/documents/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .classify_documents import classify_documents
|
||||
from .extract_chunks_from_documents import extract_chunks_from_documents
|
||||
from .check_permissions_on_documents import check_permissions_on_documents
|
||||
13
cognee/tasks/documents/classify_documents.py
Normal file
13
cognee/tasks/documents/classify_documents.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.data.processing.document_types import Document, PdfDocument, AudioDocument, ImageDocument, TextDocument
|
||||
|
||||
def classify_documents(data_documents: list[Data]) -> list[Document]:
|
||||
documents = [
|
||||
PdfDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "pdf" else
|
||||
AudioDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "audio" else
|
||||
ImageDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "image" else
|
||||
TextDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location)
|
||||
for data_item in data_documents
|
||||
]
|
||||
|
||||
return documents
|
||||
7
cognee/tasks/documents/extract_chunks_from_documents.py
Normal file
7
cognee/tasks/documents/extract_chunks_from_documents.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from cognee.modules.data.processing.document_types.Document import Document
|
||||
|
||||
|
||||
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024):
|
||||
for document in documents:
|
||||
for document_chunk in document.read(chunk_size = chunk_size):
|
||||
yield document_chunk
|
||||
|
|
@ -1,2 +1,3 @@
|
|||
from .chunks_into_graph import chunks_into_graph
|
||||
from .extract_graph_from_data import extract_graph_from_data
|
||||
from .extract_graph_from_code import extract_graph_from_code
|
||||
from .query_graph_connections import query_graph_connections
|
||||
|
|
|
|||
|
|
@ -1,213 +0,0 @@
|
|||
import json
|
||||
import asyncio
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from datetime import datetime, timezone
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
|
||||
from cognee.modules.data.extraction.knowledge_graph.extract_content_graph import extract_content_graph
|
||||
from cognee.modules.chunking import DocumentChunk
|
||||
from cognee.modules.graph.utils import generate_node_id, generate_node_name
|
||||
|
||||
|
||||
class EntityNode(BaseModel):
|
||||
uuid: str
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
async def chunks_into_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel], collection_name: str):
|
||||
chunk_graphs = await asyncio.gather(
|
||||
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
has_collection = await vector_engine.has_collection(collection_name)
|
||||
|
||||
if not has_collection:
|
||||
await vector_engine.create_collection(collection_name, payload_schema = EntityNode)
|
||||
|
||||
processed_nodes = {}
|
||||
type_node_edges = []
|
||||
entity_node_edges = []
|
||||
type_entity_edges = []
|
||||
|
||||
for (chunk_index, chunk) in enumerate(data_chunks):
|
||||
chunk_graph = chunk_graphs[chunk_index]
|
||||
for node in chunk_graph.nodes:
|
||||
type_node_id = generate_node_id(node.type)
|
||||
entity_node_id = generate_node_id(node.id)
|
||||
|
||||
if type_node_id not in processed_nodes:
|
||||
type_node_edges.append((str(chunk.chunk_id), type_node_id, "contains_entity_type"))
|
||||
processed_nodes[type_node_id] = True
|
||||
|
||||
if entity_node_id not in processed_nodes:
|
||||
entity_node_edges.append((str(chunk.chunk_id), entity_node_id, "contains_entity"))
|
||||
type_entity_edges.append((entity_node_id, type_node_id, "is_entity_type"))
|
||||
processed_nodes[entity_node_id] = True
|
||||
|
||||
graph_node_edges = [
|
||||
(edge.target_node_id, edge.source_node_id, edge.relationship_name) \
|
||||
for edge in chunk_graph.edges
|
||||
]
|
||||
|
||||
existing_edges = await graph_engine.has_edges([
|
||||
*type_node_edges,
|
||||
*entity_node_edges,
|
||||
*type_entity_edges,
|
||||
*graph_node_edges,
|
||||
])
|
||||
|
||||
existing_edges_map = {}
|
||||
existing_nodes_map = {}
|
||||
|
||||
for edge in existing_edges:
|
||||
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
|
||||
existing_nodes_map[edge[0]] = True
|
||||
|
||||
graph_nodes = []
|
||||
graph_edges = []
|
||||
data_points = []
|
||||
|
||||
for (chunk_index, chunk) in enumerate(data_chunks):
|
||||
graph = chunk_graphs[chunk_index]
|
||||
if graph is None:
|
||||
continue
|
||||
|
||||
for node in graph.nodes:
|
||||
node_id = generate_node_id(node.id)
|
||||
node_name = generate_node_name(node.name)
|
||||
|
||||
type_node_id = generate_node_id(node.type)
|
||||
type_node_name = generate_node_name(node.type)
|
||||
|
||||
if node_id not in existing_nodes_map:
|
||||
node_data = dict(
|
||||
uuid = node_id,
|
||||
name = node_name,
|
||||
type = node_name,
|
||||
description = node.description,
|
||||
created_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
graph_nodes.append((
|
||||
node_id,
|
||||
dict(
|
||||
**node_data,
|
||||
properties = json.dumps(node.properties),
|
||||
)
|
||||
))
|
||||
|
||||
data_points.append(DataPoint[EntityNode](
|
||||
id = str(uuid5(NAMESPACE_OID, node_id)),
|
||||
payload = node_data,
|
||||
embed_field = "name",
|
||||
))
|
||||
|
||||
existing_nodes_map[node_id] = True
|
||||
|
||||
edge_key = str(chunk.chunk_id) + node_id + "contains_entity"
|
||||
|
||||
if edge_key not in existing_edges_map:
|
||||
graph_edges.append((
|
||||
str(chunk.chunk_id),
|
||||
node_id,
|
||||
"contains_entity",
|
||||
dict(
|
||||
relationship_name = "contains_entity",
|
||||
source_node_id = str(chunk.chunk_id),
|
||||
target_node_id = node_id,
|
||||
),
|
||||
))
|
||||
|
||||
# Add relationship between entity type and entity itself: "Jake is Person"
|
||||
graph_edges.append((
|
||||
node_id,
|
||||
type_node_id,
|
||||
"is_entity_type",
|
||||
dict(
|
||||
relationship_name = "is_entity_type",
|
||||
source_node_id = type_node_id,
|
||||
target_node_id = node_id,
|
||||
),
|
||||
))
|
||||
|
||||
existing_edges_map[edge_key] = True
|
||||
|
||||
if type_node_id not in existing_nodes_map:
|
||||
type_node_data = dict(
|
||||
uuid = type_node_id,
|
||||
name = type_node_name,
|
||||
type = type_node_id,
|
||||
description = type_node_name,
|
||||
created_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
graph_nodes.append((type_node_id, dict(
|
||||
**type_node_data,
|
||||
properties = json.dumps(node.properties)
|
||||
)))
|
||||
|
||||
data_points.append(DataPoint[EntityNode](
|
||||
id = str(uuid5(NAMESPACE_OID, type_node_id)),
|
||||
payload = type_node_data,
|
||||
embed_field = "name",
|
||||
))
|
||||
|
||||
existing_nodes_map[type_node_id] = True
|
||||
|
||||
edge_key = str(chunk.chunk_id) + type_node_id + "contains_entity_type"
|
||||
|
||||
if edge_key not in existing_edges_map:
|
||||
graph_edges.append((
|
||||
str(chunk.chunk_id),
|
||||
type_node_id,
|
||||
"contains_entity_type",
|
||||
dict(
|
||||
relationship_name = "contains_entity_type",
|
||||
source_node_id = str(chunk.chunk_id),
|
||||
target_node_id = type_node_id,
|
||||
),
|
||||
))
|
||||
|
||||
existing_edges_map[edge_key] = True
|
||||
|
||||
# Add relationship that came from graphs.
|
||||
for edge in graph.edges:
|
||||
source_node_id = generate_node_id(edge.source_node_id)
|
||||
target_node_id = generate_node_id(edge.target_node_id)
|
||||
relationship_name = generate_node_name(edge.relationship_name)
|
||||
edge_key = source_node_id + target_node_id + relationship_name
|
||||
|
||||
if edge_key not in existing_edges_map:
|
||||
graph_edges.append((
|
||||
generate_node_id(edge.source_node_id),
|
||||
generate_node_id(edge.target_node_id),
|
||||
edge.relationship_name,
|
||||
dict(
|
||||
relationship_name = generate_node_name(edge.relationship_name),
|
||||
source_node_id = generate_node_id(edge.source_node_id),
|
||||
target_node_id = generate_node_id(edge.target_node_id),
|
||||
properties = json.dumps(edge.properties),
|
||||
),
|
||||
))
|
||||
existing_edges_map[edge_key] = True
|
||||
|
||||
if len(data_points) > 0:
|
||||
await vector_engine.create_data_points(collection_name, data_points)
|
||||
|
||||
if len(graph_nodes) > 0:
|
||||
await graph_engine.add_nodes(graph_nodes)
|
||||
|
||||
if len(graph_edges) > 0:
|
||||
await graph_engine.add_edges(graph_edges)
|
||||
|
||||
return data_chunks
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue