Merge branch 'main' into COG-575-remove-graph-overwrite-on-error

This commit is contained in:
Vasilije 2024-11-12 10:18:09 +01:00 committed by GitHub
commit be792a7ba6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
127 changed files with 3122 additions and 2951 deletions

View file

@ -5,7 +5,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@ -22,7 +22,7 @@ jobs:
run_neo4j_integration_test: run_neo4j_integration_test:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:

View file

@ -5,7 +5,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
@ -23,7 +23,7 @@ jobs:
run_notebook_test: run_notebook_test:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:
run: run:

View file

@ -5,7 +5,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
@ -23,7 +23,7 @@ jobs:
run_pgvector_integration_test: run_pgvector_integration_test:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:
run: run:

View file

@ -5,10 +5,10 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ${{ github.event.label.name == 'run-checks' }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true cancel-in-progress: true
env: env:
@ -22,7 +22,7 @@ jobs:
run_common: run_common:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false

View file

@ -5,10 +5,10 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ${{ github.event.label.name == 'run-checks' }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true cancel-in-progress: true
env: env:
@ -22,7 +22,7 @@ jobs:
run_common: run_common:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false

View file

@ -5,10 +5,10 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ${{ github.event.label.name == 'run-checks' }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true cancel-in-progress: true
env: env:
@ -22,7 +22,7 @@ jobs:
run_common: run_common:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false

View file

@ -5,7 +5,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
@ -23,7 +23,7 @@ jobs:
run_qdrant_integration_test: run_qdrant_integration_test:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:

View file

@ -5,7 +5,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
@ -23,7 +23,7 @@ jobs:
run_weaviate_integration_test: run_weaviate_integration_test:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:

View file

@ -109,24 +109,34 @@ import asyncio
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
async def main(): async def main():
await cognee.prune.prune_data() # Reset cognee data # Reset cognee data
await cognee.prune.prune_system(metadata=True) # Reset cognee system state await cognee.prune.prune_data()
# Reset cognee system state
await cognee.prune.prune_system(metadata=True)
text = """ text = """
Natural language processing (NLP) is an interdisciplinary Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval. subfield of computer science and information retrieval.
""" """
await cognee.add(text) # Add text to cognee # Add text to cognee
await cognee.cognify() # Use LLMs and cognee to create knowledge graph await cognee.add(text)
search_results = await cognee.search( # Search cognee for insights # Use LLMs and cognee to create knowledge graph
await cognee.cognify()
# Search cognee for insights
search_results = await cognee.search(
SearchType.INSIGHTS, SearchType.INSIGHTS,
{'query': 'Tell me about NLP'} "Tell me about NLP",
) )
for result_text in search_results: # Display results # Display results
for result_text in search_results:
print(result_text) print(result_text)
# natural_language_processing is_a field
# natural_language_processing is_subfield_of computer_science
# natural_language_processing is_subfield_of information_retrieval
asyncio.run(main()) asyncio.run(main())
``` ```

View file

@ -0,0 +1,110 @@
import asyncio
import logging
from typing import Union
from cognee.shared.SourceCodeGraph import SourceCodeGraph
from cognee.shared.utils import send_telemetry
from cognee.modules.data.models import Dataset, Data
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines import run_tasks
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
from cognee.tasks.graph import extract_graph_from_code
from cognee.tasks.storage import add_data_points
logger = logging.getLogger("code_graph_pipeline")
update_status_lock = asyncio.Lock()
class PermissionDeniedException(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None):
if user is None:
user = await get_default_user()
existing_datasets = await get_datasets(user.id)
if datasets is None or len(datasets) == 0:
# If no datasets are provided, cognify all existing datasets.
datasets = existing_datasets
if type(datasets[0]) == str:
datasets = await get_datasets_by_name(datasets, user.id)
existing_datasets_map = {
generate_dataset_name(dataset.name): True for dataset in existing_datasets
}
awaitables = []
for dataset in datasets:
dataset_name = generate_dataset_name(dataset.name)
if dataset_name in existing_datasets_map:
awaitables.append(run_pipeline(dataset, user))
return await asyncio.gather(*awaitables)
async def run_pipeline(dataset: Dataset, user: User):
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
document_ids_str = [str(document.id) for document in data_documents]
dataset_id = dataset.id
dataset_name = generate_dataset_name(dataset.name)
send_telemetry("code_graph_pipeline EXECUTION STARTED", user.id)
async with update_status_lock:
task_status = await get_pipeline_status([dataset_id])
if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
logger.info("Dataset %s is already being processed.", dataset_name)
return
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
try:
tasks = [
Task(classify_documents),
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
Task(add_data_points, task_config = { "batch_size": 10 }),
Task(extract_graph_from_code, graph_model = SourceCodeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
]
pipeline = run_tasks(tasks, data_documents, "code_graph_pipeline")
async for result in pipeline:
print(result)
send_telemetry("code_graph_pipeline EXECUTION COMPLETED", user.id)
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_COMPLETED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
except Exception as error:
send_telemetry("code_graph_pipeline EXECUTION ERRORED", user.id)
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_ERRORED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
raise error
def generate_dataset_name(dataset_name: str) -> str:
return dataset_name.replace(".", "_").replace(" ", "_")

View file

@ -9,21 +9,15 @@ from cognee.modules.data.models import Dataset, Data
from cognee.modules.data.methods.get_dataset_data import get_dataset_data from cognee.modules.data.methods.get_dataset_data import get_dataset_data
from cognee.modules.data.methods import get_datasets, get_datasets_by_name from cognee.modules.data.methods import get_datasets, get_datasets_by_name
from cognee.modules.pipelines.tasks.Task import Task from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines import run_tasks, run_tasks_parallel from cognee.modules.pipelines import run_tasks
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.models import PipelineRunStatus from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
from cognee.tasks import chunk_naive_llm_classifier, \ from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
chunk_remove_disconnected, \ from cognee.tasks.graph import extract_graph_from_data
infer_data_ontology, \ from cognee.tasks.storage import add_data_points
save_chunks_to_store, \
chunk_update_check, \
chunks_into_graph, \
source_documents_to_chunks, \
check_permissions_on_documents, \
classify_documents
from cognee.tasks.summarization import summarize_text from cognee.tasks.summarization import summarize_text
logger = logging.getLogger("cognify.v2") logger = logging.getLogger("cognify.v2")
@ -87,31 +81,17 @@ async def run_cognify_pipeline(dataset: Dataset, user: User):
try: try:
cognee_config = get_cognify_config() cognee_config = get_cognify_config()
root_node_id = None
tasks = [ tasks = [
Task(classify_documents), Task(classify_documents),
Task(check_permissions_on_documents, user = user, permissions = ["write"]), Task(check_permissions_on_documents, user = user, permissions = ["write"]),
Task(infer_data_ontology, root_node_id = root_node_id, ontology_model = KnowledgeGraph), Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
Task(source_documents_to_chunks, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type Task(add_data_points, task_config = { "batch_size": 10 }),
Task(chunks_into_graph, graph_model = KnowledgeGraph, collection_name = "entities", task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks and attach it to chunk nodes Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
Task(chunk_update_check, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
Task( Task(
save_chunks_to_store, summarize_text,
collection_name = "chunks", summarization_model = cognee_config.summarization_model,
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other) task_config = { "batch_size": 10 }
run_tasks_parallel([ ),
Task(
summarize_text,
summarization_model = cognee_config.summarization_model,
collection_name = "summaries",
),
Task(
chunk_naive_llm_classifier,
classification_model = cognee_config.classification_model,
),
]),
Task(chunk_remove_disconnected), # Remove the obsolete document chunks.
] ]
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline") pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")

View file

@ -5,7 +5,7 @@ from cognee.shared.utils import send_telemetry
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
from cognee.modules.users.permissions.methods import get_document_ids_for_user from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.tasks.chunking import query_chunks from cognee.tasks.chunks import query_chunks
from cognee.tasks.graph import query_graph_connections from cognee.tasks.graph import query_graph_connections
from cognee.tasks.summarization import query_summaries from cognee.tasks.summarization import query_summaries

View file

@ -1,198 +0,0 @@
""" FalcorDB Adapter for Graph Database"""
import json
import logging
from typing import Optional, Any, List, Dict
from contextlib import asynccontextmanager
from falkordb.asyncio import FalkorDB
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
logger = logging.getLogger("FalcorDBAdapter")
class FalcorDBAdapter(GraphDBInterface):
def __init__(
self,
graph_database_url: str,
graph_database_username: str,
graph_database_password: str,
graph_database_port: int,
driver: Optional[Any] = None,
graph_name: str = "DefaultGraph",
):
self.driver = FalkorDB(
host = graph_database_url,
port = graph_database_port)
self.graph_name = graph_name
async def query(
self,
query: str,
params: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
try:
selected_graph = self.driver.select_graph(self.graph_name)
result = await selected_graph.query(query)
return result.result_set
except Exception as error:
logger.error("Falkor query error: %s", error, exc_info = True)
raise error
async def graph(self):
return self.driver
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
node_id = node_id.replace(":", "_")
serialized_properties = self.serialize_properties(node_properties)
if "name" not in serialized_properties:
serialized_properties["name"] = node_id
# serialized_properties["created_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# serialized_properties["updated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# properties = ", ".join(f"{property_name}: ${property_name}" for property_name in serialized_properties.keys())
query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
ON CREATE SET node += $properties
RETURN ID(node) AS internal_id, node.id AS nodeId"""
params = {
"node_id": node_id,
"properties": serialized_properties,
}
return await self.query(query, params)
async def add_nodes(self, nodes: list[tuple[str, dict[str, Any]]]) -> None:
for node in nodes:
node_id, node_properties = node
node_id = node_id.replace(":", "_")
await self.add_node(
node_id = node_id,
node_properties = node_properties,
)
async def extract_node_description(self, node_id: str):
query = """MATCH (n)-[r]->(m)
WHERE n.id = $node_id
AND NOT m.id CONTAINS 'DefaultGraphModel'
RETURN m
"""
result = await self.query(query, dict(node_id = node_id))
descriptions = []
for node in result:
# Assuming 'm' is a consistent key in your data structure
attributes = node.get("m", {})
# Ensure all required attributes are present
if all(key in attributes for key in ["id", "layer_id", "description"]):
descriptions.append({
"id": attributes["id"],
"layer_id": attributes["layer_id"],
"description": attributes["description"],
})
return descriptions
async def get_layer_nodes(self):
query = """MATCH (node) WHERE node.layer_id IS NOT NULL
RETURN node"""
return [result["node"] for result in (await self.query(query))]
async def extract_node(self, node_id: str):
results = self.extract_nodes([node_id])
return results[0] if len(results) > 0 else None
async def extract_nodes(self, node_ids: List[str]):
query = """
UNWIND $node_ids AS id
MATCH (node {id: id})
RETURN node"""
params = {
"node_ids": node_ids
}
results = await self.query(query, params)
return results
async def delete_node(self, node_id: str):
node_id = id.replace(":", "_")
query = f"MATCH (node:`{node_id}` {{id: $node_id}}) DETACH DELETE n"
params = { "node_id": node_id }
return await self.query(query, params)
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
serialized_properties = self.serialize_properties(edge_properties)
from_node = from_node.replace(":", "_")
to_node = to_node.replace(":", "_")
query = f"""MATCH (from_node:`{from_node}` {{id: $from_node}}), (to_node:`{to_node}` {{id: $to_node}})
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
SET r += $properties
RETURN r"""
params = {
"from_node": from_node,
"to_node": to_node,
"properties": serialized_properties
}
return await self.query(query, params)
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
# edges_data = []
for edge in edges:
from_node, to_node, relationship_name, edge_properties = edge
from_node = from_node.replace(":", "_")
to_node = to_node.replace(":", "_")
await self.add_edge(
from_node = from_node,
to_node = to_node,
relationship_name = relationship_name,
edge_properties = edge_properties
)
async def filter_nodes(self, search_criteria):
query = f"""MATCH (node)
WHERE node.id CONTAINS '{search_criteria}'
RETURN node"""
return await self.query(query)
async def delete_graph(self):
query = """MATCH (node)
DETACH DELETE node;"""
return await self.query(query)
def serialize_properties(self, properties = dict()):
return {
property_key: json.dumps(property_value)
if isinstance(property_value, (dict, list))
else property_value for property_key, property_value in properties.items()
}

View file

@ -2,7 +2,6 @@
from .config import get_graph_config from .config import get_graph_config
from .graph_db_interface import GraphDBInterface from .graph_db_interface import GraphDBInterface
from .networkx.adapter import NetworkXAdapter
async def get_graph_engine() -> GraphDBInterface : async def get_graph_engine() -> GraphDBInterface :
@ -21,19 +20,19 @@ async def get_graph_engine() -> GraphDBInterface :
except: except:
pass pass
elif config.graph_database_provider == "falkorb": elif config.graph_database_provider == "falkordb":
try: from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from .falkordb.adapter import FalcorDBAdapter from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalcorDBAdapter( embedding_engine = get_embedding_engine()
graph_database_url = config.graph_database_url,
graph_database_username = config.graph_database_username,
graph_database_password = config.graph_database_password,
graph_database_port = config.graph_database_port
)
except:
pass
return FalkorDBAdapter(
database_url = config.graph_database_url,
database_port = config.graph_database_port,
embedding_engine = embedding_engine,
)
from .networkx.adapter import NetworkXAdapter
graph_client = NetworkXAdapter(filename = config.graph_file_path) graph_client = NetworkXAdapter(filename = config.graph_file_path)
if graph_client.graph is None: if graph_client.graph is None:

View file

@ -3,7 +3,7 @@ from abc import abstractmethod
class GraphDBInterface(Protocol): class GraphDBInterface(Protocol):
@abstractmethod @abstractmethod
async def graph(self): async def query(self, query: str, params: dict):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod

View file

@ -1,13 +1,14 @@
""" Neo4j Adapter for Graph Database""" """ Neo4j Adapter for Graph Database"""
import json
import logging import logging
import asyncio import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict from typing import Optional, Any, List, Dict
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from uuid import UUID
from neo4j import AsyncSession from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError from neo4j.exceptions import Neo4jError
from networkx import predecessor from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
logger = logging.getLogger("Neo4jAdapter") logger = logging.getLogger("Neo4jAdapter")
@ -41,17 +42,13 @@ class Neo4jAdapter(GraphDBInterface):
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
try: try:
async with self.get_session() as session: async with self.get_session() as session:
result = await session.run(query, parameters=params) result = await session.run(query, parameters = params)
data = await result.data() data = await result.data()
await self.close()
return data return data
except Neo4jError as error: except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info = True) logger.error("Neo4j query error: %s", error, exc_info = True)
raise error raise error
async def graph(self):
return await self.get_session()
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
results = self.query( results = self.query(
""" """
@ -63,73 +60,40 @@ class Neo4jAdapter(GraphDBInterface):
) )
return results[0]["node_exists"] if len(results) > 0 else False return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None): async def add_node(self, node: DataPoint):
node_id = node_id.replace(":", "_") serialized_properties = self.serialize_properties(node.model_dump())
serialized_properties = self.serialize_properties(node_properties) query = dedent("""MERGE (node {id: $node_id})
ON CREATE SET node += $properties, node.updated_at = timestamp()
if "name" not in serialized_properties: ON MATCH SET node += $properties, node.updated_at = timestamp()
serialized_properties["name"] = node_id RETURN ID(node) AS internal_id, node.id AS nodeId""")
query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
ON CREATE SET node += $properties
RETURN ID(node) AS internal_id, node.id AS nodeId"""
params = { params = {
"node_id": node_id, "node_id": str(node.id),
"properties": serialized_properties, "properties": serialized_properties,
} }
return await self.query(query, params) return await self.query(query, params)
async def add_nodes(self, nodes: list[tuple[str, dict[str, Any]]]) -> None: async def add_nodes(self, nodes: list[DataPoint]) -> None:
query = """ query = """
UNWIND $nodes AS node UNWIND $nodes AS node
MERGE (n {id: node.node_id}) MERGE (n {id: node.node_id})
ON CREATE SET n += node.properties ON CREATE SET n += node.properties, n.updated_at = timestamp()
ON MATCH SET n += node.properties, n.updated_at = timestamp()
WITH n, node.node_id AS label WITH n, node.node_id AS label
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
""" """
nodes = [{ nodes = [{
"node_id": node_id, "node_id": str(node.id),
"properties": self.serialize_properties(node_properties), "properties": self.serialize_properties(node.model_dump()),
} for (node_id, node_properties) in nodes] } for node in nodes]
results = await self.query(query, dict(nodes = nodes)) results = await self.query(query, dict(nodes = nodes))
return results return results
async def extract_node_description(self, node_id: str):
query = """MATCH (n)-[r]->(m)
WHERE n.id = $node_id
AND NOT m.id CONTAINS 'DefaultGraphModel'
RETURN m
"""
result = await self.query(query, dict(node_id = node_id))
descriptions = []
for node in result:
# Assuming 'm' is a consistent key in your data structure
attributes = node.get("m", {})
# Ensure all required attributes are present
if all(key in attributes for key in ["id", "layer_id", "description"]):
descriptions.append({
"id": attributes["id"],
"layer_id": attributes["layer_id"],
"description": attributes["description"],
})
return descriptions
async def get_layer_nodes(self):
query = """MATCH (node) WHERE node.layer_id IS NOT NULL
RETURN node"""
return [result["node"] for result in (await self.query(query))]
async def extract_node(self, node_id: str): async def extract_node(self, node_id: str):
results = await self.extract_nodes([node_id]) results = await self.extract_nodes([node_id])
@ -170,13 +134,20 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params) return await self.query(query, params)
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool: async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
query = f""" query = """
MATCH (from_node:`{from_node}`)-[relationship:`{edge_label}`]->(to_node:`{to_node}`) MATCH (from_node)-[relationship]->(to_node)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
RETURN COUNT(relationship) > 0 AS edge_exists RETURN COUNT(relationship) > 0 AS edge_exists
""" """
edge_exists = await self.query(query) params = {
"from_node_id": str(from_node),
"to_node_id": str(to_node),
"edge_label": edge_label,
}
edge_exists = await self.query(query, params)
return edge_exists return edge_exists
async def has_edges(self, edges): async def has_edges(self, edges):
@ -190,8 +161,8 @@ class Neo4jAdapter(GraphDBInterface):
try: try:
params = { params = {
"edges": [{ "edges": [{
"from_node": edge[0], "from_node": str(edge[0]),
"to_node": edge[1], "to_node": str(edge[1]),
"relationship_name": edge[2], "relationship_name": edge[2],
} for edge in edges], } for edge in edges],
} }
@ -203,21 +174,21 @@ class Neo4jAdapter(GraphDBInterface):
raise error raise error
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}): async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
serialized_properties = self.serialize_properties(edge_properties) serialized_properties = self.serialize_properties(edge_properties)
from_node = from_node.replace(":", "_")
to_node = to_node.replace(":", "_")
query = f"""MATCH (from_node:`{from_node}` query = dedent("""MATCH (from_node {id: $from_node}),
{{id: $from_node}}), (to_node {id: $to_node})
(to_node:`{to_node}` {{id: $to_node}}) MERGE (from_node)-[r]->(to_node)
MERGE (from_node)-[r:`{relationship_name}`]->(to_node) ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
SET r += $properties ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r""" RETURN r
""")
params = { params = {
"from_node": from_node, "from_node": str(from_node),
"to_node": to_node, "to_node": str(to_node),
"relationship_name": relationship_name,
"properties": serialized_properties "properties": serialized_properties
} }
@ -234,13 +205,13 @@ class Neo4jAdapter(GraphDBInterface):
""" """
edges = [{ edges = [{
"from_node": edge[0], "from_node": str(edge[0]),
"to_node": edge[1], "to_node": str(edge[1]),
"relationship_name": edge[2], "relationship_name": edge[2],
"properties": { "properties": {
**(edge[3] if edge[3] else {}), **(edge[3] if edge[3] else {}),
"source_node_id": edge[0], "source_node_id": str(edge[0]),
"target_node_id": edge[1], "target_node_id": str(edge[1]),
}, },
} for edge in edges] } for edge in edges]
@ -300,14 +271,6 @@ class Neo4jAdapter(GraphDBInterface):
return results[0]["ids"] if len(results) > 0 else [] return results[0]["ids"] if len(results) > 0 else []
async def filter_nodes(self, search_criteria):
query = f"""MATCH (node)
WHERE node.id CONTAINS '{search_criteria}'
RETURN node"""
return await self.query(query)
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]: async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None: if edge_label is not None:
query = """ query = """
@ -379,7 +342,7 @@ class Neo4jAdapter(GraphDBInterface):
return predecessors + successors return predecessors + successors
async def get_connections(self, node_id: str) -> list: async def get_connections(self, node_id: UUID) -> list:
predecessors_query = """ predecessors_query = """
MATCH (node)<-[relation]-(neighbour) MATCH (node)<-[relation]-(neighbour)
WHERE node.id = $node_id WHERE node.id = $node_id
@ -392,8 +355,8 @@ class Neo4jAdapter(GraphDBInterface):
""" """
predecessors, successors = await asyncio.gather( predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = node_id)), self.query(predecessors_query, dict(node_id = str(node_id))),
self.query(successors_query, dict(node_id = node_id)), self.query(successors_query, dict(node_id = str(node_id))),
) )
connections = [] connections = []
@ -438,15 +401,22 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query) return await self.query(query)
def serialize_properties(self, properties = dict()): def serialize_properties(self, properties = dict()):
return { serialized_properties = {}
property_key: json.dumps(property_value)
if isinstance(property_value, (dict, list)) for property_key, property_value in properties.items():
else property_value for property_key, property_value in properties.items() if isinstance(property_value, UUID):
} serialized_properties[property_key] = str(property_value)
continue
serialized_properties[property_key] = property_value
return serialized_properties
async def get_graph_data(self): async def get_graph_data(self):
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties" query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
result = await self.query(query) result = await self.query(query)
nodes = [( nodes = [(
record["properties"]["id"], record["properties"]["id"],
record["properties"], record["properties"],

View file

@ -1,14 +1,19 @@
"""Adapter for NetworkX graph database.""" """Adapter for NetworkX graph database."""
from datetime import datetime, timezone
import os import os
import json import json
import asyncio import asyncio
import logging import logging
from re import A
from typing import Dict, Any, List from typing import Dict, Any, List
from uuid import UUID
import aiofiles import aiofiles
import aiofiles.os as aiofiles_os import aiofiles.os as aiofiles_os
import networkx as nx import networkx as nx
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder
logger = logging.getLogger("NetworkXAdapter") logger = logging.getLogger("NetworkXAdapter")
@ -25,29 +30,38 @@ class NetworkXAdapter(GraphDBInterface):
def __init__(self, filename = "cognee_graph.pkl"): def __init__(self, filename = "cognee_graph.pkl"):
self.filename = filename self.filename = filename
async def get_graph_data(self):
await self.load_graph_from_file()
return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True)))
async def query(self, query: str, params: dict):
pass
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
return self.graph.has_node(node_id) return self.graph.has_node(node_id)
async def add_node( async def add_node(
self, self,
node_id: str, node: DataPoint,
node_properties,
) -> None: ) -> None:
if not self.graph.has_node(id): self.graph.add_node(node.id, **node.model_dump())
self.graph.add_node(node_id, **node_properties)
await self.save_graph_to_file(self.filename) await self.save_graph_to_file(self.filename)
async def add_nodes( async def add_nodes(
self, self,
nodes: List[tuple[str, dict]], nodes: list[DataPoint],
) -> None: ) -> None:
nodes = [(node.id, node.model_dump()) for node in nodes]
self.graph.add_nodes_from(nodes) self.graph.add_nodes_from(nodes)
await self.save_graph_to_file(self.filename) await self.save_graph_to_file(self.filename)
async def get_graph(self): async def get_graph(self):
return self.graph return self.graph
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool: async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
return self.graph.has_edge(from_node, to_node, key = edge_label) return self.graph.has_edge(from_node, to_node, key = edge_label)
@ -55,18 +69,20 @@ class NetworkXAdapter(GraphDBInterface):
result = [] result = []
for (from_node, to_node, edge_label) in edges: for (from_node, to_node, edge_label) in edges:
if await self.has_edge(from_node, to_node, edge_label): if self.graph.has_edge(from_node, to_node, edge_label):
result.append((from_node, to_node, edge_label)) result.append((from_node, to_node, edge_label))
return result return result
async def add_edge( async def add_edge(
self, self,
from_node: str, from_node: str,
to_node: str, to_node: str,
relationship_name: str, relationship_name: str,
edge_properties: Dict[str, Any] = None, edge_properties: Dict[str, Any] = {},
) -> None: ) -> None:
edge_properties["updated_at"] = datetime.now(timezone.utc)
self.graph.add_edge(from_node, to_node, key = relationship_name, **(edge_properties if edge_properties else {})) self.graph.add_edge(from_node, to_node, key = relationship_name, **(edge_properties if edge_properties else {}))
await self.save_graph_to_file(self.filename) await self.save_graph_to_file(self.filename)
@ -74,22 +90,29 @@ class NetworkXAdapter(GraphDBInterface):
self, self,
edges: tuple[str, str, str, dict], edges: tuple[str, str, str, dict],
) -> None: ) -> None:
edges = [(edge[0], edge[1], edge[2], {
**(edge[3] if len(edge) == 4 else {}),
"updated_at": datetime.now(timezone.utc),
}) for edge in edges]
self.graph.add_edges_from(edges) self.graph.add_edges_from(edges)
await self.save_graph_to_file(self.filename) await self.save_graph_to_file(self.filename)
async def get_edges(self, node_id: str): async def get_edges(self, node_id: str):
return list(self.graph.in_edges(node_id, data = True)) + list(self.graph.out_edges(node_id, data = True)) return list(self.graph.in_edges(node_id, data = True)) + list(self.graph.out_edges(node_id, data = True))
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
"""Asynchronously delete a node from the graph if it exists.""" """Asynchronously delete a node from the graph if it exists."""
if self.graph.has_node(id): if self.graph.has_node(node_id):
self.graph.remove_node(id) self.graph.remove_node(node_id)
await self.save_graph_to_file(self.filename) await self.save_graph_to_file(self.filename)
async def delete_nodes(self, node_ids: List[str]) -> None: async def delete_nodes(self, node_ids: List[str]) -> None:
self.graph.remove_nodes_from(node_ids) self.graph.remove_nodes_from(node_ids)
await self.save_graph_to_file(self.filename) await self.save_graph_to_file(self.filename)
async def get_disconnected_nodes(self) -> List[str]: async def get_disconnected_nodes(self) -> List[str]:
connected_components = list(nx.weakly_connected_components(self.graph)) connected_components = list(nx.weakly_connected_components(self.graph))
@ -102,33 +125,6 @@ class NetworkXAdapter(GraphDBInterface):
return disconnected_nodes return disconnected_nodes
async def extract_node_description(self, node_id: str) -> Dict[str, Any]:
descriptions = []
if self.graph.has_node(node_id):
# Get the attributes of the node
for neighbor in self.graph.neighbors(node_id):
# Get the attributes of the neighboring node
attributes = self.graph.nodes[neighbor]
# Ensure all required attributes are present before extracting description
if all(key in attributes for key in ["id", "layer_id", "description"]):
descriptions.append({
"id": attributes["id"],
"layer_id": attributes["layer_id"],
"description": attributes["description"],
})
return descriptions
async def get_layer_nodes(self):
layer_nodes = []
for _, data in self.graph.nodes(data = True):
if "layer_id" in data:
layer_nodes.append(data)
return layer_nodes
async def extract_node(self, node_id: str) -> dict: async def extract_node(self, node_id: str) -> dict:
if self.graph.has_node(node_id): if self.graph.has_node(node_id):
@ -139,7 +135,7 @@ class NetworkXAdapter(GraphDBInterface):
async def extract_nodes(self, node_ids: List[str]) -> List[dict]: async def extract_nodes(self, node_ids: List[str]) -> List[dict]:
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)] return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list: async def get_predecessors(self, node_id: UUID, edge_label: str = None) -> list:
if self.graph.has_node(node_id): if self.graph.has_node(node_id):
if edge_label is None: if edge_label is None:
return [ return [
@ -155,7 +151,7 @@ class NetworkXAdapter(GraphDBInterface):
return nodes return nodes
async def get_successors(self, node_id: str, edge_label: str = None) -> list: async def get_successors(self, node_id: UUID, edge_label: str = None) -> list:
if self.graph.has_node(node_id): if self.graph.has_node(node_id):
if edge_label is None: if edge_label is None:
return [ return [
@ -184,13 +180,13 @@ class NetworkXAdapter(GraphDBInterface):
return neighbours return neighbours
async def get_connections(self, node_id: str) -> list: async def get_connections(self, node_id: UUID) -> list:
if not self.graph.has_node(node_id): if not self.graph.has_node(node_id):
return [] return []
node = self.graph.nodes[node_id] node = self.graph.nodes[node_id]
if "uuid" not in node: if "id" not in node:
return [] return []
predecessors, successors = await asyncio.gather( predecessors, successors = await asyncio.gather(
@ -201,14 +197,14 @@ class NetworkXAdapter(GraphDBInterface):
connections = [] connections = []
for neighbor in predecessors: for neighbor in predecessors:
if "uuid" in neighbor: if "id" in neighbor:
edge_data = self.graph.get_edge_data(neighbor["uuid"], node["uuid"]) edge_data = self.graph.get_edge_data(neighbor["id"], node["id"])
for edge_properties in edge_data.values(): for edge_properties in edge_data.values():
connections.append((neighbor, edge_properties, node)) connections.append((neighbor, edge_properties, node))
for neighbor in successors: for neighbor in successors:
if "uuid" in neighbor: if "id" in neighbor:
edge_data = self.graph.get_edge_data(node["uuid"], neighbor["uuid"]) edge_data = self.graph.get_edge_data(node["id"], neighbor["id"])
for edge_properties in edge_data.values(): for edge_properties in edge_data.values():
connections.append((node, edge_properties, neighbor)) connections.append((node, edge_properties, neighbor))
@ -240,7 +236,7 @@ class NetworkXAdapter(GraphDBInterface):
graph_data = nx.readwrite.json_graph.node_link_data(self.graph) graph_data = nx.readwrite.json_graph.node_link_data(self.graph)
async with aiofiles.open(file_path, "w") as file: async with aiofiles.open(file_path, "w") as file:
await file.write(json.dumps(graph_data)) await file.write(json.dumps(graph_data, cls = JSONEncoder))
async def load_graph_from_file(self, file_path: str = None): async def load_graph_from_file(self, file_path: str = None):
@ -254,6 +250,29 @@ class NetworkXAdapter(GraphDBInterface):
if os.path.exists(file_path): if os.path.exists(file_path):
async with aiofiles.open(file_path, "r") as file: async with aiofiles.open(file_path, "r") as file:
graph_data = json.loads(await file.read()) graph_data = json.loads(await file.read())
for node in graph_data["nodes"]:
try:
node["id"] = UUID(node["id"])
except:
pass
if "updated_at" in node:
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
for edge in graph_data["links"]:
try:
source_id = UUID(edge["source"])
target_id = UUID(edge["target"])
edge["source"] = source_id
edge["target"] = target_id
edge["source_node_id"] = source_id
edge["target_node_id"] = target_id
except:
pass
if "updated_at" in edge:
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data) self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
else: else:
# Log that the file does not exist and an empty graph is initialized # Log that the file does not exist and an empty graph is initialized
@ -265,9 +284,11 @@ class NetworkXAdapter(GraphDBInterface):
os.makedirs(file_dir, exist_ok = True) os.makedirs(file_dir, exist_ok = True)
await self.save_graph_to_file(file_path) await self.save_graph_to_file(file_path)
except Exception: except Exception:
logger.error("Failed to load graph from file: %s", file_path) logger.error("Failed to load graph from file: %s", file_path)
async def delete_graph(self, file_path: str = None): async def delete_graph(self, file_path: str = None):
"""Asynchronously delete the graph file from the filesystem.""" """Asynchronously delete the graph file from the filesystem."""
if file_path is None: if file_path is None:

View file

@ -0,0 +1,267 @@
import asyncio
from textwrap import dedent
from typing import Any
from uuid import UUID
from falkordb import FalkorDB
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface
class IndexSchema(DataPoint):
text: str
_metadata: dict = {
"index_fields": ["text"]
}
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
def __init__(
self,
database_url: str,
database_port: int,
embedding_engine = EmbeddingEngine,
):
self.driver = FalkorDB(
host = database_url,
port = database_port,
)
self.embedding_engine = embedding_engine
self.graph_name = "cognee_graph"
def query(self, query: str, params: dict = {}):
graph = self.driver.select_graph(self.graph_name)
try:
result = graph.query(query, params)
return result
except Exception as e:
print(f"Error executing query: {e}")
raise e
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def stringify_properties(self, properties: dict, vectorize_fields = []) -> str:
async def get_value(key, value):
return f"'{value}'" if key not in vectorize_fields else await self.get_vectorized_value(value)
return ",".join([f"{key}:{await get_value(key, value)}" for key, value in properties.items()])
async def get_vectorized_value(self, value: Any) -> str:
vector = (await self.embed_data([value]))[0]
return f"vecf32({vector})"
async def create_data_point_query(self, data_point: DataPoint):
node_label = type(data_point).__name__
node_properties = await self.stringify_properties(
data_point.model_dump(),
data_point._metadata["index_fields"],
# data_point._metadata["index_fields"] if hasattr(data_point, "_metadata") else [],
)
return dedent(f"""
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
ON CREATE SET node += ({{{node_properties}}})
ON CREATE SET node.updated_at = timestamp()
ON MATCH SET node += ({{{node_properties}}})
ON MATCH SET node.updated_at = timestamp()
""").strip()
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
properties = await self.stringify_properties(edge[3])
properties = f"{{{properties}}}"
return dedent(f"""
MERGE (source {{id:'{edge[0]}'}})
MERGE (target {{id: '{edge[1]}'}})
MERGE (source)-[edge:{edge[2]} {properties}]->(target)
ON MATCH SET edge.updated_at = timestamp()
ON CREATE SET edge.updated_at = timestamp()
""").strip()
async def create_collection(self, collection_name: str):
pass
async def has_collection(self, collection_name: str) -> bool:
collections = self.driver.list_graphs()
return collection_name in collections
async def create_data_points(self, data_points: list[DataPoint]):
queries = [await self.create_data_point_query(data_point) for data_point in data_points]
for query in queries:
self.query(query)
async def create_vector_index(self, index_name: str, index_property_name: str):
graph = self.driver.select_graph(self.graph_name)
if not self.has_vector_index(graph, index_name, index_property_name):
graph.create_node_vector_index(index_name, index_property_name, dim = self.embedding_engine.get_vector_size())
def has_vector_index(self, graph, index_name: str, index_property_name: str) -> bool:
try:
indices = graph.list_indices()
return any([(index[0] == index_name and index_property_name in index[1]) for index in indices.result_set])
except:
return False
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
pass
async def add_node(self, node: DataPoint):
await self.create_data_points([node])
async def add_nodes(self, nodes: list[DataPoint]):
await self.create_data_points(nodes)
async def add_edge(self, edge: tuple[str, str, str, dict]):
query = await self.create_edge_query(edge)
self.query(query)
async def add_edges(self, edges: list[tuple[str, str, str, dict]]):
queries = [await self.create_edge_query(edge) for edge in edges]
for query in queries:
self.query(query)
async def has_edges(self, edges):
query = dedent("""
UNWIND $edges AS edge
MATCH (a)-[r]->(b)
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
""").strip()
params = {
"edges": [{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
} for edge in edges],
}
results = self.query(query, params).result_set
return [result["edge_exists"] for result in results]
async def retrieve(self, data_point_ids: list[str]):
return self.query(
f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
{
"node_ids": data_point_ids,
},
)
async def extract_node(self, data_point_id: str):
return await self.retrieve([data_point_id])
async def extract_nodes(self, data_point_ids: list[str]):
return await self.retrieve(data_point_ids)
async def get_connections(self, node_id: UUID) -> list:
predecessors_query = """
MATCH (node)<-[relation]-(neighbour)
WHERE node.id = $node_id
RETURN neighbour, relation, node
"""
successors_query = """
MATCH (node)-[relation]->(neighbour)
WHERE node.id = $node_id
RETURN node, relation, neighbour
"""
predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = node_id)),
self.query(successors_query, dict(node_id = node_id)),
)
connections = []
for neighbour in predecessors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
for neighbour in successors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
return connections
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: list[float] = None,
limit: int = 10,
with_vector: bool = False,
):
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embed_data([query_text]))[0]
query = dedent(f"""
CALL db.idx.vector.queryNodes(
{collection_name},
'text',
{limit},
vecf32({query_vector})
) YIELD node, score
""").strip()
result = self.query(query)
return result
async def batch_search(
self,
collection_name: str,
query_texts: list[str],
limit: int = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[self.search(
collection_name = collection_name,
query_vector = query_vector,
limit = limit,
with_vector = with_vectors,
) for query_vector in query_vectors]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
return self.query(
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
{
"node_ids": data_point_ids,
},
)
async def delete_node(self, collection_name: str, data_point_id: str):
return await self.delete_data_points([data_point_id])
async def delete_nodes(self, collection_name: str, data_point_ids: list[str]):
self.delete_data_points(data_point_ids)
async def delete_graph(self):
try:
graph = self.driver.select_graph(self.graph_name)
indices = graph.list_indices()
for index in indices.result_set:
for field in index[1]:
graph.drop_node_vector_index(index[0], field)
graph.delete()
except Exception as e:
print(f"Error deleting graph: {e}")
async def prune(self):
self.delete_graph()

View file

@ -1,4 +1,3 @@
from .models.DataPoint import DataPoint
from .models.VectorConfig import VectorConfig from .models.VectorConfig import VectorConfig
from .models.CollectionConfig import CollectionConfig from .models.CollectionConfig import CollectionConfig
from .vector_db_interface import VectorDBInterface from .vector_db_interface import VectorDBInterface

View file

@ -8,6 +8,7 @@ class VectorConfig(BaseSettings):
os.path.join(get_absolute_path(".cognee_system"), "databases"), os.path.join(get_absolute_path(".cognee_system"), "databases"),
"cognee.lancedb" "cognee.lancedb"
) )
vector_db_port: int = 1234
vector_db_key: str = "" vector_db_key: str = ""
vector_db_provider: str = "lancedb" vector_db_provider: str = "lancedb"
@ -16,6 +17,7 @@ class VectorConfig(BaseSettings):
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
"vector_db_url": self.vector_db_url, "vector_db_url": self.vector_db_url,
"vector_db_port": self.vector_db_port,
"vector_db_key": self.vector_db_key, "vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider, "vector_db_provider": self.vector_db_provider,
} }

View file

@ -1,9 +1,8 @@
from typing import Dict from typing import Dict
from ..relational.config import get_relational_config
class VectorConfig(Dict): class VectorConfig(Dict):
vector_db_url: str vector_db_url: str
vector_db_port: str
vector_db_key: str vector_db_key: str
vector_db_provider: str vector_db_provider: str
@ -29,6 +28,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
embedding_engine = embedding_engine embedding_engine = embedding_engine
) )
elif config["vector_db_provider"] == "pgvector": elif config["vector_db_provider"] == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config
from .pgvector.PGVectorAdapter import PGVectorAdapter from .pgvector.PGVectorAdapter import PGVectorAdapter
# Get configuration for postgres database # Get configuration for postgres database
@ -43,9 +43,18 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
) )
return PGVectorAdapter(connection_string, return PGVectorAdapter(
config["vector_db_key"], connection_string,
embedding_engine config["vector_db_key"],
embedding_engine,
)
elif config["vector_db_provider"] == "falkordb":
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalkorDBAdapter(
database_url = config["vector_db_url"],
database_port = config["vector_db_port"],
embedding_engine = embedding_engine,
) )
else: else:
from .lancedb.LanceDBAdapter import LanceDBAdapter from .lancedb.LanceDBAdapter import LanceDBAdapter

View file

@ -1,57 +0,0 @@
from typing import List, Dict, Optional, Any
from falkordb import FalkorDB
from qdrant_client import AsyncQdrantClient, models
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine
class FalcorDBAdapter(VectorDBInterface):
def __init__(
self,
graph_database_url: str,
graph_database_username: str,
graph_database_password: str,
graph_database_port: int,
driver: Optional[Any] = None,
embedding_engine = EmbeddingEngine,
graph_name: str = "DefaultGraph",
):
self.driver = FalkorDB(
host = graph_database_url,
port = graph_database_port)
self.graph_name = graph_name
self.embedding_engine = embedding_engine
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def create_collection(self, collection_name: str, payload_schema = None):
pass
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
pass
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
pass
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 10,
with_vector: bool = False,
):
pass
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
pass

View file

@ -1,12 +1,25 @@
import inspect
from typing import List, Optional, get_type_hints, Generic, TypeVar from typing import List, Optional, get_type_hints, Generic, TypeVar
import asyncio import asyncio
from uuid import UUID
import lancedb import lancedb
from pydantic import BaseModel
from lancedb.pydantic import Vector, LanceModel from lancedb.pydantic import Vector, LanceModel
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.storage.utils import copy_model, get_own_properties
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface, DataPoint from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
class IndexSchema(DataPoint):
id: str
text: str
_metadata: dict = {
"index_fields": ["text"]
}
class LanceDBAdapter(VectorDBInterface): class LanceDBAdapter(VectorDBInterface):
name = "LanceDB" name = "LanceDB"
url: str url: str
@ -38,10 +51,12 @@ class LanceDBAdapter(VectorDBInterface):
collection_names = await connection.table_names() collection_names = await connection.table_names()
return collection_name in collection_names return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema = None): async def create_collection(self, collection_name: str, payload_schema: BaseModel):
data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
payload_schema = self.get_data_point_schema(payload_schema)
data_point_types = get_type_hints(payload_schema)
class LanceDataPoint(LanceModel): class LanceDataPoint(LanceModel):
id: data_point_types["id"] id: data_point_types["id"]
vector: Vector(vector_size) vector: Vector(vector_size)
@ -55,13 +70,16 @@ class LanceDBAdapter(VectorDBInterface):
exist_ok = True, exist_ok = True,
) )
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
connection = await self.get_connection() connection = await self.get_connection()
payload_schema = type(data_points[0])
payload_schema = self.get_data_point_schema(payload_schema)
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
await self.create_collection( await self.create_collection(
collection_name, collection_name,
payload_schema = type(data_points[0].payload), payload_schema,
) )
collection = await connection.open_table(collection_name) collection = await connection.open_table(collection_name)
@ -79,15 +97,26 @@ class LanceDBAdapter(VectorDBInterface):
vector: Vector(vector_size) vector: Vector(vector_size)
payload: PayloadSchema payload: PayloadSchema
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
properties = get_own_properties(data_point)
properties["id"] = str(properties["id"])
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
id = str(data_point.id),
vector = vector,
payload = properties,
)
lance_data_points = [ lance_data_points = [
LanceDataPoint[type(data_point.id), type(data_point.payload)]( create_lance_data_point(data_point, data_vectors[data_point_index])
id = data_point.id, for (data_point_index, data_point) in enumerate(data_points)
vector = data_vectors[data_index],
payload = data_point.payload,
) for (data_index, data_point) in enumerate(data_points)
] ]
await collection.add(lance_data_points) await collection.merge_insert("id") \
.when_matched_update_all() \
.when_not_matched_insert_all() \
.execute(lance_data_points)
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_connection() connection = await self.get_connection()
@ -99,7 +128,7 @@ class LanceDBAdapter(VectorDBInterface):
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas() results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
return [ScoredResult( return [ScoredResult(
id = result["id"], id = UUID(result["id"]),
payload = result["payload"], payload = result["payload"],
score = 0, score = 0,
) for result in results.to_dict("index").values()] ) for result in results.to_dict("index").values()]
@ -135,10 +164,19 @@ class LanceDBAdapter(VectorDBInterface):
if value < min_value: if value < min_value:
min_value = value min_value = value
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values] normalized_values = []
min_value = min(result["_distance"] for result in result_values)
max_value = max(result["_distance"] for result in result_values)
if max_value == min_value:
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
normalized_values = [0 for _ in result_values]
else:
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
result_values]
return [ScoredResult( return [ScoredResult(
id = str(result["id"]), id = UUID(result["id"]),
payload = result["payload"], payload = result["payload"],
score = normalized_values[value_index], score = normalized_values[value_index],
) for value_index, result in enumerate(result_values)] ) for value_index, result in enumerate(result_values)]
@ -170,7 +208,27 @@ class LanceDBAdapter(VectorDBInterface):
results = await collection.delete(f"id IN {tuple(data_point_ids)}") results = await collection.delete(f"id IN {tuple(data_point_ids)}")
return results return results
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}", payload_schema = IndexSchema)
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = str(data_point.id),
text = getattr(data_point, data_point._metadata["index_fields"][0]),
) for data_point in data_points
])
async def prune(self): async def prune(self):
# Clean up the database if it was set up as temporary # Clean up the database if it was set up as temporary
if self.url.startswith("/"): if self.url.startswith("/"):
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside
def get_data_point_schema(self, model_type):
return copy_model(
model_type,
include_fields = {
"id": (str, ...),
},
exclude_fields = ["_metadata"],
)

View file

@ -1,13 +0,0 @@
from typing import Generic, TypeVar
from pydantic import BaseModel
PayloadSchema = TypeVar("PayloadSchema", bound = BaseModel)
class DataPoint(BaseModel, Generic[PayloadSchema]):
id: str
payload: PayloadSchema
embed_field: str = "value"
def get_embeddable_data(self):
if hasattr(self.payload, self.embed_field):
return getattr(self.payload, self.embed_field)

View file

@ -1,7 +1,8 @@
from typing import Any, Dict from typing import Any, Dict
from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
class ScoredResult(BaseModel): class ScoredResult(BaseModel):
id: str id: UUID
score: float # Lower score is better score: float # Lower score is better
payload: Dict[str, Any] payload: Dict[str, Any]

View file

@ -1,17 +1,26 @@
import asyncio import asyncio
from uuid import UUID
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
from typing import List, Optional, get_type_hints from typing import List, Optional, get_type_hints
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import JSON, Column, Table, select, delete from sqlalchemy import JSON, Column, Table, select, delete
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from .serialize_datetime import serialize_datetime from cognee.infrastructure.engine import DataPoint
from .serialize_data import serialize_data
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface, DataPoint from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ...relational.ModelBase import Base from ...relational.ModelBase import Base
class IndexSchema(DataPoint):
text: str
_metadata: dict = {
"index_fields": ["text"]
}
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
@ -45,7 +54,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base): class PGVectorDataPoint(Base):
__tablename__ = collection_name __tablename__ = collection_name
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True}
@ -71,47 +79,58 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def create_data_points( async def create_data_points(
self, collection_name: str, data_points: List[DataPoint] self, collection_name: str, data_points: List[DataPoint]
): ):
async with self.get_async_session() as session: if not await self.has_collection(collection_name):
if not await self.has_collection(collection_name): await self.create_collection(
await self.create_collection( collection_name = collection_name,
collection_name=collection_name, payload_schema = type(data_points[0]),
payload_schema=type(data_points[0].payload),
)
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
) )
vector_size = self.embedding_engine.get_vector_size() data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
)
class PGVectorDataPoint(Base): vector_size = self.embedding_engine.get_vector_size()
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
def __init__(self, id, payload, vector): class PGVectorDataPoint(Base):
self.id = id __tablename__ = collection_name
self.payload = payload __table_args__ = {"extend_existing": True}
self.vector = vector # PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
pgvector_data_points = [ def __init__(self, id, payload, vector):
PGVectorDataPoint( self.id = id
id=data_point.id, self.payload = payload
vector=data_vectors[data_index], self.vector = vector
payload=serialize_datetime(data_point.payload.dict()),
)
for (data_index, data_point) in enumerate(data_points)
]
pgvector_data_points = [
PGVectorDataPoint(
id = data_point.id,
vector = data_vectors[data_index],
payload = serialize_data(data_point.model_dump()),
)
for (data_index, data_point) in enumerate(data_points)
]
async with self.get_async_session() as session:
session.add_all(pgvector_data_points) session.add_all(pgvector_data_points)
await session.commit() await session.commit()
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = data_point.id,
text = data_point.get_embeddable_data(),
) for data_point in data_points
])
async def get_table(self, collection_name: str) -> Table: async def get_table(self, collection_name: str) -> Table:
""" """
Dynamically loads a table using the given collection name Dynamically loads a table using the given collection name
@ -126,18 +145,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
raise ValueError(f"Table '{collection_name}' not found.") raise ValueError(f"Table '{collection_name}' not found.")
async def retrieve(self, collection_name: str, data_point_ids: List[str]): async def retrieve(self, collection_name: str, data_point_ids: List[str]):
async with self.get_async_session() as session: # Get PGVectorDataPoint Table from database
# Get PGVectorDataPoint Table from database PGVectorDataPoint = await self.get_table(collection_name)
PGVectorDataPoint = await self.get_table(collection_name)
async with self.get_async_session() as session:
results = await session.execute( results = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids)) select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
) )
results = results.all() results = results.all()
return [ return [
ScoredResult(id=result.id, payload=result.payload, score=0) ScoredResult(
for result in results id = UUID(result.id),
payload = result.payload,
score = 0
) for result in results
] ]
async def search( async def search(
@ -154,11 +176,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
if query_text and not query_vector: if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0] query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
closest_items = []
# Use async session to connect to the database # Use async session to connect to the database
async with self.get_async_session() as session: async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# Find closest vectors to query_vector # Find closest vectors to query_vector
closest_items = await session.execute( closest_items = await session.execute(
select( select(
@ -171,19 +195,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
.limit(limit) .limit(limit)
) )
vector_list = [] vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
# Create and return ScoredResult objects # Extract distances and find min/max for normalization
return [ for vector in closest_items:
ScoredResult( # TODO: Add normalization of similarity score
id=str(row.id), payload=row.payload, score=row.similarity vector_list.append(vector)
)
for row in vector_list # Create and return ScoredResult objects
] return [
ScoredResult(
id = UUID(str(row.id)),
payload = row.payload,
score = row.similarity
) for row in vector_list
]
async def batch_search( async def batch_search(
self, self,

View file

@ -1,12 +1,15 @@
from datetime import datetime from datetime import datetime
from uuid import UUID
def serialize_datetime(data): def serialize_data(data):
"""Recursively convert datetime objects in dictionaries/lists to ISO format.""" """Recursively convert datetime objects in dictionaries/lists to ISO format."""
if isinstance(data, dict): if isinstance(data, dict):
return {key: serialize_datetime(value) for key, value in data.items()} return {key: serialize_data(value) for key, value in data.items()}
elif isinstance(data, list): elif isinstance(data, list):
return [serialize_datetime(item) for item in data] return [serialize_data(item) for item in data]
elif isinstance(data, datetime): elif isinstance(data, datetime):
return data.isoformat() # Convert datetime to ISO 8601 string return data.isoformat() # Convert datetime to ISO 8601 string
elif isinstance(data, UUID):
return str(data)
else: else:
return data return data

View file

@ -1,12 +1,22 @@
import logging import logging
from uuid import UUID
from typing import List, Dict, Optional from typing import List, Dict, Optional
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("QDrantAdapter") logger = logging.getLogger("QDrantAdapter")
class IndexSchema(DataPoint):
text: str
_metadata: dict = {
"index_fields": ["text"]
}
# class CollectionConfig(BaseModel, extra = "forbid"): # class CollectionConfig(BaseModel, extra = "forbid"):
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" ) # vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration") # hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
@ -75,20 +85,19 @@ class QDrantAdapter(VectorDBInterface):
): ):
client = self.get_qdrant_client() client = self.get_qdrant_client()
result = await client.create_collection( if not await client.collection_exists(collection_name):
collection_name = collection_name, await client.create_collection(
vectors_config = { collection_name = collection_name,
"text": models.VectorParams( vectors_config = {
size = self.embedding_engine.get_vector_size(), "text": models.VectorParams(
distance = "Cosine" size = self.embedding_engine.get_vector_size(),
) distance = "Cosine"
} )
) }
)
await client.close() await client.close()
return result
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
client = self.get_qdrant_client() client = self.get_qdrant_client()
@ -96,8 +105,8 @@ class QDrantAdapter(VectorDBInterface):
def convert_to_qdrant_point(data_point: DataPoint): def convert_to_qdrant_point(data_point: DataPoint):
return models.PointStruct( return models.PointStruct(
id = data_point.id, id = str(data_point.id),
payload = data_point.payload.dict(), payload = data_point.model_dump(),
vector = { vector = {
"text": data_vectors[data_points.index(data_point)] "text": data_vectors[data_points.index(data_point)]
} }
@ -116,6 +125,17 @@ class QDrantAdapter(VectorDBInterface):
finally: finally:
await client.close() await client.close()
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = data_point.id,
text = getattr(data_point, data_point._metadata["index_fields"][0]),
) for data_point in data_points
])
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):
client = self.get_qdrant_client() client = self.get_qdrant_client()
results = await client.retrieve(collection_name, data_point_ids, with_payload = True) results = await client.retrieve(collection_name, data_point_ids, with_payload = True)
@ -135,7 +155,7 @@ class QDrantAdapter(VectorDBInterface):
client = self.get_qdrant_client() client = self.get_qdrant_client()
result = await client.search( results = await client.search(
collection_name = collection_name, collection_name = collection_name,
query_vector = models.NamedVector( query_vector = models.NamedVector(
name = "text", name = "text",
@ -147,7 +167,16 @@ class QDrantAdapter(VectorDBInterface):
await client.close() await client.close()
return result return [
ScoredResult(
id = UUID(result.id),
payload = {
**result.payload,
"id": UUID(result.id),
},
score = 1 - result.score,
) for result in results
]
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int = None, with_vectors: bool = False): async def batch_search(self, collection_name: str, query_texts: List[str], limit: int = None, with_vectors: bool = False):

View file

@ -1,6 +1,6 @@
from typing import List, Protocol, Optional from typing import List, Protocol, Optional
from abc import abstractmethod from abc import abstractmethod
from .models.DataPoint import DataPoint from cognee.infrastructure.engine import DataPoint
from .models.PayloadSchema import PayloadSchema from .models.PayloadSchema import PayloadSchema
class VectorDBInterface(Protocol): class VectorDBInterface(Protocol):

View file

@ -1,13 +1,22 @@
import asyncio import asyncio
import logging import logging
from typing import List, Optional from typing import List, Optional
from uuid import UUID
from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("WeaviateAdapter") logger = logging.getLogger("WeaviateAdapter")
class IndexSchema(DataPoint):
text: str
_metadata: dict = {
"index_fields": ["text"]
}
class WeaviateAdapter(VectorDBInterface): class WeaviateAdapter(VectorDBInterface):
name = "Weaviate" name = "Weaviate"
url: str url: str
@ -48,18 +57,21 @@ class WeaviateAdapter(VectorDBInterface):
future = asyncio.Future() future = asyncio.Future()
future.set_result( if not self.client.collections.exists(collection_name):
self.client.collections.create( future.set_result(
name=collection_name, self.client.collections.create(
properties=[ name = collection_name,
wvcc.Property( properties = [
name="text", wvcc.Property(
data_type=wvcc.DataType.TEXT, name = "text",
skip_vectorization=True data_type = wvcc.DataType.TEXT,
) skip_vectorization = True
] )
]
)
) )
) else:
future.set_result(self.get_collection(collection_name))
return await future return await future
@ -70,36 +82,60 @@ class WeaviateAdapter(VectorDBInterface):
from weaviate.classes.data import DataObject from weaviate.classes.data import DataObject
data_vectors = await self.embed_data( data_vectors = await self.embed_data(
list(map(lambda data_point: data_point.get_embeddable_data(), data_points))) [data_point.get_embeddable_data() for data_point in data_points]
)
def convert_to_weaviate_data_points(data_point: DataPoint): def convert_to_weaviate_data_points(data_point: DataPoint):
vector = data_vectors[data_points.index(data_point)] vector = data_vectors[data_points.index(data_point)]
properties = data_point.model_dump()
if "id" in properties:
properties["uuid"] = str(data_point.id)
del properties["id"]
return DataObject( return DataObject(
uuid = data_point.id, uuid = data_point.id,
properties = data_point.payload.dict(), properties = properties,
vector = vector vector = vector
) )
data_points = list(map(convert_to_weaviate_data_points, data_points)) data_points = [convert_to_weaviate_data_points(data_point) for data_point in data_points]
collection = self.get_collection(collection_name) collection = self.get_collection(collection_name)
try: try:
if len(data_points) > 1: if len(data_points) > 1:
return collection.data.insert_many(data_points) with collection.batch.dynamic() as batch:
for data_point in data_points:
batch.add_object(
uuid = data_point.uuid,
vector = data_point.vector,
properties = data_point.properties,
references = data_point.references,
)
else: else:
return collection.data.insert(data_points[0]) data_point: DataObject = data_points[0]
# with collection.batch.dynamic() as batch: return collection.data.update(
# for point in data_points: uuid = data_point.uuid,
# batch.add_object( vector = data_point.vector,
# uuid = point.uuid, properties = data_point.properties,
# properties = point.properties, references = data_point.references,
# vector = point.vector )
# )
except Exception as error: except Exception as error:
logger.error("Error creating data points: %s", str(error)) logger.error("Error creating data points: %s", str(error))
raise error raise error
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = data_point.id,
text = data_point.get_embeddable_data(),
) for data_point in data_points
])
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):
from weaviate.classes.query import Filter from weaviate.classes.query import Filter
future = asyncio.Future() future = asyncio.Future()
@ -143,9 +179,9 @@ class WeaviateAdapter(VectorDBInterface):
return [ return [
ScoredResult( ScoredResult(
id=str(result.uuid), id = UUID(str(result.uuid)),
payload=result.properties, payload = result.properties,
score=float(result.metadata.score) score = 1 - float(result.metadata.score)
) for result in search_result.objects ) for result in search_result.objects
] ]

View file

@ -0,0 +1 @@
from .models.DataPoint import DataPoint

View file

@ -0,0 +1,72 @@
from enum import Enum
from typing import Optional
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model, get_model_instance_from_graph
if __name__ == "__main__":
class CarTypeName(Enum):
Pickup = "Pickup"
Sedan = "Sedan"
SUV = "SUV"
Coupe = "Coupe"
Convertible = "Convertible"
Hatchback = "Hatchback"
Wagon = "Wagon"
Minivan = "Minivan"
Van = "Van"
class CarType(DataPoint):
id: str
name: CarTypeName
_metadata: dict = dict(index_fields = ["name"])
class Car(DataPoint):
id: str
brand: str
model: str
year: int
color: str
is_type: CarType
class Person(DataPoint):
id: str
name: str
age: int
owns_car: list[Car]
driving_licence: Optional[dict]
_metadata: dict = dict(index_fields = ["name"])
boris = Person(
id = "boris",
name = "Boris",
age = 30,
owns_car = [
Car(
id = "car1",
brand = "Toyota",
model = "Camry",
year = 2020,
color = "Blue",
is_type = CarType(id = "sedan", name = CarTypeName.Sedan),
),
],
driving_licence = {
"issued_by": "PU Vrsac",
"issued_on": "2025-11-06",
"number": "1234567890",
"expires_on": "2025-11-06",
},
)
nodes, edges = get_graph_from_model(boris)
print(nodes)
print(edges)
person_data = nodes[len(nodes) - 1]
parsed_person = get_model_instance_from_graph(nodes, edges, 'boris')
print(parsed_person)

View file

@ -0,0 +1,24 @@
from typing_extensions import TypedDict
from uuid import UUID, uuid4
from typing import Optional
from datetime import datetime, timezone
from pydantic import BaseModel, Field
class MetaData(TypedDict):
index_fields: list[str]
class DataPoint(BaseModel):
id: UUID = Field(default_factory = uuid4)
updated_at: Optional[datetime] = datetime.now(timezone.utc)
_metadata: Optional[MetaData] = {
"index_fields": []
}
# class Config:
# underscore_attrs_are_private = True
def get_embeddable_data(self):
if self._metadata and len(self._metadata["index_fields"]) > 0 \
and hasattr(self, self._metadata["index_fields"][0]):
return getattr(self, self._metadata["index_fields"][0])

View file

@ -1,18 +1,18 @@
from uuid import UUID, uuid5, NAMESPACE_OID from uuid import uuid5, NAMESPACE_OID
from cognee.modules.chunking import DocumentChunk from .models.DocumentChunk import DocumentChunk
from cognee.tasks.chunking import chunk_by_paragraph from cognee.tasks.chunks import chunk_by_paragraph
class TextChunker(): class TextChunker():
id: UUID document = None
max_chunk_size: int max_chunk_size: int
chunk_index = 0 chunk_index = 0
chunk_size = 0 chunk_size = 0
paragraph_chunks = [] paragraph_chunks = []
def __init__(self, id: UUID, get_text: callable, chunk_size: int = 1024): def __init__(self, document, get_text: callable, chunk_size: int = 1024):
self.id = id self.document = document
self.max_chunk_size = chunk_size self.max_chunk_size = chunk_size
self.get_text = get_text self.get_text = get_text
@ -29,10 +29,10 @@ class TextChunker():
else: else:
if len(self.paragraph_chunks) == 0: if len(self.paragraph_chunks) == 0:
yield DocumentChunk( yield DocumentChunk(
id = chunk_data["chunk_id"],
text = chunk_data["text"], text = chunk_data["text"],
word_count = chunk_data["word_count"], word_count = chunk_data["word_count"],
document_id = str(self.id), is_part_of = self.document,
chunk_id = str(chunk_data["chunk_id"]),
chunk_index = self.chunk_index, chunk_index = self.chunk_index,
cut_type = chunk_data["cut_type"], cut_type = chunk_data["cut_type"],
) )
@ -40,25 +40,31 @@ class TextChunker():
self.chunk_size = 0 self.chunk_size = 0
else: else:
chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks) chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks)
yield DocumentChunk( try:
text = chunk_text, yield DocumentChunk(
word_count = self.chunk_size, id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
document_id = str(self.id), text = chunk_text,
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{self.chunk_index}")), word_count = self.chunk_size,
chunk_index = self.chunk_index, is_part_of = self.document,
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"], chunk_index = self.chunk_index,
) cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
)
except Exception as e:
print(e)
self.paragraph_chunks = [chunk_data] self.paragraph_chunks = [chunk_data]
self.chunk_size = chunk_data["word_count"] self.chunk_size = chunk_data["word_count"]
self.chunk_index += 1 self.chunk_index += 1
if len(self.paragraph_chunks) > 0: if len(self.paragraph_chunks) > 0:
yield DocumentChunk( try:
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks), yield DocumentChunk(
word_count = self.chunk_size, id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
document_id = str(self.id), text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{self.chunk_index}")), word_count = self.chunk_size,
chunk_index = self.chunk_index, is_part_of = self.document,
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"], chunk_index = self.chunk_index,
) cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
)
except Exception as e:
print(e)

View file

@ -1,2 +0,0 @@
from .models.DocumentChunk import DocumentChunk
from .TextChunker import TextChunker

View file

@ -1,9 +1,14 @@
from pydantic import BaseModel from typing import Optional
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
class DocumentChunk(BaseModel): class DocumentChunk(DataPoint):
text: str text: str
word_count: int word_count: int
document_id: str
chunk_id: str
chunk_index: int chunk_index: int
cut_type: str cut_type: str
is_part_of: Document
_metadata: Optional[dict] = {
"index_fields": ["text"],
}

View file

@ -0,0 +1 @@
from .knowledge_graph.extract_content_graph import extract_content_graph

View file

@ -0,0 +1 @@
from .extract_content_graph import extract_content_graph

View file

@ -1,36 +1,36 @@
import logging import logging
logger = logging.getLogger(__name__)
async def detect_language(text: str):
async def detect_language(data:str):
""" """
Detect the language of the given text and return its ISO 639-1 language code. Detect the language of the given text and return its ISO 639-1 language code.
If the detected language is Croatian ('hr'), it maps to Serbian ('sr'). If the detected language is Croatian ("hr"), it maps to Serbian ("sr").
The text is trimmed to the first 100 characters for efficient processing. The text is trimmed to the first 100 characters for efficient processing.
Parameters: Parameters:
text (str): The text for language detection. text (str): The text for language detection.
Returns: Returns:
str: The ISO 639-1 language code of the detected language, or 'None' in case of an error. str: The ISO 639-1 language code of the detected language, or "None" in case of an error.
""" """
# Trim the text to the first 100 characters
from langdetect import detect, LangDetectException from langdetect import detect, LangDetectException
trimmed_text = data[:100] # Trim the text to the first 100 characters
trimmed_text = text[:100]
try: try:
# Detect the language using langdetect # Detect the language using langdetect
detected_lang_iso639_1 = detect(trimmed_text) detected_lang_iso639_1 = detect(trimmed_text)
logging.info(f"Detected ISO 639-1 code: {detected_lang_iso639_1}")
# Special case: map 'hr' (Croatian) to 'sr' (Serbian ISO 639-2) # Special case: map "hr" (Croatian) to "sr" (Serbian ISO 639-2)
if detected_lang_iso639_1 == 'hr': if detected_lang_iso639_1 == "hr":
yield 'sr' return "sr"
yield detected_lang_iso639_1
return detected_lang_iso639_1
except LangDetectException as e: except LangDetectException as e:
logging.error(f"Language detection error: {e}") logger.error(f"Language detection error: {e}")
except Exception as e:
logging.error(f"Unexpected error: {e}")
yield None except Exception as e:
logger.error(f"Unexpected error: {e}")
return None

View file

@ -0,0 +1,41 @@
import logging
logger = logging.getLogger(__name__)
async def translate_text(text, source_language: str = "sr", target_language: str = "en", region_name = "eu-west-1"):
"""
Translate text from source language to target language using AWS Translate.
Parameters:
text (str): The text to be translated.
source_language (str): The source language code (e.g., "sr" for Serbian). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
target_language (str): The target language code (e.g., "en" for English). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
region_name (str): AWS region name.
Returns:
str: Translated text or an error message.
"""
import boto3
from botocore.exceptions import BotoCoreError, ClientError
if not text:
raise ValueError("No text to translate.")
if not source_language or not target_language:
raise ValueError("Source and target language codes are required.")
try:
translate = boto3.client(service_name = "translate", region_name = region_name, use_ssl = True)
result = translate.translate_text(
Text = text,
SourceLanguageCode = source_language,
TargetLanguageCode = target_language,
)
yield result.get("TranslatedText", "No translation found.")
except BotoCoreError as e:
logger.error(f"BotoCoreError occurred: {e}")
yield None
except ClientError as e:
logger.error(f"ClientError occurred: {e}")
yield None

View file

@ -1,34 +1,15 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document from .Document import Document
class AudioDocument(Document): class AudioDocument(Document):
type: str = "audio" type: str = "audio"
title: str
raw_data_location: str
chunking_strategy: str
def __init__(self, id: UUID, title: str, raw_data_location: str, chunking_strategy:str="paragraph"):
self.id = id or uuid5(NAMESPACE_OID, title)
self.title = title
self.raw_data_location = raw_data_location
self.chunking_strategy = chunking_strategy
def read(self, chunk_size: int): def read(self, chunk_size: int):
# Transcribe the audio file # Transcribe the audio file
result = get_llm_client().create_transcript(self.raw_data_location) result = get_llm_client().create_transcript(self.raw_data_location)
text = result.text text = result.text
chunker = TextChunker(self.id, chunk_size = chunk_size, get_text = lambda: text) chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: text)
yield from chunker.read() yield from chunker.read()
def to_dict(self) -> dict:
return dict(
id=str(self.id),
type=self.type,
title=self.title,
raw_data_location=self.raw_data_location,
)

View file

@ -1,10 +1,8 @@
from uuid import UUID from cognee.infrastructure.engine import DataPoint
from typing import Protocol
class Document(Protocol): class Document(DataPoint):
id: UUID
type: str type: str
title: str name: str
raw_data_location: str raw_data_location: str
def read(self, chunk_size: int) -> str: def read(self, chunk_size: int) -> str:

View file

@ -1,33 +1,15 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document from .Document import Document
class ImageDocument(Document): class ImageDocument(Document):
type: str = "image" type: str = "image"
title: str
raw_data_location: str
def __init__(self, id: UUID, title: str, raw_data_location: str):
self.id = id or uuid5(NAMESPACE_OID, title)
self.title = title
self.raw_data_location = raw_data_location
def read(self, chunk_size: int): def read(self, chunk_size: int):
# Transcribe the image file # Transcribe the image file
result = get_llm_client().transcribe_image(self.raw_data_location) result = get_llm_client().transcribe_image(self.raw_data_location)
text = result.choices[0].message.content text = result.choices[0].message.content
chunker = TextChunker(self.id, chunk_size = chunk_size, get_text = lambda: text) chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: text)
yield from chunker.read() yield from chunker.read()
def to_dict(self) -> dict:
return dict(
id=str(self.id),
type=self.type,
title=self.title,
raw_data_location=self.raw_data_location,
)

View file

@ -1,19 +1,11 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from pypdf import PdfReader from pypdf import PdfReader
from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document from .Document import Document
class PdfDocument(Document): class PdfDocument(Document):
type: str = "pdf" type: str = "pdf"
title: str
raw_data_location: str
def __init__(self, id: UUID, title: str, raw_data_location: str): def read(self, chunk_size: int):
self.id = id or uuid5(NAMESPACE_OID, title)
self.title = title
self.raw_data_location = raw_data_location
def read(self, chunk_size: int) -> PdfReader:
file = PdfReader(self.raw_data_location) file = PdfReader(self.raw_data_location)
def get_text(): def get_text():
@ -21,16 +13,8 @@ class PdfDocument(Document):
page_text = page.extract_text() page_text = page.extract_text()
yield page_text yield page_text
chunker = TextChunker(self.id, chunk_size = chunk_size, get_text = get_text) chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
yield from chunker.read() yield from chunker.read()
file.stream.close() file.stream.close()
def to_dict(self) -> dict:
return dict(
id = str(self.id),
type = self.type,
title = self.title,
raw_data_location = self.raw_data_location,
)

View file

@ -1,16 +1,8 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document from .Document import Document
class TextDocument(Document): class TextDocument(Document):
type: str = "text" type: str = "text"
title: str
raw_data_location: str
def __init__(self, id: UUID, title: str, raw_data_location: str):
self.id = id or uuid5(NAMESPACE_OID, title)
self.title = title
self.raw_data_location = raw_data_location
def read(self, chunk_size: int): def read(self, chunk_size: int):
def get_text(): def get_text():
@ -23,16 +15,6 @@ class TextDocument(Document):
yield text yield text
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
chunker = TextChunker(self.id,chunk_size = chunk_size, get_text = get_text)
yield from chunker.read() yield from chunker.read()
def to_dict(self) -> dict:
return dict(
id = str(self.id),
type = self.type,
title = self.title,
raw_data_location = self.raw_data_location,
)

View file

@ -1,3 +1,4 @@
from .Document import Document
from .PdfDocument import PdfDocument from .PdfDocument import PdfDocument
from .TextDocument import TextDocument from .TextDocument import TextDocument
from .ImageDocument import ImageDocument from .ImageDocument import ImageDocument

View file

@ -0,0 +1,12 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from .EntityType import EntityType
class Entity(DataPoint):
name: str
is_a: EntityType
description: str
mentioned_in: DocumentChunk
_metadata: dict = {
"index_fields": ["name"],
}

View file

@ -0,0 +1,11 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
class EntityType(DataPoint):
name: str
type: str
description: str
exists_in: DocumentChunk
_metadata: dict = {
"index_fields": ["name"],
}

View file

@ -0,0 +1,2 @@
from .Entity import Entity
from .EntityType import EntityType

View file

@ -0,0 +1,3 @@
from .generate_node_id import generate_node_id
from .generate_node_name import generate_node_name
from .generate_edge_name import generate_edge_name

View file

@ -0,0 +1,2 @@
def generate_edge_name(name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "")

View file

@ -0,0 +1,4 @@
from uuid import NAMESPACE_OID, uuid5
def generate_node_id(node_id: str) -> str:
return uuid5(NAMESPACE_OID, node_id.lower().replace(" ", "_").replace("'", ""))

View file

@ -0,0 +1,2 @@
def generate_node_name(name: str) -> str:
return name.lower().replace("'", "")

View file

@ -1,5 +0,0 @@
def generate_node_name(name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "")
def generate_node_id(node_id: str) -> str:
return node_id.lower().replace(" ", "_").replace("'", "")

View file

@ -0,0 +1,2 @@
from .get_graph_from_model import get_graph_from_model
from .get_model_instance_from_graph import get_model_instance_from_graph

View file

@ -0,0 +1,107 @@
from datetime import datetime, timezone
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
nodes = []
edges = []
data_point_properties = {}
excluded_properties = set()
for field_name, field_value in data_point:
if field_name == "_metadata":
continue
if isinstance(field_value, DataPoint):
excluded_properties.add(field_name)
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
excluded_properties.add(field_name)
for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"metadata": {
"type": "list"
},
}))
added_edges[edge_key] = True
continue
data_point_properties[field_name] = field_value
SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
"_metadata": (dict, data_point._metadata),
},
exclude_fields = excluded_properties,
)
if include_root:
nodes.append(SimpleDataPointModel(**data_point_properties))
return nodes, edges
def get_own_properties(property_nodes, property_edges):
own_properties = []
destination_nodes = [str(property_edge[1]) for property_edge in property_edges]
for node in property_nodes:
if str(node.id) in destination_nodes:
continue
own_properties.append(node)
return own_properties

View file

@ -0,0 +1,29 @@
from pydantic_core import PydanticUndefined
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
node_map = {}
for node in nodes:
node_map[node.id] = node
for edge in edges:
source_node = node_map[edge[0]]
target_node = node_map[edge[1]]
edge_label = edge[2]
edge_properties = edge[3] if len(edge) == 4 else {}
edge_metadata = edge_properties.get("metadata", {})
edge_type = edge_metadata.get("type")
if edge_type == "list":
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
else:
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
return node_map[entity_id]

View file

@ -7,7 +7,7 @@ from ..tasks.Task import Task
logger = logging.getLogger("run_tasks(tasks: [Task], data)") logger = logging.getLogger("run_tasks(tasks: [Task], data)")
async def run_tasks_base(tasks: [Task], data = None, user: User = None): async def run_tasks_base(tasks: list[Task], data = None, user: User = None):
if len(tasks) == 0: if len(tasks) == 0:
yield data yield data
return return
@ -16,7 +16,7 @@ async def run_tasks_base(tasks: [Task], data = None, user: User = None):
running_task = tasks[0] running_task = tasks[0]
leftover_tasks = tasks[1:] leftover_tasks = tasks[1:]
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1 next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
if inspect.isasyncgenfunction(running_task.executable): if inspect.isasyncgenfunction(running_task.executable):

View file

@ -1,33 +0,0 @@
import asyncio
import nest_asyncio
import dspy
from cognee.modules.search.vector.search_similarity import search_similarity
nest_asyncio.apply()
class AnswerFromContext(dspy.Signature):
question: str = dspy.InputField()
context: str = dspy.InputField(desc = "Context to use for answer generation.")
answer: str = dspy.OutputField()
question_answer_llm = dspy.OpenAI(model = "gpt-3.5-turbo-instruct")
class CogneeSearch(dspy.Module):
def __init__(self, ):
super().__init__()
self.generate_answer = dspy.TypedChainOfThought(AnswerFromContext)
def forward(self, question):
context = asyncio.run(search_similarity(question))
context_text = "\n".join(context)
print(f"Context: {context_text}")
with dspy.context(lm = question_answer_llm):
answer_prediction = self.generate_answer(context = context_text, question = question)
answer = answer_prediction.answer
print(f"Question: {question}")
print(f"Answer: {answer}")
return dspy.Prediction(context = context_text, answer = answer)

View file

@ -1,43 +0,0 @@
import asyncio
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_adjacent(query: str) -> list[(str, str)]:
"""
Find the neighbours of a given node in the graph and return their ids and descriptions.
Parameters:
- query (str): The query string to filter nodes by.
Returns:
- list[(str, str)]: A list containing the unique identifiers and names of the neighbours of the given node.
"""
node_id = query
if node_id is None:
return {}
graph_engine = await get_graph_engine()
exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "uuid" in exact_node:
neighbours = await graph_engine.get_neighbours(exact_node["uuid"])
else:
vector_engine = get_vector_engine()
results = await asyncio.gather(
vector_engine.search("entities", query_text = query, limit = 10),
vector_engine.search("classification", query_text = query, limit = 10),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]
if len(relevant_results) == 0:
return []
node_neighbours = await asyncio.gather(*[graph_engine.get_neighbours(result.id) for result in relevant_results])
neighbours = []
for neighbour_ids in node_neighbours:
neighbours.extend(neighbour_ids)
return neighbours

View file

@ -1,15 +0,0 @@
from cognee.infrastructure.databases.graph import get_graph_engine, get_graph_config
async def search_cypher(query: str):
"""
Use a Cypher query to search the graph and return the results.
"""
graph_config = get_graph_config()
if graph_config.graph_database_provider == "neo4j":
graph_engine = await get_graph_engine()
result = await graph_engine.graph().run(query)
return result
else:
raise ValueError("Unsupported search type for the used graph engine.")

View file

@ -1,27 +0,0 @@
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_similarity(query: str) -> list[str, str]:
"""
Parameters:
- query (str): The query string to filter nodes by.
Returns:
- list(chunk): A list of objects providing information about the chunks related to query.
"""
vector_engine = get_vector_engine()
similar_results = await vector_engine.search("chunks", query, limit = 5)
results = [
parse_payload(result.payload) for result in similar_results
]
return results
def parse_payload(payload: dict) -> dict:
return {
"text": payload["text"],
"chunk_id": payload["chunk_id"],
"document_id": payload["document_id"],
}

View file

@ -1,17 +0,0 @@
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_summary(query: str) -> list:
"""
Parameters:
- query (str): The query string to filter summaries by.
Returns:
- list[str, UUID]: A list of objects providing information about the summaries related to query.
"""
vector_engine = get_vector_engine()
summaries_results = await vector_engine.search("summaries", query, limit = 5)
summaries = [summary.payload for summary in summaries_results]
return summaries

View file

@ -1,16 +0,0 @@
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client
async def categorize_relevant_category(query: str, summary, response_model: Type[BaseModel]):
llm_client = get_llm_client()
enriched_query= render_prompt("categorize_categories.txt", {"query": query, "categories": summary})
system_prompt = " Choose the relevant categories and return appropriate output based on the model"
llm_output = await llm_client.acreate_structured_output(enriched_query, system_prompt, response_model)
return llm_output.model_dump()

View file

@ -1,15 +0,0 @@
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client
async def categorize_relevant_summary(query: str, summaries, response_model: Type[BaseModel]):
llm_client = get_llm_client()
enriched_query= render_prompt("categorize_summary.txt", {"query": query, "summaries": summaries})
system_prompt = "Choose the relevant summaries and return appropriate output based on the model"
llm_output = await llm_client.acreate_structured_output(enriched_query, system_prompt, response_model)
return llm_output

View file

@ -1,17 +0,0 @@
import logging
from typing import List, Dict
from cognee.modules.cognify.config import get_cognify_config
from .extraction.categorize_relevant_summary import categorize_relevant_summary
logger = logging.getLogger(__name__)
async def get_cognitive_layers(content: str, categories: List[Dict]):
try:
cognify_config = get_cognify_config()
return (await categorize_relevant_summary(
content,
categories[0],
cognify_config.summarization_model,
)).cognitive_layers
except Exception as error:
logger.error("Error extracting cognitive layers from content: %s", error, exc_info = True)
raise error

View file

@ -1 +0,0 @@
""" Placeholder for BM25 implementation"""

View file

@ -1 +0,0 @@
"""Placeholder for fusions search implementation"""

View file

@ -1,36 +0,0 @@
import asyncio
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_traverse(query: str):
node_id = query
rules = set()
graph_engine = await get_graph_engine()
vector_engine = get_vector_engine()
exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "uuid" in exact_node:
edges = await graph_engine.get_edges(exact_node["uuid"])
for edge in edges:
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
else:
results = await asyncio.gather(
vector_engine.search("entities", query_text = query, limit = 10),
vector_engine.search("classification", query_text = query, limit = 10),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]
if len(relevant_results) > 0:
for result in relevant_results:
graph_node_id = result.id
edges = await graph_engine.get_edges(graph_node_id)
for edge in edges:
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
return list(rules)

View file

@ -0,0 +1,46 @@
import json
from uuid import UUID
from datetime import datetime
from pydantic_core import PydanticUndefined
from cognee.infrastructure.engine import DataPoint
class JSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat() # Convert datetime to ISO 8601 string
elif isinstance(obj, UUID):
# if the obj is uuid, we simply return the value of uuid
return str(obj)
return json.JSONEncoder.default(self, obj)
from pydantic import create_model
def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list = []):
fields = {
name: (field.annotation, field.default if field.default is not None else PydanticUndefined)
for name, field in model.model_fields.items()
if name not in exclude_fields
}
final_fields = {
**fields,
**include_fields
}
return create_model(model.__name__, **final_fields)
def get_own_properties(data_point: DataPoint):
properties = {}
for field_name, field_value in data_point:
if field_name == "_metadata" \
or isinstance(field_value, dict) \
or isinstance(field_value, DataPoint) \
or (isinstance(field_value, list) and isinstance(field_value[0], DataPoint)):
continue
properties[field_name] = field_value
return properties

View file

@ -1,84 +1,95 @@
from typing import List, Union, Literal, Optional from typing import Any, List, Union, Literal, Optional
from pydantic import BaseModel from cognee.infrastructure.engine import DataPoint
class BaseClass(BaseModel): class Variable(DataPoint):
id: str
name: str
type: Literal["Class"] = "Class"
description: str
constructor_parameters: Optional[List[str]] = None
class Class(BaseModel):
id: str
name: str
type: Literal["Class"] = "Class"
description: str
constructor_parameters: Optional[List[str]] = None
from_class: Optional[BaseClass] = None
class ClassInstance(BaseModel):
id: str
name: str
type: Literal["ClassInstance"] = "ClassInstance"
description: str
from_class: Class
class Function(BaseModel):
id: str
name: str
type: Literal["Function"] = "Function"
description: str
parameters: Optional[List[str]] = None
return_type: str
is_static: Optional[bool] = False
class Variable(BaseModel):
id: str id: str
name: str name: str
type: Literal["Variable"] = "Variable" type: Literal["Variable"] = "Variable"
description: str description: str
is_static: Optional[bool] = False is_static: Optional[bool] = False
default_value: Optional[str] = None default_value: Optional[str] = None
data_type: str
class Operator(BaseModel): _metadata = {
"index_fields": ["name"]
}
class Operator(DataPoint):
id: str id: str
name: str name: str
type: Literal["Operator"] = "Operator" type: Literal["Operator"] = "Operator"
description: str description: str
return_type: str return_type: str
class ExpressionPart(BaseModel): class Class(DataPoint):
id: str
name: str
type: Literal["Class"] = "Class"
description: str
constructor_parameters: List[Variable]
extended_from_class: Optional["Class"] = None
has_methods: list["Function"]
_metadata = {
"index_fields": ["name"]
}
class ClassInstance(DataPoint):
id: str
name: str
type: Literal["ClassInstance"] = "ClassInstance"
description: str
from_class: Class
instantiated_by: Union["Function"]
instantiation_arguments: List[Variable]
_metadata = {
"index_fields": ["name"]
}
class Function(DataPoint):
id: str
name: str
type: Literal["Function"] = "Function"
description: str
parameters: List[Variable]
return_type: str
is_static: Optional[bool] = False
_metadata = {
"index_fields": ["name"]
}
class FunctionCall(DataPoint):
id: str
type: Literal["FunctionCall"] = "FunctionCall"
called_by: Union[Function, Literal["main"]]
function_called: Function
function_arguments: List[Any]
class Expression(DataPoint):
id: str id: str
name: str name: str
type: Literal["Expression"] = "Expression" type: Literal["Expression"] = "Expression"
description: str description: str
expression: str expression: str
members: List[Union[Variable, Function, Operator]] members: List[Union[Variable, Function, Operator, "Expression"]]
class Expression(BaseModel): class SourceCodeGraph(DataPoint):
id: str
name: str
type: Literal["Expression"] = "Expression"
description: str
expression: str
members: List[Union[Variable, Function, Operator, ExpressionPart]]
class Edge(BaseModel):
source_node_id: str
target_node_id: str
relationship_name: Literal["called in", "stored in", "defined in", "returned by", "instantiated in", "uses", "updates"]
class SourceCodeGraph(BaseModel):
id: str id: str
name: str name: str
description: str description: str
language: str language: str
nodes: List[Union[ nodes: List[Union[
Class, Class,
ClassInstance,
Function, Function,
FunctionCall,
Variable, Variable,
Operator, Operator,
Expression, Expression,
ClassInstance,
]] ]]
edges: List[Edge]
Class.model_rebuild()
ClassInstance.model_rebuild()
Expression.model_rebuild()

View file

@ -1,6 +1,6 @@
""" This module contains utility functions for the cognee. """ """ This module contains utility functions for the cognee. """
import os import os
import datetime from datetime import datetime, timezone
import graphistry import graphistry
import networkx as nx import networkx as nx
import numpy as np import numpy as np
@ -45,7 +45,7 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
host = "https://eu.i.posthog.com" host = "https://eu.i.posthog.com"
) )
current_time = datetime.datetime.now() current_time = datetime.now(timezone.utc)
properties = { properties = {
"time": current_time.strftime("%m/%d/%Y"), "time": current_time.strftime("%m/%d/%Y"),
"user_id": user_id, "user_id": user_id,
@ -110,30 +110,36 @@ async def register_graphistry():
graphistry.register(api = 3, username = config.graphistry_username, password = config.graphistry_password) graphistry.register(api = 3, username = config.graphistry_username, password = config.graphistry_password)
def prepare_edges(graph): def prepare_edges(graph, source, target, edge_key):
return nx.to_pandas_edgelist(graph) edge_list = [{
source: str(edge[0]),
target: str(edge[1]),
edge_key: str(edge[2]),
} for edge in graph.edges(keys = True, data = True)]
return pd.DataFrame(edge_list)
def prepare_nodes(graph, include_size=False): def prepare_nodes(graph, include_size=False):
nodes_data = [] nodes_data = []
for node in graph.nodes: for node in graph.nodes:
node_info = graph.nodes[node] node_info = graph.nodes[node]
description = node_info.get("layer_description", {}).get("layer", "Default Layer") if isinstance(
node_info.get("layer_description"), dict) else node_info.get("layer_description", "Default Layer")
# description = node_info['layer_description']['layer'] if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info['layer_description'] else node_info.get('layer_description', node)
# if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info.get('layer_description'):
# description = node_info['layer_description']['layer']
# # Use 'layer_description' directly if it's not a dictionary, otherwise default to node ID
# else:
# description = node_info.get('layer_description', node)
node_data = {"id": node, "layer_description": description} if not node_info:
continue
node_data = {
"id": str(node),
"name": node_info["name"] if "name" in node_info else str(node),
}
if include_size: if include_size:
default_size = 10 # Default node size default_size = 10 # Default node size
larger_size = 20 # Size for nodes with specific keywords in their ID larger_size = 20 # Size for nodes with specific keywords in their ID
keywords = ["DOCUMENT", "User", "LAYER"] keywords = ["DOCUMENT", "User"]
node_size = larger_size if any(keyword in str(node) for keyword in keywords) else default_size node_size = larger_size if any(keyword in str(node) for keyword in keywords) else default_size
node_data["size"] = node_size node_data["size"] = node_size
nodes_data.append(node_data) nodes_data.append(node_data)
return pd.DataFrame(nodes_data) return pd.DataFrame(nodes_data)
@ -153,28 +159,28 @@ async def render_graph(graph, include_nodes=False, include_color=False, include_
graph = networkx_graph graph = networkx_graph
edges = prepare_edges(graph) edges = prepare_edges(graph, "source_node", "target_node", "relationship_name")
plotter = graphistry.edges(edges, "source", "target") plotter = graphistry.edges(edges, "source_node", "target_node")
plotter = plotter.bind(edge_label = "relationship_name")
if include_nodes: if include_nodes:
nodes = prepare_nodes(graph, include_size=include_size) nodes = prepare_nodes(graph, include_size = include_size)
plotter = plotter.nodes(nodes, "id") plotter = plotter.nodes(nodes, "id")
if include_size: if include_size:
plotter = plotter.bind(point_size="size") plotter = plotter.bind(point_size = "size")
if include_color: if include_color:
unique_layers = nodes["layer_description"].unique() pass
color_palette = generate_color_palette(unique_layers) # unique_layers = nodes["layer_description"].unique()
plotter = plotter.encode_point_color("layer_description", categorical_mapping=color_palette, # color_palette = generate_color_palette(unique_layers)
default_mapping="silver") # plotter = plotter.encode_point_color("layer_description", categorical_mapping=color_palette,
# default_mapping="silver")
if include_labels: if include_labels:
plotter = plotter.bind(point_label = "layer_description") plotter = plotter.bind(point_label = "name")
# Visualization # Visualization

View file

@ -1,10 +0,0 @@
from .summarization.summarize_text import summarize_text
from .chunk_naive_llm_classifier.chunk_naive_llm_classifier import chunk_naive_llm_classifier
from .chunk_remove_disconnected.chunk_remove_disconnected import chunk_remove_disconnected
from .chunk_update_check.chunk_update_check import chunk_update_check
from .save_chunks_to_store.save_chunks_to_store import save_chunks_to_store
from .source_documents_to_chunks.source_documents_to_chunks import source_documents_to_chunks
from .infer_data_ontology.infer_data_ontology import infer_data_ontology
from .check_permissions_on_documents.check_permissions_on_documents import check_permissions_on_documents
from .classify_documents.classify_documents import classify_documents
from .graph.chunks_into_graph import chunks_into_graph

View file

@ -5,7 +5,7 @@ from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
from cognee.modules.data.extraction.extract_categories import extract_categories from cognee.modules.data.extraction.extract_categories import extract_categories
from cognee.modules.chunking import DocumentChunk from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]): async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
@ -65,7 +65,7 @@ async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classific
"chunk_id": str(data_chunk.chunk_id), "chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id), "document_id": str(data_chunk.document_id),
}), }),
embed_field="text", index_fields=["text"],
) )
) )
@ -104,7 +104,7 @@ async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classific
"chunk_id": str(data_chunk.chunk_id), "chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id), "document_id": str(data_chunk.document_id),
}), }),
embed_field="text", index_fields=["text"],
) )
) )

View file

@ -1,39 +0,0 @@
import logging
from cognee.base_config import get_base_config
BaseConfig = get_base_config()
async def translate_text(data, source_language:str='sr', target_language:str='en', region_name='eu-west-1'):
"""
Translate text from source language to target language using AWS Translate.
Parameters:
data (str): The text to be translated.
source_language (str): The source language code (e.g., 'sr' for Serbian). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
target_language (str): The target language code (e.g., 'en' for English). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
region_name (str): AWS region name.
Returns:
str: Translated text or an error message.
"""
import boto3
from botocore.exceptions import BotoCoreError, ClientError
if not data:
yield "No text provided for translation."
if not source_language or not target_language:
yield "Both source and target language codes are required."
try:
translate = boto3.client(service_name='translate', region_name=region_name, use_ssl=True)
result = translate.translate_text(Text=data, SourceLanguageCode=source_language, TargetLanguageCode=target_language)
yield result.get('TranslatedText', 'No translation found.')
except BotoCoreError as e:
logging.info(f"BotoCoreError occurred: {e}")
yield "Error with AWS Translate service configuration or request."
except ClientError as e:
logging.info(f"ClientError occurred: {e}")
yield "Error with AWS client or network issue."

View file

@ -1,26 +0,0 @@
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking import DocumentChunk
async def chunk_update_check(data_chunks: list[DocumentChunk], collection_name: str) -> list[DocumentChunk]:
vector_engine = get_vector_engine()
if not await vector_engine.has_collection(collection_name):
# If collection doesn't exist, all data_chunks are new
return data_chunks
existing_chunks = await vector_engine.retrieve(
collection_name,
[str(chunk.chunk_id) for chunk in data_chunks],
)
existing_chunks_map = {str(chunk.id): chunk.payload for chunk in existing_chunks}
affected_data_chunks = []
for chunk in data_chunks:
if chunk.chunk_id not in existing_chunks_map or \
chunk.text != existing_chunks_map[chunk.chunk_id]["text"]:
affected_data_chunks.append(chunk)
return affected_data_chunks

View file

@ -2,3 +2,4 @@ from .query_chunks import query_chunks
from .chunk_by_word import chunk_by_word from .chunk_by_word import chunk_by_word
from .chunk_by_sentence import chunk_by_sentence from .chunk_by_sentence import chunk_by_sentence
from .chunk_by_paragraph import chunk_by_paragraph from .chunk_by_paragraph import chunk_by_paragraph
from .remove_disconnected_chunks import remove_disconnected_chunks

View file

@ -1,4 +1,4 @@
from cognee.tasks.chunking import chunk_by_paragraph from cognee.tasks.chunks import chunk_by_paragraph
if __name__ == "__main__": if __name__ == "__main__":
def test_chunking_on_whole_text(): def test_chunking_on_whole_text():

View file

@ -10,7 +10,7 @@ async def query_chunks(query: str) -> list[dict]:
""" """
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("chunks", query, limit = 5) found_chunks = await vector_engine.search("DocumentChunk_text", query, limit = 5)
chunks = [result.payload for result in found_chunks] chunks = [result.payload for result in found_chunks]

View file

@ -1,7 +1,7 @@
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.chunking import DocumentChunk from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
async def chunk_remove_disconnected(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]: async def remove_disconnected_chunks(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
document_ids = set((data_chunk.document_id for data_chunk in data_chunks)) document_ids = set((data_chunk.document_id for data_chunk in data_chunks))

View file

@ -1,13 +0,0 @@
from cognee.modules.data.models import Data
from cognee.modules.data.processing.document_types import Document, PdfDocument, AudioDocument, ImageDocument, TextDocument
def classify_documents(data_documents: list[Data]) -> list[Document]:
documents = [
PdfDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "pdf" else
AudioDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "audio" else
ImageDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "image" else
TextDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location)
for data_item in data_documents
]
return documents

View file

@ -0,0 +1,3 @@
from .classify_documents import classify_documents
from .extract_chunks_from_documents import extract_chunks_from_documents
from .check_permissions_on_documents import check_permissions_on_documents

View file

@ -0,0 +1,13 @@
from cognee.modules.data.models import Data
from cognee.modules.data.processing.document_types import Document, PdfDocument, AudioDocument, ImageDocument, TextDocument
def classify_documents(data_documents: list[Data]) -> list[Document]:
documents = [
PdfDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "pdf" else
AudioDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "audio" else
ImageDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "image" else
TextDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location)
for data_item in data_documents
]
return documents

View file

@ -0,0 +1,7 @@
from cognee.modules.data.processing.document_types.Document import Document
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024):
for document in documents:
for document_chunk in document.read(chunk_size = chunk_size):
yield document_chunk

View file

@ -1,2 +1,3 @@
from .chunks_into_graph import chunks_into_graph from .extract_graph_from_data import extract_graph_from_data
from .extract_graph_from_code import extract_graph_from_code
from .query_graph_connections import query_graph_connections from .query_graph_connections import query_graph_connections

View file

@ -1,213 +0,0 @@
import json
import asyncio
from uuid import uuid5, NAMESPACE_OID
from datetime import datetime, timezone
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
from cognee.modules.data.extraction.knowledge_graph.extract_content_graph import extract_content_graph
from cognee.modules.chunking import DocumentChunk
from cognee.modules.graph.utils import generate_node_id, generate_node_name
class EntityNode(BaseModel):
uuid: str
name: str
type: str
description: str
created_at: datetime
updated_at: datetime
async def chunks_into_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel], collection_name: str):
chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
has_collection = await vector_engine.has_collection(collection_name)
if not has_collection:
await vector_engine.create_collection(collection_name, payload_schema = EntityNode)
processed_nodes = {}
type_node_edges = []
entity_node_edges = []
type_entity_edges = []
for (chunk_index, chunk) in enumerate(data_chunks):
chunk_graph = chunk_graphs[chunk_index]
for node in chunk_graph.nodes:
type_node_id = generate_node_id(node.type)
entity_node_id = generate_node_id(node.id)
if type_node_id not in processed_nodes:
type_node_edges.append((str(chunk.chunk_id), type_node_id, "contains_entity_type"))
processed_nodes[type_node_id] = True
if entity_node_id not in processed_nodes:
entity_node_edges.append((str(chunk.chunk_id), entity_node_id, "contains_entity"))
type_entity_edges.append((entity_node_id, type_node_id, "is_entity_type"))
processed_nodes[entity_node_id] = True
graph_node_edges = [
(edge.target_node_id, edge.source_node_id, edge.relationship_name) \
for edge in chunk_graph.edges
]
existing_edges = await graph_engine.has_edges([
*type_node_edges,
*entity_node_edges,
*type_entity_edges,
*graph_node_edges,
])
existing_edges_map = {}
existing_nodes_map = {}
for edge in existing_edges:
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
existing_nodes_map[edge[0]] = True
graph_nodes = []
graph_edges = []
data_points = []
for (chunk_index, chunk) in enumerate(data_chunks):
graph = chunk_graphs[chunk_index]
if graph is None:
continue
for node in graph.nodes:
node_id = generate_node_id(node.id)
node_name = generate_node_name(node.name)
type_node_id = generate_node_id(node.type)
type_node_name = generate_node_name(node.type)
if node_id not in existing_nodes_map:
node_data = dict(
uuid = node_id,
name = node_name,
type = node_name,
description = node.description,
created_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
)
graph_nodes.append((
node_id,
dict(
**node_data,
properties = json.dumps(node.properties),
)
))
data_points.append(DataPoint[EntityNode](
id = str(uuid5(NAMESPACE_OID, node_id)),
payload = node_data,
embed_field = "name",
))
existing_nodes_map[node_id] = True
edge_key = str(chunk.chunk_id) + node_id + "contains_entity"
if edge_key not in existing_edges_map:
graph_edges.append((
str(chunk.chunk_id),
node_id,
"contains_entity",
dict(
relationship_name = "contains_entity",
source_node_id = str(chunk.chunk_id),
target_node_id = node_id,
),
))
# Add relationship between entity type and entity itself: "Jake is Person"
graph_edges.append((
node_id,
type_node_id,
"is_entity_type",
dict(
relationship_name = "is_entity_type",
source_node_id = type_node_id,
target_node_id = node_id,
),
))
existing_edges_map[edge_key] = True
if type_node_id not in existing_nodes_map:
type_node_data = dict(
uuid = type_node_id,
name = type_node_name,
type = type_node_id,
description = type_node_name,
created_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
)
graph_nodes.append((type_node_id, dict(
**type_node_data,
properties = json.dumps(node.properties)
)))
data_points.append(DataPoint[EntityNode](
id = str(uuid5(NAMESPACE_OID, type_node_id)),
payload = type_node_data,
embed_field = "name",
))
existing_nodes_map[type_node_id] = True
edge_key = str(chunk.chunk_id) + type_node_id + "contains_entity_type"
if edge_key not in existing_edges_map:
graph_edges.append((
str(chunk.chunk_id),
type_node_id,
"contains_entity_type",
dict(
relationship_name = "contains_entity_type",
source_node_id = str(chunk.chunk_id),
target_node_id = type_node_id,
),
))
existing_edges_map[edge_key] = True
# Add relationship that came from graphs.
for edge in graph.edges:
source_node_id = generate_node_id(edge.source_node_id)
target_node_id = generate_node_id(edge.target_node_id)
relationship_name = generate_node_name(edge.relationship_name)
edge_key = source_node_id + target_node_id + relationship_name
if edge_key not in existing_edges_map:
graph_edges.append((
generate_node_id(edge.source_node_id),
generate_node_id(edge.target_node_id),
edge.relationship_name,
dict(
relationship_name = generate_node_name(edge.relationship_name),
source_node_id = generate_node_id(edge.source_node_id),
target_node_id = generate_node_id(edge.target_node_id),
properties = json.dumps(edge.properties),
),
))
existing_edges_map[edge_key] = True
if len(data_points) > 0:
await vector_engine.create_data_points(collection_name, data_points)
if len(graph_nodes) > 0:
await graph_engine.add_nodes(graph_nodes)
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
return data_chunks

Some files were not shown because too many files have changed in this diff Show more