Merge branch 'main' into COG-575-remove-graph-overwrite-on-error

This commit is contained in:
Vasilije 2024-11-12 10:18:09 +01:00 committed by GitHub
commit be792a7ba6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
127 changed files with 3122 additions and 2951 deletions

View file

@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@ -22,7 +22,7 @@ jobs:
run_neo4j_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:

View file

@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
@ -23,7 +23,7 @@ jobs:
run_notebook_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:
run:

View file

@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
@ -23,7 +23,7 @@ jobs:
run_pgvector_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:
run:

View file

@ -5,10 +5,10 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ${{ github.event.label.name == 'run-checks' }}
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
env:
@ -22,7 +22,7 @@ jobs:
run_common:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false

View file

@ -5,10 +5,10 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ${{ github.event.label.name == 'run-checks' }}
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
env:
@ -22,7 +22,7 @@ jobs:
run_common:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false

View file

@ -5,10 +5,10 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ${{ github.event.label.name == 'run-checks' }}
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
env:
@ -22,7 +22,7 @@ jobs:
run_common:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false

View file

@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
@ -23,7 +23,7 @@ jobs:
run_qdrant_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:

View file

@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
@ -23,7 +23,7 @@ jobs:
run_weaviate_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:

View file

@ -109,24 +109,34 @@ import asyncio
from cognee.api.v1.search import SearchType
async def main():
await cognee.prune.prune_data() # Reset cognee data
await cognee.prune.prune_system(metadata=True) # Reset cognee system state
# Reset cognee data
await cognee.prune.prune_data()
# Reset cognee system state
await cognee.prune.prune_system(metadata=True)
text = """
Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval.
"""
await cognee.add(text) # Add text to cognee
await cognee.cognify() # Use LLMs and cognee to create knowledge graph
# Add text to cognee
await cognee.add(text)
search_results = await cognee.search( # Search cognee for insights
# Use LLMs and cognee to create knowledge graph
await cognee.cognify()
# Search cognee for insights
search_results = await cognee.search(
SearchType.INSIGHTS,
{'query': 'Tell me about NLP'}
"Tell me about NLP",
)
for result_text in search_results: # Display results
# Display results
for result_text in search_results:
print(result_text)
# natural_language_processing is_a field
# natural_language_processing is_subfield_of computer_science
# natural_language_processing is_subfield_of information_retrieval
asyncio.run(main())
```

View file

@ -0,0 +1,110 @@
import asyncio
import logging
from typing import Union
from cognee.shared.SourceCodeGraph import SourceCodeGraph
from cognee.shared.utils import send_telemetry
from cognee.modules.data.models import Dataset, Data
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines import run_tasks
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
from cognee.tasks.graph import extract_graph_from_code
from cognee.tasks.storage import add_data_points
logger = logging.getLogger("code_graph_pipeline")
update_status_lock = asyncio.Lock()
class PermissionDeniedException(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None):
if user is None:
user = await get_default_user()
existing_datasets = await get_datasets(user.id)
if datasets is None or len(datasets) == 0:
# If no datasets are provided, cognify all existing datasets.
datasets = existing_datasets
if type(datasets[0]) == str:
datasets = await get_datasets_by_name(datasets, user.id)
existing_datasets_map = {
generate_dataset_name(dataset.name): True for dataset in existing_datasets
}
awaitables = []
for dataset in datasets:
dataset_name = generate_dataset_name(dataset.name)
if dataset_name in existing_datasets_map:
awaitables.append(run_pipeline(dataset, user))
return await asyncio.gather(*awaitables)
async def run_pipeline(dataset: Dataset, user: User):
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
document_ids_str = [str(document.id) for document in data_documents]
dataset_id = dataset.id
dataset_name = generate_dataset_name(dataset.name)
send_telemetry("code_graph_pipeline EXECUTION STARTED", user.id)
async with update_status_lock:
task_status = await get_pipeline_status([dataset_id])
if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
logger.info("Dataset %s is already being processed.", dataset_name)
return
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
try:
tasks = [
Task(classify_documents),
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
Task(add_data_points, task_config = { "batch_size": 10 }),
Task(extract_graph_from_code, graph_model = SourceCodeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
]
pipeline = run_tasks(tasks, data_documents, "code_graph_pipeline")
async for result in pipeline:
print(result)
send_telemetry("code_graph_pipeline EXECUTION COMPLETED", user.id)
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_COMPLETED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
except Exception as error:
send_telemetry("code_graph_pipeline EXECUTION ERRORED", user.id)
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_ERRORED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
raise error
def generate_dataset_name(dataset_name: str) -> str:
return dataset_name.replace(".", "_").replace(" ", "_")

View file

@ -9,21 +9,15 @@ from cognee.modules.data.models import Dataset, Data
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines import run_tasks, run_tasks_parallel
from cognee.modules.pipelines import run_tasks
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
from cognee.tasks import chunk_naive_llm_classifier, \
chunk_remove_disconnected, \
infer_data_ontology, \
save_chunks_to_store, \
chunk_update_check, \
chunks_into_graph, \
source_documents_to_chunks, \
check_permissions_on_documents, \
classify_documents
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_text
logger = logging.getLogger("cognify.v2")
@ -87,31 +81,17 @@ async def run_cognify_pipeline(dataset: Dataset, user: User):
try:
cognee_config = get_cognify_config()
root_node_id = None
tasks = [
Task(classify_documents),
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
Task(infer_data_ontology, root_node_id = root_node_id, ontology_model = KnowledgeGraph),
Task(source_documents_to_chunks, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
Task(chunks_into_graph, graph_model = KnowledgeGraph, collection_name = "entities", task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
Task(chunk_update_check, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
Task(add_data_points, task_config = { "batch_size": 10 }),
Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
Task(
save_chunks_to_store,
collection_name = "chunks",
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
run_tasks_parallel([
Task(
summarize_text,
summarization_model = cognee_config.summarization_model,
collection_name = "summaries",
),
Task(
chunk_naive_llm_classifier,
classification_model = cognee_config.classification_model,
),
]),
Task(chunk_remove_disconnected), # Remove the obsolete document chunks.
summarize_text,
summarization_model = cognee_config.summarization_model,
task_config = { "batch_size": 10 }
),
]
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")

View file

@ -5,7 +5,7 @@ from cognee.shared.utils import send_telemetry
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.tasks.chunking import query_chunks
from cognee.tasks.chunks import query_chunks
from cognee.tasks.graph import query_graph_connections
from cognee.tasks.summarization import query_summaries

View file

@ -1,198 +0,0 @@
""" FalcorDB Adapter for Graph Database"""
import json
import logging
from typing import Optional, Any, List, Dict
from contextlib import asynccontextmanager
from falkordb.asyncio import FalkorDB
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
logger = logging.getLogger("FalcorDBAdapter")
class FalcorDBAdapter(GraphDBInterface):
def __init__(
self,
graph_database_url: str,
graph_database_username: str,
graph_database_password: str,
graph_database_port: int,
driver: Optional[Any] = None,
graph_name: str = "DefaultGraph",
):
self.driver = FalkorDB(
host = graph_database_url,
port = graph_database_port)
self.graph_name = graph_name
async def query(
self,
query: str,
params: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
try:
selected_graph = self.driver.select_graph(self.graph_name)
result = await selected_graph.query(query)
return result.result_set
except Exception as error:
logger.error("Falkor query error: %s", error, exc_info = True)
raise error
async def graph(self):
return self.driver
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
node_id = node_id.replace(":", "_")
serialized_properties = self.serialize_properties(node_properties)
if "name" not in serialized_properties:
serialized_properties["name"] = node_id
# serialized_properties["created_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# serialized_properties["updated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# properties = ", ".join(f"{property_name}: ${property_name}" for property_name in serialized_properties.keys())
query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
ON CREATE SET node += $properties
RETURN ID(node) AS internal_id, node.id AS nodeId"""
params = {
"node_id": node_id,
"properties": serialized_properties,
}
return await self.query(query, params)
async def add_nodes(self, nodes: list[tuple[str, dict[str, Any]]]) -> None:
for node in nodes:
node_id, node_properties = node
node_id = node_id.replace(":", "_")
await self.add_node(
node_id = node_id,
node_properties = node_properties,
)
async def extract_node_description(self, node_id: str):
query = """MATCH (n)-[r]->(m)
WHERE n.id = $node_id
AND NOT m.id CONTAINS 'DefaultGraphModel'
RETURN m
"""
result = await self.query(query, dict(node_id = node_id))
descriptions = []
for node in result:
# Assuming 'm' is a consistent key in your data structure
attributes = node.get("m", {})
# Ensure all required attributes are present
if all(key in attributes for key in ["id", "layer_id", "description"]):
descriptions.append({
"id": attributes["id"],
"layer_id": attributes["layer_id"],
"description": attributes["description"],
})
return descriptions
async def get_layer_nodes(self):
query = """MATCH (node) WHERE node.layer_id IS NOT NULL
RETURN node"""
return [result["node"] for result in (await self.query(query))]
async def extract_node(self, node_id: str):
results = self.extract_nodes([node_id])
return results[0] if len(results) > 0 else None
async def extract_nodes(self, node_ids: List[str]):
query = """
UNWIND $node_ids AS id
MATCH (node {id: id})
RETURN node"""
params = {
"node_ids": node_ids
}
results = await self.query(query, params)
return results
async def delete_node(self, node_id: str):
node_id = id.replace(":", "_")
query = f"MATCH (node:`{node_id}` {{id: $node_id}}) DETACH DELETE n"
params = { "node_id": node_id }
return await self.query(query, params)
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
serialized_properties = self.serialize_properties(edge_properties)
from_node = from_node.replace(":", "_")
to_node = to_node.replace(":", "_")
query = f"""MATCH (from_node:`{from_node}` {{id: $from_node}}), (to_node:`{to_node}` {{id: $to_node}})
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
SET r += $properties
RETURN r"""
params = {
"from_node": from_node,
"to_node": to_node,
"properties": serialized_properties
}
return await self.query(query, params)
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
# edges_data = []
for edge in edges:
from_node, to_node, relationship_name, edge_properties = edge
from_node = from_node.replace(":", "_")
to_node = to_node.replace(":", "_")
await self.add_edge(
from_node = from_node,
to_node = to_node,
relationship_name = relationship_name,
edge_properties = edge_properties
)
async def filter_nodes(self, search_criteria):
query = f"""MATCH (node)
WHERE node.id CONTAINS '{search_criteria}'
RETURN node"""
return await self.query(query)
async def delete_graph(self):
query = """MATCH (node)
DETACH DELETE node;"""
return await self.query(query)
def serialize_properties(self, properties = dict()):
return {
property_key: json.dumps(property_value)
if isinstance(property_value, (dict, list))
else property_value for property_key, property_value in properties.items()
}

View file

@ -2,7 +2,6 @@
from .config import get_graph_config
from .graph_db_interface import GraphDBInterface
from .networkx.adapter import NetworkXAdapter
async def get_graph_engine() -> GraphDBInterface :
@ -21,19 +20,19 @@ async def get_graph_engine() -> GraphDBInterface :
except:
pass
elif config.graph_database_provider == "falkorb":
try:
from .falkordb.adapter import FalcorDBAdapter
elif config.graph_database_provider == "falkordb":
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalcorDBAdapter(
graph_database_url = config.graph_database_url,
graph_database_username = config.graph_database_username,
graph_database_password = config.graph_database_password,
graph_database_port = config.graph_database_port
)
except:
pass
embedding_engine = get_embedding_engine()
return FalkorDBAdapter(
database_url = config.graph_database_url,
database_port = config.graph_database_port,
embedding_engine = embedding_engine,
)
from .networkx.adapter import NetworkXAdapter
graph_client = NetworkXAdapter(filename = config.graph_file_path)
if graph_client.graph is None:

View file

@ -3,7 +3,7 @@ from abc import abstractmethod
class GraphDBInterface(Protocol):
@abstractmethod
async def graph(self):
async def query(self, query: str, params: dict):
raise NotImplementedError
@abstractmethod

View file

@ -1,13 +1,14 @@
""" Neo4j Adapter for Graph Database"""
import json
import logging
import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict
from contextlib import asynccontextmanager
from uuid import UUID
from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError
from networkx import predecessor
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
logger = logging.getLogger("Neo4jAdapter")
@ -41,17 +42,13 @@ class Neo4jAdapter(GraphDBInterface):
) -> List[Dict[str, Any]]:
try:
async with self.get_session() as session:
result = await session.run(query, parameters=params)
result = await session.run(query, parameters = params)
data = await result.data()
await self.close()
return data
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info = True)
raise error
async def graph(self):
return await self.get_session()
async def has_node(self, node_id: str) -> bool:
results = self.query(
"""
@ -63,73 +60,40 @@ class Neo4jAdapter(GraphDBInterface):
)
return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(self, node_id: str, node_properties: Dict[str, Any] = None):
node_id = node_id.replace(":", "_")
async def add_node(self, node: DataPoint):
serialized_properties = self.serialize_properties(node.model_dump())
serialized_properties = self.serialize_properties(node_properties)
if "name" not in serialized_properties:
serialized_properties["name"] = node_id
query = f"""MERGE (node:`{node_id}` {{id: $node_id}})
ON CREATE SET node += $properties
RETURN ID(node) AS internal_id, node.id AS nodeId"""
query = dedent("""MERGE (node {id: $node_id})
ON CREATE SET node += $properties, node.updated_at = timestamp()
ON MATCH SET node += $properties, node.updated_at = timestamp()
RETURN ID(node) AS internal_id, node.id AS nodeId""")
params = {
"node_id": node_id,
"node_id": str(node.id),
"properties": serialized_properties,
}
return await self.query(query, params)
async def add_nodes(self, nodes: list[tuple[str, dict[str, Any]]]) -> None:
async def add_nodes(self, nodes: list[DataPoint]) -> None:
query = """
UNWIND $nodes AS node
MERGE (n {id: node.node_id})
ON CREATE SET n += node.properties
ON CREATE SET n += node.properties, n.updated_at = timestamp()
ON MATCH SET n += node.properties, n.updated_at = timestamp()
WITH n, node.node_id AS label
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
"""
nodes = [{
"node_id": node_id,
"properties": self.serialize_properties(node_properties),
} for (node_id, node_properties) in nodes]
"node_id": str(node.id),
"properties": self.serialize_properties(node.model_dump()),
} for node in nodes]
results = await self.query(query, dict(nodes = nodes))
return results
async def extract_node_description(self, node_id: str):
query = """MATCH (n)-[r]->(m)
WHERE n.id = $node_id
AND NOT m.id CONTAINS 'DefaultGraphModel'
RETURN m
"""
result = await self.query(query, dict(node_id = node_id))
descriptions = []
for node in result:
# Assuming 'm' is a consistent key in your data structure
attributes = node.get("m", {})
# Ensure all required attributes are present
if all(key in attributes for key in ["id", "layer_id", "description"]):
descriptions.append({
"id": attributes["id"],
"layer_id": attributes["layer_id"],
"description": attributes["description"],
})
return descriptions
async def get_layer_nodes(self):
query = """MATCH (node) WHERE node.layer_id IS NOT NULL
RETURN node"""
return [result["node"] for result in (await self.query(query))]
async def extract_node(self, node_id: str):
results = await self.extract_nodes([node_id])
@ -170,13 +134,20 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params)
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
query = f"""
MATCH (from_node:`{from_node}`)-[relationship:`{edge_label}`]->(to_node:`{to_node}`)
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
query = """
MATCH (from_node)-[relationship]->(to_node)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
RETURN COUNT(relationship) > 0 AS edge_exists
"""
edge_exists = await self.query(query)
params = {
"from_node_id": str(from_node),
"to_node_id": str(to_node),
"edge_label": edge_label,
}
edge_exists = await self.query(query, params)
return edge_exists
async def has_edges(self, edges):
@ -190,8 +161,8 @@ class Neo4jAdapter(GraphDBInterface):
try:
params = {
"edges": [{
"from_node": edge[0],
"to_node": edge[1],
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
} for edge in edges],
}
@ -203,21 +174,21 @@ class Neo4jAdapter(GraphDBInterface):
raise error
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
serialized_properties = self.serialize_properties(edge_properties)
from_node = from_node.replace(":", "_")
to_node = to_node.replace(":", "_")
query = f"""MATCH (from_node:`{from_node}`
{{id: $from_node}}),
(to_node:`{to_node}` {{id: $to_node}})
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
SET r += $properties
RETURN r"""
query = dedent("""MATCH (from_node {id: $from_node}),
(to_node {id: $to_node})
MERGE (from_node)-[r]->(to_node)
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r
""")
params = {
"from_node": from_node,
"to_node": to_node,
"from_node": str(from_node),
"to_node": str(to_node),
"relationship_name": relationship_name,
"properties": serialized_properties
}
@ -234,13 +205,13 @@ class Neo4jAdapter(GraphDBInterface):
"""
edges = [{
"from_node": edge[0],
"to_node": edge[1],
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
"properties": {
**(edge[3] if edge[3] else {}),
"source_node_id": edge[0],
"target_node_id": edge[1],
"source_node_id": str(edge[0]),
"target_node_id": str(edge[1]),
},
} for edge in edges]
@ -300,14 +271,6 @@ class Neo4jAdapter(GraphDBInterface):
return results[0]["ids"] if len(results) > 0 else []
async def filter_nodes(self, search_criteria):
query = f"""MATCH (node)
WHERE node.id CONTAINS '{search_criteria}'
RETURN node"""
return await self.query(query)
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
if edge_label is not None:
query = """
@ -379,7 +342,7 @@ class Neo4jAdapter(GraphDBInterface):
return predecessors + successors
async def get_connections(self, node_id: str) -> list:
async def get_connections(self, node_id: UUID) -> list:
predecessors_query = """
MATCH (node)<-[relation]-(neighbour)
WHERE node.id = $node_id
@ -392,8 +355,8 @@ class Neo4jAdapter(GraphDBInterface):
"""
predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = node_id)),
self.query(successors_query, dict(node_id = node_id)),
self.query(predecessors_query, dict(node_id = str(node_id))),
self.query(successors_query, dict(node_id = str(node_id))),
)
connections = []
@ -438,15 +401,22 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query)
def serialize_properties(self, properties = dict()):
return {
property_key: json.dumps(property_value)
if isinstance(property_value, (dict, list))
else property_value for property_key, property_value in properties.items()
}
serialized_properties = {}
for property_key, property_value in properties.items():
if isinstance(property_value, UUID):
serialized_properties[property_key] = str(property_value)
continue
serialized_properties[property_key] = property_value
return serialized_properties
async def get_graph_data(self):
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
result = await self.query(query)
nodes = [(
record["properties"]["id"],
record["properties"],

View file

@ -1,14 +1,19 @@
"""Adapter for NetworkX graph database."""
from datetime import datetime, timezone
import os
import json
import asyncio
import logging
from re import A
from typing import Dict, Any, List
from uuid import UUID
import aiofiles
import aiofiles.os as aiofiles_os
import networkx as nx
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder
logger = logging.getLogger("NetworkXAdapter")
@ -25,29 +30,38 @@ class NetworkXAdapter(GraphDBInterface):
def __init__(self, filename = "cognee_graph.pkl"):
self.filename = filename
async def get_graph_data(self):
await self.load_graph_from_file()
return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True)))
async def query(self, query: str, params: dict):
pass
async def has_node(self, node_id: str) -> bool:
return self.graph.has_node(node_id)
async def add_node(
self,
node_id: str,
node_properties,
node: DataPoint,
) -> None:
if not self.graph.has_node(id):
self.graph.add_node(node_id, **node_properties)
await self.save_graph_to_file(self.filename)
self.graph.add_node(node.id, **node.model_dump())
await self.save_graph_to_file(self.filename)
async def add_nodes(
self,
nodes: List[tuple[str, dict]],
nodes: list[DataPoint],
) -> None:
nodes = [(node.id, node.model_dump()) for node in nodes]
self.graph.add_nodes_from(nodes)
await self.save_graph_to_file(self.filename)
async def get_graph(self):
return self.graph
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
return self.graph.has_edge(from_node, to_node, key = edge_label)
@ -55,18 +69,20 @@ class NetworkXAdapter(GraphDBInterface):
result = []
for (from_node, to_node, edge_label) in edges:
if await self.has_edge(from_node, to_node, edge_label):
if self.graph.has_edge(from_node, to_node, edge_label):
result.append((from_node, to_node, edge_label))
return result
async def add_edge(
self,
from_node: str,
to_node: str,
relationship_name: str,
edge_properties: Dict[str, Any] = None,
edge_properties: Dict[str, Any] = {},
) -> None:
edge_properties["updated_at"] = datetime.now(timezone.utc)
self.graph.add_edge(from_node, to_node, key = relationship_name, **(edge_properties if edge_properties else {}))
await self.save_graph_to_file(self.filename)
@ -74,22 +90,29 @@ class NetworkXAdapter(GraphDBInterface):
self,
edges: tuple[str, str, str, dict],
) -> None:
edges = [(edge[0], edge[1], edge[2], {
**(edge[3] if len(edge) == 4 else {}),
"updated_at": datetime.now(timezone.utc),
}) for edge in edges]
self.graph.add_edges_from(edges)
await self.save_graph_to_file(self.filename)
async def get_edges(self, node_id: str):
return list(self.graph.in_edges(node_id, data = True)) + list(self.graph.out_edges(node_id, data = True))
async def delete_node(self, node_id: str) -> None:
"""Asynchronously delete a node from the graph if it exists."""
if self.graph.has_node(id):
self.graph.remove_node(id)
if self.graph.has_node(node_id):
self.graph.remove_node(node_id)
await self.save_graph_to_file(self.filename)
async def delete_nodes(self, node_ids: List[str]) -> None:
self.graph.remove_nodes_from(node_ids)
await self.save_graph_to_file(self.filename)
async def get_disconnected_nodes(self) -> List[str]:
connected_components = list(nx.weakly_connected_components(self.graph))
@ -102,33 +125,6 @@ class NetworkXAdapter(GraphDBInterface):
return disconnected_nodes
async def extract_node_description(self, node_id: str) -> Dict[str, Any]:
descriptions = []
if self.graph.has_node(node_id):
# Get the attributes of the node
for neighbor in self.graph.neighbors(node_id):
# Get the attributes of the neighboring node
attributes = self.graph.nodes[neighbor]
# Ensure all required attributes are present before extracting description
if all(key in attributes for key in ["id", "layer_id", "description"]):
descriptions.append({
"id": attributes["id"],
"layer_id": attributes["layer_id"],
"description": attributes["description"],
})
return descriptions
async def get_layer_nodes(self):
layer_nodes = []
for _, data in self.graph.nodes(data = True):
if "layer_id" in data:
layer_nodes.append(data)
return layer_nodes
async def extract_node(self, node_id: str) -> dict:
if self.graph.has_node(node_id):
@ -139,7 +135,7 @@ class NetworkXAdapter(GraphDBInterface):
async def extract_nodes(self, node_ids: List[str]) -> List[dict]:
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list:
async def get_predecessors(self, node_id: UUID, edge_label: str = None) -> list:
if self.graph.has_node(node_id):
if edge_label is None:
return [
@ -155,7 +151,7 @@ class NetworkXAdapter(GraphDBInterface):
return nodes
async def get_successors(self, node_id: str, edge_label: str = None) -> list:
async def get_successors(self, node_id: UUID, edge_label: str = None) -> list:
if self.graph.has_node(node_id):
if edge_label is None:
return [
@ -184,13 +180,13 @@ class NetworkXAdapter(GraphDBInterface):
return neighbours
async def get_connections(self, node_id: str) -> list:
async def get_connections(self, node_id: UUID) -> list:
if not self.graph.has_node(node_id):
return []
node = self.graph.nodes[node_id]
if "uuid" not in node:
if "id" not in node:
return []
predecessors, successors = await asyncio.gather(
@ -201,14 +197,14 @@ class NetworkXAdapter(GraphDBInterface):
connections = []
for neighbor in predecessors:
if "uuid" in neighbor:
edge_data = self.graph.get_edge_data(neighbor["uuid"], node["uuid"])
if "id" in neighbor:
edge_data = self.graph.get_edge_data(neighbor["id"], node["id"])
for edge_properties in edge_data.values():
connections.append((neighbor, edge_properties, node))
for neighbor in successors:
if "uuid" in neighbor:
edge_data = self.graph.get_edge_data(node["uuid"], neighbor["uuid"])
if "id" in neighbor:
edge_data = self.graph.get_edge_data(node["id"], neighbor["id"])
for edge_properties in edge_data.values():
connections.append((node, edge_properties, neighbor))
@ -240,7 +236,7 @@ class NetworkXAdapter(GraphDBInterface):
graph_data = nx.readwrite.json_graph.node_link_data(self.graph)
async with aiofiles.open(file_path, "w") as file:
await file.write(json.dumps(graph_data))
await file.write(json.dumps(graph_data, cls = JSONEncoder))
async def load_graph_from_file(self, file_path: str = None):
@ -254,6 +250,29 @@ class NetworkXAdapter(GraphDBInterface):
if os.path.exists(file_path):
async with aiofiles.open(file_path, "r") as file:
graph_data = json.loads(await file.read())
for node in graph_data["nodes"]:
try:
node["id"] = UUID(node["id"])
except:
pass
if "updated_at" in node:
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
for edge in graph_data["links"]:
try:
source_id = UUID(edge["source"])
target_id = UUID(edge["target"])
edge["source"] = source_id
edge["target"] = target_id
edge["source_node_id"] = source_id
edge["target_node_id"] = target_id
except:
pass
if "updated_at" in edge:
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
else:
# Log that the file does not exist and an empty graph is initialized
@ -265,9 +284,11 @@ class NetworkXAdapter(GraphDBInterface):
os.makedirs(file_dir, exist_ok = True)
await self.save_graph_to_file(file_path)
except Exception:
logger.error("Failed to load graph from file: %s", file_path)
async def delete_graph(self, file_path: str = None):
"""Asynchronously delete the graph file from the filesystem."""
if file_path is None:

View file

@ -0,0 +1,267 @@
import asyncio
from textwrap import dedent
from typing import Any
from uuid import UUID
from falkordb import FalkorDB
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface
class IndexSchema(DataPoint):
text: str
_metadata: dict = {
"index_fields": ["text"]
}
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
def __init__(
self,
database_url: str,
database_port: int,
embedding_engine = EmbeddingEngine,
):
self.driver = FalkorDB(
host = database_url,
port = database_port,
)
self.embedding_engine = embedding_engine
self.graph_name = "cognee_graph"
def query(self, query: str, params: dict = {}):
graph = self.driver.select_graph(self.graph_name)
try:
result = graph.query(query, params)
return result
except Exception as e:
print(f"Error executing query: {e}")
raise e
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def stringify_properties(self, properties: dict, vectorize_fields = []) -> str:
async def get_value(key, value):
return f"'{value}'" if key not in vectorize_fields else await self.get_vectorized_value(value)
return ",".join([f"{key}:{await get_value(key, value)}" for key, value in properties.items()])
async def get_vectorized_value(self, value: Any) -> str:
vector = (await self.embed_data([value]))[0]
return f"vecf32({vector})"
async def create_data_point_query(self, data_point: DataPoint):
node_label = type(data_point).__name__
node_properties = await self.stringify_properties(
data_point.model_dump(),
data_point._metadata["index_fields"],
# data_point._metadata["index_fields"] if hasattr(data_point, "_metadata") else [],
)
return dedent(f"""
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
ON CREATE SET node += ({{{node_properties}}})
ON CREATE SET node.updated_at = timestamp()
ON MATCH SET node += ({{{node_properties}}})
ON MATCH SET node.updated_at = timestamp()
""").strip()
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
properties = await self.stringify_properties(edge[3])
properties = f"{{{properties}}}"
return dedent(f"""
MERGE (source {{id:'{edge[0]}'}})
MERGE (target {{id: '{edge[1]}'}})
MERGE (source)-[edge:{edge[2]} {properties}]->(target)
ON MATCH SET edge.updated_at = timestamp()
ON CREATE SET edge.updated_at = timestamp()
""").strip()
async def create_collection(self, collection_name: str):
pass
async def has_collection(self, collection_name: str) -> bool:
collections = self.driver.list_graphs()
return collection_name in collections
async def create_data_points(self, data_points: list[DataPoint]):
queries = [await self.create_data_point_query(data_point) for data_point in data_points]
for query in queries:
self.query(query)
async def create_vector_index(self, index_name: str, index_property_name: str):
graph = self.driver.select_graph(self.graph_name)
if not self.has_vector_index(graph, index_name, index_property_name):
graph.create_node_vector_index(index_name, index_property_name, dim = self.embedding_engine.get_vector_size())
def has_vector_index(self, graph, index_name: str, index_property_name: str) -> bool:
try:
indices = graph.list_indices()
return any([(index[0] == index_name and index_property_name in index[1]) for index in indices.result_set])
except:
return False
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
pass
async def add_node(self, node: DataPoint):
await self.create_data_points([node])
async def add_nodes(self, nodes: list[DataPoint]):
await self.create_data_points(nodes)
async def add_edge(self, edge: tuple[str, str, str, dict]):
query = await self.create_edge_query(edge)
self.query(query)
async def add_edges(self, edges: list[tuple[str, str, str, dict]]):
queries = [await self.create_edge_query(edge) for edge in edges]
for query in queries:
self.query(query)
async def has_edges(self, edges):
query = dedent("""
UNWIND $edges AS edge
MATCH (a)-[r]->(b)
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
""").strip()
params = {
"edges": [{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
} for edge in edges],
}
results = self.query(query, params).result_set
return [result["edge_exists"] for result in results]
async def retrieve(self, data_point_ids: list[str]):
return self.query(
f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
{
"node_ids": data_point_ids,
},
)
async def extract_node(self, data_point_id: str):
return await self.retrieve([data_point_id])
async def extract_nodes(self, data_point_ids: list[str]):
return await self.retrieve(data_point_ids)
async def get_connections(self, node_id: UUID) -> list:
predecessors_query = """
MATCH (node)<-[relation]-(neighbour)
WHERE node.id = $node_id
RETURN neighbour, relation, node
"""
successors_query = """
MATCH (node)-[relation]->(neighbour)
WHERE node.id = $node_id
RETURN node, relation, neighbour
"""
predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = node_id)),
self.query(successors_query, dict(node_id = node_id)),
)
connections = []
for neighbour in predecessors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
for neighbour in successors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
return connections
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: list[float] = None,
limit: int = 10,
with_vector: bool = False,
):
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embed_data([query_text]))[0]
query = dedent(f"""
CALL db.idx.vector.queryNodes(
{collection_name},
'text',
{limit},
vecf32({query_vector})
) YIELD node, score
""").strip()
result = self.query(query)
return result
async def batch_search(
self,
collection_name: str,
query_texts: list[str],
limit: int = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[self.search(
collection_name = collection_name,
query_vector = query_vector,
limit = limit,
with_vector = with_vectors,
) for query_vector in query_vectors]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
return self.query(
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
{
"node_ids": data_point_ids,
},
)
async def delete_node(self, collection_name: str, data_point_id: str):
return await self.delete_data_points([data_point_id])
async def delete_nodes(self, collection_name: str, data_point_ids: list[str]):
self.delete_data_points(data_point_ids)
async def delete_graph(self):
try:
graph = self.driver.select_graph(self.graph_name)
indices = graph.list_indices()
for index in indices.result_set:
for field in index[1]:
graph.drop_node_vector_index(index[0], field)
graph.delete()
except Exception as e:
print(f"Error deleting graph: {e}")
async def prune(self):
self.delete_graph()

View file

@ -1,4 +1,3 @@
from .models.DataPoint import DataPoint
from .models.VectorConfig import VectorConfig
from .models.CollectionConfig import CollectionConfig
from .vector_db_interface import VectorDBInterface

View file

@ -8,6 +8,7 @@ class VectorConfig(BaseSettings):
os.path.join(get_absolute_path(".cognee_system"), "databases"),
"cognee.lancedb"
)
vector_db_port: int = 1234
vector_db_key: str = ""
vector_db_provider: str = "lancedb"
@ -16,6 +17,7 @@ class VectorConfig(BaseSettings):
def to_dict(self) -> dict:
return {
"vector_db_url": self.vector_db_url,
"vector_db_port": self.vector_db_port,
"vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider,
}

View file

@ -1,9 +1,8 @@
from typing import Dict
from ..relational.config import get_relational_config
class VectorConfig(Dict):
vector_db_url: str
vector_db_port: str
vector_db_key: str
vector_db_provider: str
@ -29,6 +28,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
embedding_engine = embedding_engine
)
elif config["vector_db_provider"] == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config
from .pgvector.PGVectorAdapter import PGVectorAdapter
# Get configuration for postgres database
@ -43,9 +43,18 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
)
return PGVectorAdapter(connection_string,
config["vector_db_key"],
embedding_engine
return PGVectorAdapter(
connection_string,
config["vector_db_key"],
embedding_engine,
)
elif config["vector_db_provider"] == "falkordb":
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalkorDBAdapter(
database_url = config["vector_db_url"],
database_port = config["vector_db_port"],
embedding_engine = embedding_engine,
)
else:
from .lancedb.LanceDBAdapter import LanceDBAdapter

View file

@ -1,57 +0,0 @@
from typing import List, Dict, Optional, Any
from falkordb import FalkorDB
from qdrant_client import AsyncQdrantClient, models
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine
class FalcorDBAdapter(VectorDBInterface):
def __init__(
self,
graph_database_url: str,
graph_database_username: str,
graph_database_password: str,
graph_database_port: int,
driver: Optional[Any] = None,
embedding_engine = EmbeddingEngine,
graph_name: str = "DefaultGraph",
):
self.driver = FalkorDB(
host = graph_database_url,
port = graph_database_port)
self.graph_name = graph_name
self.embedding_engine = embedding_engine
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def create_collection(self, collection_name: str, payload_schema = None):
pass
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
pass
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
pass
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 10,
with_vector: bool = False,
):
pass
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
pass

View file

@ -1,12 +1,25 @@
import inspect
from typing import List, Optional, get_type_hints, Generic, TypeVar
import asyncio
from uuid import UUID
import lancedb
from pydantic import BaseModel
from lancedb.pydantic import Vector, LanceModel
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.storage.utils import copy_model, get_own_properties
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface, DataPoint
from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine
class IndexSchema(DataPoint):
id: str
text: str
_metadata: dict = {
"index_fields": ["text"]
}
class LanceDBAdapter(VectorDBInterface):
name = "LanceDB"
url: str
@ -38,10 +51,12 @@ class LanceDBAdapter(VectorDBInterface):
collection_names = await connection.table_names()
return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema = None):
data_point_types = get_type_hints(DataPoint)
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
vector_size = self.embedding_engine.get_vector_size()
payload_schema = self.get_data_point_schema(payload_schema)
data_point_types = get_type_hints(payload_schema)
class LanceDataPoint(LanceModel):
id: data_point_types["id"]
vector: Vector(vector_size)
@ -55,13 +70,16 @@ class LanceDBAdapter(VectorDBInterface):
exist_ok = True,
)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
connection = await self.get_connection()
payload_schema = type(data_points[0])
payload_schema = self.get_data_point_schema(payload_schema)
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name,
payload_schema = type(data_points[0].payload),
payload_schema,
)
collection = await connection.open_table(collection_name)
@ -79,15 +97,26 @@ class LanceDBAdapter(VectorDBInterface):
vector: Vector(vector_size)
payload: PayloadSchema
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
properties = get_own_properties(data_point)
properties["id"] = str(properties["id"])
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
id = str(data_point.id),
vector = vector,
payload = properties,
)
lance_data_points = [
LanceDataPoint[type(data_point.id), type(data_point.payload)](
id = data_point.id,
vector = data_vectors[data_index],
payload = data_point.payload,
) for (data_index, data_point) in enumerate(data_points)
create_lance_data_point(data_point, data_vectors[data_point_index])
for (data_point_index, data_point) in enumerate(data_points)
]
await collection.add(lance_data_points)
await collection.merge_insert("id") \
.when_matched_update_all() \
.when_not_matched_insert_all() \
.execute(lance_data_points)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_connection()
@ -99,7 +128,7 @@ class LanceDBAdapter(VectorDBInterface):
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
return [ScoredResult(
id = result["id"],
id = UUID(result["id"]),
payload = result["payload"],
score = 0,
) for result in results.to_dict("index").values()]
@ -135,10 +164,19 @@ class LanceDBAdapter(VectorDBInterface):
if value < min_value:
min_value = value
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
normalized_values = []
min_value = min(result["_distance"] for result in result_values)
max_value = max(result["_distance"] for result in result_values)
if max_value == min_value:
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
normalized_values = [0 for _ in result_values]
else:
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
result_values]
return [ScoredResult(
id = str(result["id"]),
id = UUID(result["id"]),
payload = result["payload"],
score = normalized_values[value_index],
) for value_index, result in enumerate(result_values)]
@ -170,7 +208,27 @@ class LanceDBAdapter(VectorDBInterface):
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
return results
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}", payload_schema = IndexSchema)
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = str(data_point.id),
text = getattr(data_point, data_point._metadata["index_fields"][0]),
) for data_point in data_points
])
async def prune(self):
# Clean up the database if it was set up as temporary
if self.url.startswith("/"):
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside
def get_data_point_schema(self, model_type):
return copy_model(
model_type,
include_fields = {
"id": (str, ...),
},
exclude_fields = ["_metadata"],
)

View file

@ -1,13 +0,0 @@
from typing import Generic, TypeVar
from pydantic import BaseModel
PayloadSchema = TypeVar("PayloadSchema", bound = BaseModel)
class DataPoint(BaseModel, Generic[PayloadSchema]):
id: str
payload: PayloadSchema
embed_field: str = "value"
def get_embeddable_data(self):
if hasattr(self.payload, self.embed_field):
return getattr(self.payload, self.embed_field)

View file

@ -1,7 +1,8 @@
from typing import Any, Dict
from uuid import UUID
from pydantic import BaseModel
class ScoredResult(BaseModel):
id: str
id: UUID
score: float # Lower score is better
payload: Dict[str, Any]

View file

@ -1,17 +1,26 @@
import asyncio
from uuid import UUID
from pgvector.sqlalchemy import Vector
from typing import List, Optional, get_type_hints
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import JSON, Column, Table, select, delete
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from .serialize_datetime import serialize_datetime
from cognee.infrastructure.engine import DataPoint
from .serialize_data import serialize_data
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface, DataPoint
from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ...relational.ModelBase import Base
class IndexSchema(DataPoint):
text: str
_metadata: dict = {
"index_fields": ["text"]
}
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
@ -45,7 +54,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
@ -71,47 +79,58 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def create_data_points(
self, collection_name: str, data_points: List[DataPoint]
):
async with self.get_async_session() as session:
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name=collection_name,
payload_schema=type(data_points[0].payload),
)
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name = collection_name,
payload_schema = type(data_points[0]),
)
vector_size = self.embedding_engine.get_vector_size()
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
)
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
vector_size = self.embedding_engine.get_vector_size()
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
pgvector_data_points = [
PGVectorDataPoint(
id=data_point.id,
vector=data_vectors[data_index],
payload=serialize_datetime(data_point.payload.dict()),
)
for (data_index, data_point) in enumerate(data_points)
]
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
pgvector_data_points = [
PGVectorDataPoint(
id = data_point.id,
vector = data_vectors[data_index],
payload = serialize_data(data_point.model_dump()),
)
for (data_index, data_point) in enumerate(data_points)
]
async with self.get_async_session() as session:
session.add_all(pgvector_data_points)
await session.commit()
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = data_point.id,
text = data_point.get_embeddable_data(),
) for data_point in data_points
])
async def get_table(self, collection_name: str) -> Table:
"""
Dynamically loads a table using the given collection name
@ -126,18 +145,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
raise ValueError(f"Table '{collection_name}' not found.")
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
async with self.get_async_session() as session:
results = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
results = results.all()
return [
ScoredResult(id=result.id, payload=result.payload, score=0)
for result in results
ScoredResult(
id = UUID(result.id),
payload = result.payload,
score = 0
) for result in results
]
async def search(
@ -154,11 +176,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
closest_items = []
# Use async session to connect to the database
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# Find closest vectors to query_vector
closest_items = await session.execute(
select(
@ -171,19 +195,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
.limit(limit)
)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
vector_list = []
# Create and return ScoredResult objects
return [
ScoredResult(
id=str(row.id), payload=row.payload, score=row.similarity
)
for row in vector_list
]
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
# Create and return ScoredResult objects
return [
ScoredResult(
id = UUID(str(row.id)),
payload = row.payload,
score = row.similarity
) for row in vector_list
]
async def batch_search(
self,

View file

@ -1,12 +1,15 @@
from datetime import datetime
from uuid import UUID
def serialize_datetime(data):
def serialize_data(data):
"""Recursively convert datetime objects in dictionaries/lists to ISO format."""
if isinstance(data, dict):
return {key: serialize_datetime(value) for key, value in data.items()}
return {key: serialize_data(value) for key, value in data.items()}
elif isinstance(data, list):
return [serialize_datetime(item) for item in data]
return [serialize_data(item) for item in data]
elif isinstance(data, datetime):
return data.isoformat() # Convert datetime to ISO 8601 string
elif isinstance(data, UUID):
return str(data)
else:
return data

View file

@ -1,12 +1,22 @@
import logging
from uuid import UUID
from typing import List, Dict, Optional
from qdrant_client import AsyncQdrantClient, models
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("QDrantAdapter")
class IndexSchema(DataPoint):
text: str
_metadata: dict = {
"index_fields": ["text"]
}
# class CollectionConfig(BaseModel, extra = "forbid"):
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
@ -75,20 +85,19 @@ class QDrantAdapter(VectorDBInterface):
):
client = self.get_qdrant_client()
result = await client.create_collection(
collection_name = collection_name,
vectors_config = {
"text": models.VectorParams(
size = self.embedding_engine.get_vector_size(),
distance = "Cosine"
)
}
)
if not await client.collection_exists(collection_name):
await client.create_collection(
collection_name = collection_name,
vectors_config = {
"text": models.VectorParams(
size = self.embedding_engine.get_vector_size(),
distance = "Cosine"
)
}
)
await client.close()
return result
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
client = self.get_qdrant_client()
@ -96,8 +105,8 @@ class QDrantAdapter(VectorDBInterface):
def convert_to_qdrant_point(data_point: DataPoint):
return models.PointStruct(
id = data_point.id,
payload = data_point.payload.dict(),
id = str(data_point.id),
payload = data_point.model_dump(),
vector = {
"text": data_vectors[data_points.index(data_point)]
}
@ -116,6 +125,17 @@ class QDrantAdapter(VectorDBInterface):
finally:
await client.close()
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = data_point.id,
text = getattr(data_point, data_point._metadata["index_fields"][0]),
) for data_point in data_points
])
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
client = self.get_qdrant_client()
results = await client.retrieve(collection_name, data_point_ids, with_payload = True)
@ -135,7 +155,7 @@ class QDrantAdapter(VectorDBInterface):
client = self.get_qdrant_client()
result = await client.search(
results = await client.search(
collection_name = collection_name,
query_vector = models.NamedVector(
name = "text",
@ -147,7 +167,16 @@ class QDrantAdapter(VectorDBInterface):
await client.close()
return result
return [
ScoredResult(
id = UUID(result.id),
payload = {
**result.payload,
"id": UUID(result.id),
},
score = 1 - result.score,
) for result in results
]
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int = None, with_vectors: bool = False):

View file

@ -1,6 +1,6 @@
from typing import List, Protocol, Optional
from abc import abstractmethod
from .models.DataPoint import DataPoint
from cognee.infrastructure.engine import DataPoint
from .models.PayloadSchema import PayloadSchema
class VectorDBInterface(Protocol):

View file

@ -1,13 +1,22 @@
import asyncio
import logging
from typing import List, Optional
from uuid import UUID
from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("WeaviateAdapter")
class IndexSchema(DataPoint):
text: str
_metadata: dict = {
"index_fields": ["text"]
}
class WeaviateAdapter(VectorDBInterface):
name = "Weaviate"
url: str
@ -48,18 +57,21 @@ class WeaviateAdapter(VectorDBInterface):
future = asyncio.Future()
future.set_result(
self.client.collections.create(
name=collection_name,
properties=[
wvcc.Property(
name="text",
data_type=wvcc.DataType.TEXT,
skip_vectorization=True
)
]
if not self.client.collections.exists(collection_name):
future.set_result(
self.client.collections.create(
name = collection_name,
properties = [
wvcc.Property(
name = "text",
data_type = wvcc.DataType.TEXT,
skip_vectorization = True
)
]
)
)
)
else:
future.set_result(self.get_collection(collection_name))
return await future
@ -70,36 +82,60 @@ class WeaviateAdapter(VectorDBInterface):
from weaviate.classes.data import DataObject
data_vectors = await self.embed_data(
list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
[data_point.get_embeddable_data() for data_point in data_points]
)
def convert_to_weaviate_data_points(data_point: DataPoint):
vector = data_vectors[data_points.index(data_point)]
properties = data_point.model_dump()
if "id" in properties:
properties["uuid"] = str(data_point.id)
del properties["id"]
return DataObject(
uuid = data_point.id,
properties = data_point.payload.dict(),
properties = properties,
vector = vector
)
data_points = list(map(convert_to_weaviate_data_points, data_points))
data_points = [convert_to_weaviate_data_points(data_point) for data_point in data_points]
collection = self.get_collection(collection_name)
try:
if len(data_points) > 1:
return collection.data.insert_many(data_points)
with collection.batch.dynamic() as batch:
for data_point in data_points:
batch.add_object(
uuid = data_point.uuid,
vector = data_point.vector,
properties = data_point.properties,
references = data_point.references,
)
else:
return collection.data.insert(data_points[0])
# with collection.batch.dynamic() as batch:
# for point in data_points:
# batch.add_object(
# uuid = point.uuid,
# properties = point.properties,
# vector = point.vector
# )
data_point: DataObject = data_points[0]
return collection.data.update(
uuid = data_point.uuid,
vector = data_point.vector,
properties = data_point.properties,
references = data_point.references,
)
except Exception as error:
logger.error("Error creating data points: %s", str(error))
raise error
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = data_point.id,
text = data_point.get_embeddable_data(),
) for data_point in data_points
])
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
from weaviate.classes.query import Filter
future = asyncio.Future()
@ -143,9 +179,9 @@ class WeaviateAdapter(VectorDBInterface):
return [
ScoredResult(
id=str(result.uuid),
payload=result.properties,
score=float(result.metadata.score)
id = UUID(str(result.uuid)),
payload = result.properties,
score = 1 - float(result.metadata.score)
) for result in search_result.objects
]

View file

@ -0,0 +1 @@
from .models.DataPoint import DataPoint

View file

@ -0,0 +1,72 @@
from enum import Enum
from typing import Optional
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model, get_model_instance_from_graph
if __name__ == "__main__":
class CarTypeName(Enum):
Pickup = "Pickup"
Sedan = "Sedan"
SUV = "SUV"
Coupe = "Coupe"
Convertible = "Convertible"
Hatchback = "Hatchback"
Wagon = "Wagon"
Minivan = "Minivan"
Van = "Van"
class CarType(DataPoint):
id: str
name: CarTypeName
_metadata: dict = dict(index_fields = ["name"])
class Car(DataPoint):
id: str
brand: str
model: str
year: int
color: str
is_type: CarType
class Person(DataPoint):
id: str
name: str
age: int
owns_car: list[Car]
driving_licence: Optional[dict]
_metadata: dict = dict(index_fields = ["name"])
boris = Person(
id = "boris",
name = "Boris",
age = 30,
owns_car = [
Car(
id = "car1",
brand = "Toyota",
model = "Camry",
year = 2020,
color = "Blue",
is_type = CarType(id = "sedan", name = CarTypeName.Sedan),
),
],
driving_licence = {
"issued_by": "PU Vrsac",
"issued_on": "2025-11-06",
"number": "1234567890",
"expires_on": "2025-11-06",
},
)
nodes, edges = get_graph_from_model(boris)
print(nodes)
print(edges)
person_data = nodes[len(nodes) - 1]
parsed_person = get_model_instance_from_graph(nodes, edges, 'boris')
print(parsed_person)

View file

@ -0,0 +1,24 @@
from typing_extensions import TypedDict
from uuid import UUID, uuid4
from typing import Optional
from datetime import datetime, timezone
from pydantic import BaseModel, Field
class MetaData(TypedDict):
index_fields: list[str]
class DataPoint(BaseModel):
id: UUID = Field(default_factory = uuid4)
updated_at: Optional[datetime] = datetime.now(timezone.utc)
_metadata: Optional[MetaData] = {
"index_fields": []
}
# class Config:
# underscore_attrs_are_private = True
def get_embeddable_data(self):
if self._metadata and len(self._metadata["index_fields"]) > 0 \
and hasattr(self, self._metadata["index_fields"][0]):
return getattr(self, self._metadata["index_fields"][0])

View file

@ -1,18 +1,18 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from uuid import uuid5, NAMESPACE_OID
from cognee.modules.chunking import DocumentChunk
from cognee.tasks.chunking import chunk_by_paragraph
from .models.DocumentChunk import DocumentChunk
from cognee.tasks.chunks import chunk_by_paragraph
class TextChunker():
id: UUID
document = None
max_chunk_size: int
chunk_index = 0
chunk_size = 0
paragraph_chunks = []
def __init__(self, id: UUID, get_text: callable, chunk_size: int = 1024):
self.id = id
def __init__(self, document, get_text: callable, chunk_size: int = 1024):
self.document = document
self.max_chunk_size = chunk_size
self.get_text = get_text
@ -29,10 +29,10 @@ class TextChunker():
else:
if len(self.paragraph_chunks) == 0:
yield DocumentChunk(
id = chunk_data["chunk_id"],
text = chunk_data["text"],
word_count = chunk_data["word_count"],
document_id = str(self.id),
chunk_id = str(chunk_data["chunk_id"]),
is_part_of = self.document,
chunk_index = self.chunk_index,
cut_type = chunk_data["cut_type"],
)
@ -40,25 +40,31 @@ class TextChunker():
self.chunk_size = 0
else:
chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks)
yield DocumentChunk(
text = chunk_text,
word_count = self.chunk_size,
document_id = str(self.id),
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{self.chunk_index}")),
chunk_index = self.chunk_index,
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
)
try:
yield DocumentChunk(
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
text = chunk_text,
word_count = self.chunk_size,
is_part_of = self.document,
chunk_index = self.chunk_index,
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
)
except Exception as e:
print(e)
self.paragraph_chunks = [chunk_data]
self.chunk_size = chunk_data["word_count"]
self.chunk_index += 1
if len(self.paragraph_chunks) > 0:
yield DocumentChunk(
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
word_count = self.chunk_size,
document_id = str(self.id),
chunk_id = str(uuid5(NAMESPACE_OID, f"{str(self.id)}-{self.chunk_index}")),
chunk_index = self.chunk_index,
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
)
try:
yield DocumentChunk(
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
word_count = self.chunk_size,
is_part_of = self.document,
chunk_index = self.chunk_index,
cut_type = self.paragraph_chunks[len(self.paragraph_chunks) - 1]["cut_type"],
)
except Exception as e:
print(e)

View file

@ -1,2 +0,0 @@
from .models.DocumentChunk import DocumentChunk
from .TextChunker import TextChunker

View file

@ -1,9 +1,14 @@
from pydantic import BaseModel
from typing import Optional
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
class DocumentChunk(BaseModel):
class DocumentChunk(DataPoint):
text: str
word_count: int
document_id: str
chunk_id: str
chunk_index: int
cut_type: str
is_part_of: Document
_metadata: Optional[dict] = {
"index_fields": ["text"],
}

View file

@ -0,0 +1 @@
from .knowledge_graph.extract_content_graph import extract_content_graph

View file

@ -0,0 +1 @@
from .extract_content_graph import extract_content_graph

View file

@ -1,36 +1,36 @@
import logging
logger = logging.getLogger(__name__)
async def detect_language(data:str):
async def detect_language(text: str):
"""
Detect the language of the given text and return its ISO 639-1 language code.
If the detected language is Croatian ('hr'), it maps to Serbian ('sr').
If the detected language is Croatian ("hr"), it maps to Serbian ("sr").
The text is trimmed to the first 100 characters for efficient processing.
Parameters:
text (str): The text for language detection.
Returns:
str: The ISO 639-1 language code of the detected language, or 'None' in case of an error.
str: The ISO 639-1 language code of the detected language, or "None" in case of an error.
"""
# Trim the text to the first 100 characters
from langdetect import detect, LangDetectException
trimmed_text = data[:100]
# Trim the text to the first 100 characters
trimmed_text = text[:100]
try:
# Detect the language using langdetect
detected_lang_iso639_1 = detect(trimmed_text)
logging.info(f"Detected ISO 639-1 code: {detected_lang_iso639_1}")
# Special case: map 'hr' (Croatian) to 'sr' (Serbian ISO 639-2)
if detected_lang_iso639_1 == 'hr':
yield 'sr'
yield detected_lang_iso639_1
# Special case: map "hr" (Croatian) to "sr" (Serbian ISO 639-2)
if detected_lang_iso639_1 == "hr":
return "sr"
return detected_lang_iso639_1
except LangDetectException as e:
logging.error(f"Language detection error: {e}")
except Exception as e:
logging.error(f"Unexpected error: {e}")
logger.error(f"Language detection error: {e}")
yield None
except Exception as e:
logger.error(f"Unexpected error: {e}")
return None

View file

@ -0,0 +1,41 @@
import logging
logger = logging.getLogger(__name__)
async def translate_text(text, source_language: str = "sr", target_language: str = "en", region_name = "eu-west-1"):
"""
Translate text from source language to target language using AWS Translate.
Parameters:
text (str): The text to be translated.
source_language (str): The source language code (e.g., "sr" for Serbian). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
target_language (str): The target language code (e.g., "en" for English). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
region_name (str): AWS region name.
Returns:
str: Translated text or an error message.
"""
import boto3
from botocore.exceptions import BotoCoreError, ClientError
if not text:
raise ValueError("No text to translate.")
if not source_language or not target_language:
raise ValueError("Source and target language codes are required.")
try:
translate = boto3.client(service_name = "translate", region_name = region_name, use_ssl = True)
result = translate.translate_text(
Text = text,
SourceLanguageCode = source_language,
TargetLanguageCode = target_language,
)
yield result.get("TranslatedText", "No translation found.")
except BotoCoreError as e:
logger.error(f"BotoCoreError occurred: {e}")
yield None
except ClientError as e:
logger.error(f"ClientError occurred: {e}")
yield None

View file

@ -1,34 +1,15 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
class AudioDocument(Document):
type: str = "audio"
title: str
raw_data_location: str
chunking_strategy: str
def __init__(self, id: UUID, title: str, raw_data_location: str, chunking_strategy:str="paragraph"):
self.id = id or uuid5(NAMESPACE_OID, title)
self.title = title
self.raw_data_location = raw_data_location
self.chunking_strategy = chunking_strategy
def read(self, chunk_size: int):
# Transcribe the audio file
result = get_llm_client().create_transcript(self.raw_data_location)
text = result.text
chunker = TextChunker(self.id, chunk_size = chunk_size, get_text = lambda: text)
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: text)
yield from chunker.read()
def to_dict(self) -> dict:
return dict(
id=str(self.id),
type=self.type,
title=self.title,
raw_data_location=self.raw_data_location,
)

View file

@ -1,10 +1,8 @@
from uuid import UUID
from typing import Protocol
from cognee.infrastructure.engine import DataPoint
class Document(Protocol):
id: UUID
class Document(DataPoint):
type: str
title: str
name: str
raw_data_location: str
def read(self, chunk_size: int) -> str:

View file

@ -1,33 +1,15 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
class ImageDocument(Document):
type: str = "image"
title: str
raw_data_location: str
def __init__(self, id: UUID, title: str, raw_data_location: str):
self.id = id or uuid5(NAMESPACE_OID, title)
self.title = title
self.raw_data_location = raw_data_location
def read(self, chunk_size: int):
# Transcribe the image file
result = get_llm_client().transcribe_image(self.raw_data_location)
text = result.choices[0].message.content
chunker = TextChunker(self.id, chunk_size = chunk_size, get_text = lambda: text)
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: text)
yield from chunker.read()
def to_dict(self) -> dict:
return dict(
id=str(self.id),
type=self.type,
title=self.title,
raw_data_location=self.raw_data_location,
)

View file

@ -1,19 +1,11 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from pypdf import PdfReader
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
class PdfDocument(Document):
type: str = "pdf"
title: str
raw_data_location: str
def __init__(self, id: UUID, title: str, raw_data_location: str):
self.id = id or uuid5(NAMESPACE_OID, title)
self.title = title
self.raw_data_location = raw_data_location
def read(self, chunk_size: int) -> PdfReader:
def read(self, chunk_size: int):
file = PdfReader(self.raw_data_location)
def get_text():
@ -21,16 +13,8 @@ class PdfDocument(Document):
page_text = page.extract_text()
yield page_text
chunker = TextChunker(self.id, chunk_size = chunk_size, get_text = get_text)
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
yield from chunker.read()
file.stream.close()
def to_dict(self) -> dict:
return dict(
id = str(self.id),
type = self.type,
title = self.title,
raw_data_location = self.raw_data_location,
)

View file

@ -1,16 +1,8 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
class TextDocument(Document):
type: str = "text"
title: str
raw_data_location: str
def __init__(self, id: UUID, title: str, raw_data_location: str):
self.id = id or uuid5(NAMESPACE_OID, title)
self.title = title
self.raw_data_location = raw_data_location
def read(self, chunk_size: int):
def get_text():
@ -23,16 +15,6 @@ class TextDocument(Document):
yield text
chunker = TextChunker(self.id,chunk_size = chunk_size, get_text = get_text)
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
yield from chunker.read()
def to_dict(self) -> dict:
return dict(
id = str(self.id),
type = self.type,
title = self.title,
raw_data_location = self.raw_data_location,
)

View file

@ -1,3 +1,4 @@
from .Document import Document
from .PdfDocument import PdfDocument
from .TextDocument import TextDocument
from .ImageDocument import ImageDocument

View file

@ -0,0 +1,12 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from .EntityType import EntityType
class Entity(DataPoint):
name: str
is_a: EntityType
description: str
mentioned_in: DocumentChunk
_metadata: dict = {
"index_fields": ["name"],
}

View file

@ -0,0 +1,11 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
class EntityType(DataPoint):
name: str
type: str
description: str
exists_in: DocumentChunk
_metadata: dict = {
"index_fields": ["name"],
}

View file

@ -0,0 +1,2 @@
from .Entity import Entity
from .EntityType import EntityType

View file

@ -0,0 +1,3 @@
from .generate_node_id import generate_node_id
from .generate_node_name import generate_node_name
from .generate_edge_name import generate_edge_name

View file

@ -0,0 +1,2 @@
def generate_edge_name(name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "")

View file

@ -0,0 +1,4 @@
from uuid import NAMESPACE_OID, uuid5
def generate_node_id(node_id: str) -> str:
return uuid5(NAMESPACE_OID, node_id.lower().replace(" ", "_").replace("'", ""))

View file

@ -0,0 +1,2 @@
def generate_node_name(name: str) -> str:
return name.lower().replace("'", "")

View file

@ -1,5 +0,0 @@
def generate_node_name(name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "")
def generate_node_id(node_id: str) -> str:
return node_id.lower().replace(" ", "_").replace("'", "")

View file

@ -0,0 +1,2 @@
from .get_graph_from_model import get_graph_from_model
from .get_model_instance_from_graph import get_model_instance_from_graph

View file

@ -0,0 +1,107 @@
from datetime import datetime, timezone
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
nodes = []
edges = []
data_point_properties = {}
excluded_properties = set()
for field_name, field_value in data_point:
if field_name == "_metadata":
continue
if isinstance(field_value, DataPoint):
excluded_properties.add(field_name)
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
excluded_properties.add(field_name)
for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"metadata": {
"type": "list"
},
}))
added_edges[edge_key] = True
continue
data_point_properties[field_name] = field_value
SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
"_metadata": (dict, data_point._metadata),
},
exclude_fields = excluded_properties,
)
if include_root:
nodes.append(SimpleDataPointModel(**data_point_properties))
return nodes, edges
def get_own_properties(property_nodes, property_edges):
own_properties = []
destination_nodes = [str(property_edge[1]) for property_edge in property_edges]
for node in property_nodes:
if str(node.id) in destination_nodes:
continue
own_properties.append(node)
return own_properties

View file

@ -0,0 +1,29 @@
from pydantic_core import PydanticUndefined
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
node_map = {}
for node in nodes:
node_map[node.id] = node
for edge in edges:
source_node = node_map[edge[0]]
target_node = node_map[edge[1]]
edge_label = edge[2]
edge_properties = edge[3] if len(edge) == 4 else {}
edge_metadata = edge_properties.get("metadata", {})
edge_type = edge_metadata.get("type")
if edge_type == "list":
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
else:
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
return node_map[entity_id]

View file

@ -7,7 +7,7 @@ from ..tasks.Task import Task
logger = logging.getLogger("run_tasks(tasks: [Task], data)")
async def run_tasks_base(tasks: [Task], data = None, user: User = None):
async def run_tasks_base(tasks: list[Task], data = None, user: User = None):
if len(tasks) == 0:
yield data
return
@ -16,7 +16,7 @@ async def run_tasks_base(tasks: [Task], data = None, user: User = None):
running_task = tasks[0]
leftover_tasks = tasks[1:]
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
if inspect.isasyncgenfunction(running_task.executable):

View file

@ -1,33 +0,0 @@
import asyncio
import nest_asyncio
import dspy
from cognee.modules.search.vector.search_similarity import search_similarity
nest_asyncio.apply()
class AnswerFromContext(dspy.Signature):
question: str = dspy.InputField()
context: str = dspy.InputField(desc = "Context to use for answer generation.")
answer: str = dspy.OutputField()
question_answer_llm = dspy.OpenAI(model = "gpt-3.5-turbo-instruct")
class CogneeSearch(dspy.Module):
def __init__(self, ):
super().__init__()
self.generate_answer = dspy.TypedChainOfThought(AnswerFromContext)
def forward(self, question):
context = asyncio.run(search_similarity(question))
context_text = "\n".join(context)
print(f"Context: {context_text}")
with dspy.context(lm = question_answer_llm):
answer_prediction = self.generate_answer(context = context_text, question = question)
answer = answer_prediction.answer
print(f"Question: {question}")
print(f"Answer: {answer}")
return dspy.Prediction(context = context_text, answer = answer)

View file

@ -1,43 +0,0 @@
import asyncio
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_adjacent(query: str) -> list[(str, str)]:
"""
Find the neighbours of a given node in the graph and return their ids and descriptions.
Parameters:
- query (str): The query string to filter nodes by.
Returns:
- list[(str, str)]: A list containing the unique identifiers and names of the neighbours of the given node.
"""
node_id = query
if node_id is None:
return {}
graph_engine = await get_graph_engine()
exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "uuid" in exact_node:
neighbours = await graph_engine.get_neighbours(exact_node["uuid"])
else:
vector_engine = get_vector_engine()
results = await asyncio.gather(
vector_engine.search("entities", query_text = query, limit = 10),
vector_engine.search("classification", query_text = query, limit = 10),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]
if len(relevant_results) == 0:
return []
node_neighbours = await asyncio.gather(*[graph_engine.get_neighbours(result.id) for result in relevant_results])
neighbours = []
for neighbour_ids in node_neighbours:
neighbours.extend(neighbour_ids)
return neighbours

View file

@ -1,15 +0,0 @@
from cognee.infrastructure.databases.graph import get_graph_engine, get_graph_config
async def search_cypher(query: str):
"""
Use a Cypher query to search the graph and return the results.
"""
graph_config = get_graph_config()
if graph_config.graph_database_provider == "neo4j":
graph_engine = await get_graph_engine()
result = await graph_engine.graph().run(query)
return result
else:
raise ValueError("Unsupported search type for the used graph engine.")

View file

@ -1,27 +0,0 @@
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_similarity(query: str) -> list[str, str]:
"""
Parameters:
- query (str): The query string to filter nodes by.
Returns:
- list(chunk): A list of objects providing information about the chunks related to query.
"""
vector_engine = get_vector_engine()
similar_results = await vector_engine.search("chunks", query, limit = 5)
results = [
parse_payload(result.payload) for result in similar_results
]
return results
def parse_payload(payload: dict) -> dict:
return {
"text": payload["text"],
"chunk_id": payload["chunk_id"],
"document_id": payload["document_id"],
}

View file

@ -1,17 +0,0 @@
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_summary(query: str) -> list:
"""
Parameters:
- query (str): The query string to filter summaries by.
Returns:
- list[str, UUID]: A list of objects providing information about the summaries related to query.
"""
vector_engine = get_vector_engine()
summaries_results = await vector_engine.search("summaries", query, limit = 5)
summaries = [summary.payload for summary in summaries_results]
return summaries

View file

@ -1,16 +0,0 @@
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client
async def categorize_relevant_category(query: str, summary, response_model: Type[BaseModel]):
llm_client = get_llm_client()
enriched_query= render_prompt("categorize_categories.txt", {"query": query, "categories": summary})
system_prompt = " Choose the relevant categories and return appropriate output based on the model"
llm_output = await llm_client.acreate_structured_output(enriched_query, system_prompt, response_model)
return llm_output.model_dump()

View file

@ -1,15 +0,0 @@
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client
async def categorize_relevant_summary(query: str, summaries, response_model: Type[BaseModel]):
llm_client = get_llm_client()
enriched_query= render_prompt("categorize_summary.txt", {"query": query, "summaries": summaries})
system_prompt = "Choose the relevant summaries and return appropriate output based on the model"
llm_output = await llm_client.acreate_structured_output(enriched_query, system_prompt, response_model)
return llm_output

View file

@ -1,17 +0,0 @@
import logging
from typing import List, Dict
from cognee.modules.cognify.config import get_cognify_config
from .extraction.categorize_relevant_summary import categorize_relevant_summary
logger = logging.getLogger(__name__)
async def get_cognitive_layers(content: str, categories: List[Dict]):
try:
cognify_config = get_cognify_config()
return (await categorize_relevant_summary(
content,
categories[0],
cognify_config.summarization_model,
)).cognitive_layers
except Exception as error:
logger.error("Error extracting cognitive layers from content: %s", error, exc_info = True)
raise error

View file

@ -1 +0,0 @@
""" Placeholder for BM25 implementation"""

View file

@ -1 +0,0 @@
"""Placeholder for fusions search implementation"""

View file

@ -1,36 +0,0 @@
import asyncio
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_traverse(query: str):
node_id = query
rules = set()
graph_engine = await get_graph_engine()
vector_engine = get_vector_engine()
exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "uuid" in exact_node:
edges = await graph_engine.get_edges(exact_node["uuid"])
for edge in edges:
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
else:
results = await asyncio.gather(
vector_engine.search("entities", query_text = query, limit = 10),
vector_engine.search("classification", query_text = query, limit = 10),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]
if len(relevant_results) > 0:
for result in relevant_results:
graph_node_id = result.id
edges = await graph_engine.get_edges(graph_node_id)
for edge in edges:
rules.add(f"{edge[0]} {edge[2]['relationship_name']} {edge[1]}")
return list(rules)

View file

@ -0,0 +1,46 @@
import json
from uuid import UUID
from datetime import datetime
from pydantic_core import PydanticUndefined
from cognee.infrastructure.engine import DataPoint
class JSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat() # Convert datetime to ISO 8601 string
elif isinstance(obj, UUID):
# if the obj is uuid, we simply return the value of uuid
return str(obj)
return json.JSONEncoder.default(self, obj)
from pydantic import create_model
def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list = []):
fields = {
name: (field.annotation, field.default if field.default is not None else PydanticUndefined)
for name, field in model.model_fields.items()
if name not in exclude_fields
}
final_fields = {
**fields,
**include_fields
}
return create_model(model.__name__, **final_fields)
def get_own_properties(data_point: DataPoint):
properties = {}
for field_name, field_value in data_point:
if field_name == "_metadata" \
or isinstance(field_value, dict) \
or isinstance(field_value, DataPoint) \
or (isinstance(field_value, list) and isinstance(field_value[0], DataPoint)):
continue
properties[field_name] = field_value
return properties

View file

@ -1,84 +1,95 @@
from typing import List, Union, Literal, Optional
from pydantic import BaseModel
from typing import Any, List, Union, Literal, Optional
from cognee.infrastructure.engine import DataPoint
class BaseClass(BaseModel):
id: str
name: str
type: Literal["Class"] = "Class"
description: str
constructor_parameters: Optional[List[str]] = None
class Class(BaseModel):
id: str
name: str
type: Literal["Class"] = "Class"
description: str
constructor_parameters: Optional[List[str]] = None
from_class: Optional[BaseClass] = None
class ClassInstance(BaseModel):
id: str
name: str
type: Literal["ClassInstance"] = "ClassInstance"
description: str
from_class: Class
class Function(BaseModel):
id: str
name: str
type: Literal["Function"] = "Function"
description: str
parameters: Optional[List[str]] = None
return_type: str
is_static: Optional[bool] = False
class Variable(BaseModel):
class Variable(DataPoint):
id: str
name: str
type: Literal["Variable"] = "Variable"
description: str
is_static: Optional[bool] = False
default_value: Optional[str] = None
data_type: str
class Operator(BaseModel):
_metadata = {
"index_fields": ["name"]
}
class Operator(DataPoint):
id: str
name: str
type: Literal["Operator"] = "Operator"
description: str
return_type: str
class ExpressionPart(BaseModel):
class Class(DataPoint):
id: str
name: str
type: Literal["Class"] = "Class"
description: str
constructor_parameters: List[Variable]
extended_from_class: Optional["Class"] = None
has_methods: list["Function"]
_metadata = {
"index_fields": ["name"]
}
class ClassInstance(DataPoint):
id: str
name: str
type: Literal["ClassInstance"] = "ClassInstance"
description: str
from_class: Class
instantiated_by: Union["Function"]
instantiation_arguments: List[Variable]
_metadata = {
"index_fields": ["name"]
}
class Function(DataPoint):
id: str
name: str
type: Literal["Function"] = "Function"
description: str
parameters: List[Variable]
return_type: str
is_static: Optional[bool] = False
_metadata = {
"index_fields": ["name"]
}
class FunctionCall(DataPoint):
id: str
type: Literal["FunctionCall"] = "FunctionCall"
called_by: Union[Function, Literal["main"]]
function_called: Function
function_arguments: List[Any]
class Expression(DataPoint):
id: str
name: str
type: Literal["Expression"] = "Expression"
description: str
expression: str
members: List[Union[Variable, Function, Operator]]
members: List[Union[Variable, Function, Operator, "Expression"]]
class Expression(BaseModel):
id: str
name: str
type: Literal["Expression"] = "Expression"
description: str
expression: str
members: List[Union[Variable, Function, Operator, ExpressionPart]]
class Edge(BaseModel):
source_node_id: str
target_node_id: str
relationship_name: Literal["called in", "stored in", "defined in", "returned by", "instantiated in", "uses", "updates"]
class SourceCodeGraph(BaseModel):
class SourceCodeGraph(DataPoint):
id: str
name: str
description: str
language: str
nodes: List[Union[
Class,
ClassInstance,
Function,
FunctionCall,
Variable,
Operator,
Expression,
ClassInstance,
]]
edges: List[Edge]
Class.model_rebuild()
ClassInstance.model_rebuild()
Expression.model_rebuild()

View file

@ -1,6 +1,6 @@
""" This module contains utility functions for the cognee. """
import os
import datetime
from datetime import datetime, timezone
import graphistry
import networkx as nx
import numpy as np
@ -45,7 +45,7 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
host = "https://eu.i.posthog.com"
)
current_time = datetime.datetime.now()
current_time = datetime.now(timezone.utc)
properties = {
"time": current_time.strftime("%m/%d/%Y"),
"user_id": user_id,
@ -110,30 +110,36 @@ async def register_graphistry():
graphistry.register(api = 3, username = config.graphistry_username, password = config.graphistry_password)
def prepare_edges(graph):
return nx.to_pandas_edgelist(graph)
def prepare_edges(graph, source, target, edge_key):
edge_list = [{
source: str(edge[0]),
target: str(edge[1]),
edge_key: str(edge[2]),
} for edge in graph.edges(keys = True, data = True)]
return pd.DataFrame(edge_list)
def prepare_nodes(graph, include_size=False):
nodes_data = []
for node in graph.nodes:
node_info = graph.nodes[node]
description = node_info.get("layer_description", {}).get("layer", "Default Layer") if isinstance(
node_info.get("layer_description"), dict) else node_info.get("layer_description", "Default Layer")
# description = node_info['layer_description']['layer'] if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info['layer_description'] else node_info.get('layer_description', node)
# if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info.get('layer_description'):
# description = node_info['layer_description']['layer']
# # Use 'layer_description' directly if it's not a dictionary, otherwise default to node ID
# else:
# description = node_info.get('layer_description', node)
node_data = {"id": node, "layer_description": description}
if not node_info:
continue
node_data = {
"id": str(node),
"name": node_info["name"] if "name" in node_info else str(node),
}
if include_size:
default_size = 10 # Default node size
larger_size = 20 # Size for nodes with specific keywords in their ID
keywords = ["DOCUMENT", "User", "LAYER"]
keywords = ["DOCUMENT", "User"]
node_size = larger_size if any(keyword in str(node) for keyword in keywords) else default_size
node_data["size"] = node_size
nodes_data.append(node_data)
return pd.DataFrame(nodes_data)
@ -153,28 +159,28 @@ async def render_graph(graph, include_nodes=False, include_color=False, include_
graph = networkx_graph
edges = prepare_edges(graph)
plotter = graphistry.edges(edges, "source", "target")
edges = prepare_edges(graph, "source_node", "target_node", "relationship_name")
plotter = graphistry.edges(edges, "source_node", "target_node")
plotter = plotter.bind(edge_label = "relationship_name")
if include_nodes:
nodes = prepare_nodes(graph, include_size=include_size)
nodes = prepare_nodes(graph, include_size = include_size)
plotter = plotter.nodes(nodes, "id")
if include_size:
plotter = plotter.bind(point_size="size")
plotter = plotter.bind(point_size = "size")
if include_color:
unique_layers = nodes["layer_description"].unique()
color_palette = generate_color_palette(unique_layers)
plotter = plotter.encode_point_color("layer_description", categorical_mapping=color_palette,
default_mapping="silver")
pass
# unique_layers = nodes["layer_description"].unique()
# color_palette = generate_color_palette(unique_layers)
# plotter = plotter.encode_point_color("layer_description", categorical_mapping=color_palette,
# default_mapping="silver")
if include_labels:
plotter = plotter.bind(point_label = "layer_description")
plotter = plotter.bind(point_label = "name")
# Visualization

View file

@ -1,10 +0,0 @@
from .summarization.summarize_text import summarize_text
from .chunk_naive_llm_classifier.chunk_naive_llm_classifier import chunk_naive_llm_classifier
from .chunk_remove_disconnected.chunk_remove_disconnected import chunk_remove_disconnected
from .chunk_update_check.chunk_update_check import chunk_update_check
from .save_chunks_to_store.save_chunks_to_store import save_chunks_to_store
from .source_documents_to_chunks.source_documents_to_chunks import source_documents_to_chunks
from .infer_data_ontology.infer_data_ontology import infer_data_ontology
from .check_permissions_on_documents.check_permissions_on_documents import check_permissions_on_documents
from .classify_documents.classify_documents import classify_documents
from .graph.chunks_into_graph import chunks_into_graph

View file

@ -5,7 +5,7 @@ from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
from cognee.modules.data.extraction.extract_categories import extract_categories
from cognee.modules.chunking import DocumentChunk
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
@ -65,7 +65,7 @@ async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classific
"chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id),
}),
embed_field="text",
index_fields=["text"],
)
)
@ -104,7 +104,7 @@ async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classific
"chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id),
}),
embed_field="text",
index_fields=["text"],
)
)

View file

@ -1,39 +0,0 @@
import logging
from cognee.base_config import get_base_config
BaseConfig = get_base_config()
async def translate_text(data, source_language:str='sr', target_language:str='en', region_name='eu-west-1'):
"""
Translate text from source language to target language using AWS Translate.
Parameters:
data (str): The text to be translated.
source_language (str): The source language code (e.g., 'sr' for Serbian). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
target_language (str): The target language code (e.g., 'en' for English). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
region_name (str): AWS region name.
Returns:
str: Translated text or an error message.
"""
import boto3
from botocore.exceptions import BotoCoreError, ClientError
if not data:
yield "No text provided for translation."
if not source_language or not target_language:
yield "Both source and target language codes are required."
try:
translate = boto3.client(service_name='translate', region_name=region_name, use_ssl=True)
result = translate.translate_text(Text=data, SourceLanguageCode=source_language, TargetLanguageCode=target_language)
yield result.get('TranslatedText', 'No translation found.')
except BotoCoreError as e:
logging.info(f"BotoCoreError occurred: {e}")
yield "Error with AWS Translate service configuration or request."
except ClientError as e:
logging.info(f"ClientError occurred: {e}")
yield "Error with AWS client or network issue."

View file

@ -1,26 +0,0 @@
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.chunking import DocumentChunk
async def chunk_update_check(data_chunks: list[DocumentChunk], collection_name: str) -> list[DocumentChunk]:
vector_engine = get_vector_engine()
if not await vector_engine.has_collection(collection_name):
# If collection doesn't exist, all data_chunks are new
return data_chunks
existing_chunks = await vector_engine.retrieve(
collection_name,
[str(chunk.chunk_id) for chunk in data_chunks],
)
existing_chunks_map = {str(chunk.id): chunk.payload for chunk in existing_chunks}
affected_data_chunks = []
for chunk in data_chunks:
if chunk.chunk_id not in existing_chunks_map or \
chunk.text != existing_chunks_map[chunk.chunk_id]["text"]:
affected_data_chunks.append(chunk)
return affected_data_chunks

View file

@ -2,3 +2,4 @@ from .query_chunks import query_chunks
from .chunk_by_word import chunk_by_word
from .chunk_by_sentence import chunk_by_sentence
from .chunk_by_paragraph import chunk_by_paragraph
from .remove_disconnected_chunks import remove_disconnected_chunks

View file

@ -1,4 +1,4 @@
from cognee.tasks.chunking import chunk_by_paragraph
from cognee.tasks.chunks import chunk_by_paragraph
if __name__ == "__main__":
def test_chunking_on_whole_text():

View file

@ -10,7 +10,7 @@ async def query_chunks(query: str) -> list[dict]:
"""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("chunks", query, limit = 5)
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit = 5)
chunks = [result.payload for result in found_chunks]

View file

@ -1,7 +1,7 @@
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.chunking import DocumentChunk
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
async def chunk_remove_disconnected(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
async def remove_disconnected_chunks(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
graph_engine = await get_graph_engine()
document_ids = set((data_chunk.document_id for data_chunk in data_chunks))

View file

@ -1,13 +0,0 @@
from cognee.modules.data.models import Data
from cognee.modules.data.processing.document_types import Document, PdfDocument, AudioDocument, ImageDocument, TextDocument
def classify_documents(data_documents: list[Data]) -> list[Document]:
documents = [
PdfDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "pdf" else
AudioDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "audio" else
ImageDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "image" else
TextDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location)
for data_item in data_documents
]
return documents

View file

@ -0,0 +1,3 @@
from .classify_documents import classify_documents
from .extract_chunks_from_documents import extract_chunks_from_documents
from .check_permissions_on_documents import check_permissions_on_documents

View file

@ -0,0 +1,13 @@
from cognee.modules.data.models import Data
from cognee.modules.data.processing.document_types import Document, PdfDocument, AudioDocument, ImageDocument, TextDocument
def classify_documents(data_documents: list[Data]) -> list[Document]:
documents = [
PdfDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "pdf" else
AudioDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "audio" else
ImageDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location) if data_item.extension == "image" else
TextDocument(id = data_item.id, name=f"{data_item.name}.{data_item.extension}", raw_data_location=data_item.raw_data_location)
for data_item in data_documents
]
return documents

View file

@ -0,0 +1,7 @@
from cognee.modules.data.processing.document_types.Document import Document
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024):
for document in documents:
for document_chunk in document.read(chunk_size = chunk_size):
yield document_chunk

View file

@ -1,2 +1,3 @@
from .chunks_into_graph import chunks_into_graph
from .extract_graph_from_data import extract_graph_from_data
from .extract_graph_from_code import extract_graph_from_code
from .query_graph_connections import query_graph_connections

View file

@ -1,213 +0,0 @@
import json
import asyncio
from uuid import uuid5, NAMESPACE_OID
from datetime import datetime, timezone
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import DataPoint, get_vector_engine
from cognee.modules.data.extraction.knowledge_graph.extract_content_graph import extract_content_graph
from cognee.modules.chunking import DocumentChunk
from cognee.modules.graph.utils import generate_node_id, generate_node_name
class EntityNode(BaseModel):
uuid: str
name: str
type: str
description: str
created_at: datetime
updated_at: datetime
async def chunks_into_graph(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel], collection_name: str):
chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
has_collection = await vector_engine.has_collection(collection_name)
if not has_collection:
await vector_engine.create_collection(collection_name, payload_schema = EntityNode)
processed_nodes = {}
type_node_edges = []
entity_node_edges = []
type_entity_edges = []
for (chunk_index, chunk) in enumerate(data_chunks):
chunk_graph = chunk_graphs[chunk_index]
for node in chunk_graph.nodes:
type_node_id = generate_node_id(node.type)
entity_node_id = generate_node_id(node.id)
if type_node_id not in processed_nodes:
type_node_edges.append((str(chunk.chunk_id), type_node_id, "contains_entity_type"))
processed_nodes[type_node_id] = True
if entity_node_id not in processed_nodes:
entity_node_edges.append((str(chunk.chunk_id), entity_node_id, "contains_entity"))
type_entity_edges.append((entity_node_id, type_node_id, "is_entity_type"))
processed_nodes[entity_node_id] = True
graph_node_edges = [
(edge.target_node_id, edge.source_node_id, edge.relationship_name) \
for edge in chunk_graph.edges
]
existing_edges = await graph_engine.has_edges([
*type_node_edges,
*entity_node_edges,
*type_entity_edges,
*graph_node_edges,
])
existing_edges_map = {}
existing_nodes_map = {}
for edge in existing_edges:
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
existing_nodes_map[edge[0]] = True
graph_nodes = []
graph_edges = []
data_points = []
for (chunk_index, chunk) in enumerate(data_chunks):
graph = chunk_graphs[chunk_index]
if graph is None:
continue
for node in graph.nodes:
node_id = generate_node_id(node.id)
node_name = generate_node_name(node.name)
type_node_id = generate_node_id(node.type)
type_node_name = generate_node_name(node.type)
if node_id not in existing_nodes_map:
node_data = dict(
uuid = node_id,
name = node_name,
type = node_name,
description = node.description,
created_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
)
graph_nodes.append((
node_id,
dict(
**node_data,
properties = json.dumps(node.properties),
)
))
data_points.append(DataPoint[EntityNode](
id = str(uuid5(NAMESPACE_OID, node_id)),
payload = node_data,
embed_field = "name",
))
existing_nodes_map[node_id] = True
edge_key = str(chunk.chunk_id) + node_id + "contains_entity"
if edge_key not in existing_edges_map:
graph_edges.append((
str(chunk.chunk_id),
node_id,
"contains_entity",
dict(
relationship_name = "contains_entity",
source_node_id = str(chunk.chunk_id),
target_node_id = node_id,
),
))
# Add relationship between entity type and entity itself: "Jake is Person"
graph_edges.append((
node_id,
type_node_id,
"is_entity_type",
dict(
relationship_name = "is_entity_type",
source_node_id = type_node_id,
target_node_id = node_id,
),
))
existing_edges_map[edge_key] = True
if type_node_id not in existing_nodes_map:
type_node_data = dict(
uuid = type_node_id,
name = type_node_name,
type = type_node_id,
description = type_node_name,
created_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
updated_at = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
)
graph_nodes.append((type_node_id, dict(
**type_node_data,
properties = json.dumps(node.properties)
)))
data_points.append(DataPoint[EntityNode](
id = str(uuid5(NAMESPACE_OID, type_node_id)),
payload = type_node_data,
embed_field = "name",
))
existing_nodes_map[type_node_id] = True
edge_key = str(chunk.chunk_id) + type_node_id + "contains_entity_type"
if edge_key not in existing_edges_map:
graph_edges.append((
str(chunk.chunk_id),
type_node_id,
"contains_entity_type",
dict(
relationship_name = "contains_entity_type",
source_node_id = str(chunk.chunk_id),
target_node_id = type_node_id,
),
))
existing_edges_map[edge_key] = True
# Add relationship that came from graphs.
for edge in graph.edges:
source_node_id = generate_node_id(edge.source_node_id)
target_node_id = generate_node_id(edge.target_node_id)
relationship_name = generate_node_name(edge.relationship_name)
edge_key = source_node_id + target_node_id + relationship_name
if edge_key not in existing_edges_map:
graph_edges.append((
generate_node_id(edge.source_node_id),
generate_node_id(edge.target_node_id),
edge.relationship_name,
dict(
relationship_name = generate_node_name(edge.relationship_name),
source_node_id = generate_node_id(edge.source_node_id),
target_node_id = generate_node_id(edge.target_node_id),
properties = json.dumps(edge.properties),
),
))
existing_edges_map[edge_key] = True
if len(data_points) > 0:
await vector_engine.create_data_points(collection_name, data_points)
if len(graph_nodes) > 0:
await graph_engine.add_nodes(graph_nodes)
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
return data_chunks

Some files were not shown because too many files have changed in this diff Show more