This commit is contained in:
Vasilije 2026-01-07 15:27:03 +00:00 committed by GitHub
commit e3950e6cfe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 349 additions and 0 deletions

View file

@ -1 +1,7 @@
from .get_formatted_graph_data import get_formatted_graph_data
from .delete_data_related_nodes import delete_data_related_nodes
from .delete_data_related_edges import delete_data_related_edges
from .delete_dataset_related_nodes import delete_dataset_related_nodes
from .delete_dataset_related_edges import delete_dataset_related_edges

View file

@ -0,0 +1,13 @@
from uuid import UUID
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Edge
@with_async_session
async def delete_data_related_edges(data_id: UUID, session: AsyncSession):
edges = (await session.scalars(select(Edge).where(Edge.data_id == data_id))).all()
await session.execute(delete(Edge).where(Edge.id.in_([edge.id for edge in edges])))

View file

@ -0,0 +1,13 @@
from uuid import UUID
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Node
@with_async_session
async def delete_data_related_nodes(data_id: UUID, session: AsyncSession):
nodes = (await session.scalars(select(Node).where(Node.data_id == data_id))).all()
await session.execute(delete(Node).where(Node.id.in_([node.id for node in nodes])))

View file

@ -0,0 +1,13 @@
from uuid import UUID
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Edge
@with_async_session
async def delete_dataset_related_edges(dataset_id: UUID, session: AsyncSession):
edges = (await session.scalars(select(Edge).where(Edge.dataset_id == dataset_id))).all()
await session.execute(delete(Edge).where(Edge.id.in_([edge.id for edge in edges])))

View file

@ -0,0 +1,13 @@
from uuid import UUID
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Node
@with_async_session
async def delete_dataset_related_nodes(dataset_id: UUID, session: AsyncSession):
nodes = (await session.scalars(select(Node).where(Node.dataset_id == dataset_id))).all()
await session.execute(delete(Node).where(Node.id.in_([node.id for node in nodes])))

View file

@ -0,0 +1,58 @@
from datetime import datetime, timezone
from sqlalchemy import (
# event,
DateTime,
JSON,
UUID,
Text,
)
# from sqlalchemy.schema import DDL
from sqlalchemy.orm import Mapped, mapped_column
from cognee.infrastructure.databases.relational import Base
class Edge(Base):
__tablename__ = "edges"
id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)
slug: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False)
dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False)
source_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
destination_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
relationship_name: Mapped[str] = mapped_column(Text, nullable=False)
label: Mapped[str | None] = mapped_column(Text)
attributes: Mapped[dict | None] = mapped_column(JSON)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
)
# __table_args__ = (
# {"postgresql_partition_by": "HASH (user_id)"}, # partitioning by user
# )
# Enable row-level security (RLS) for edges
# enable_edge_rls = DDL("""
# ALTER TABLE edges ENABLE ROW LEVEL SECURITY;
# """)
# create_user_isolation_policy = DDL("""
# CREATE POLICY user_isolation_policy
# ON edges
# USING (user_id = current_setting('app.current_user_id')::uuid)
# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid);
# """)
# event.listen(Edge.__table__, "after_create", enable_edge_rls)
# event.listen(Edge.__table__, "after_create", create_user_isolation_policy)

View file

@ -0,0 +1,59 @@
from datetime import datetime, timezone
from sqlalchemy import (
DateTime,
Index,
# event,
String,
JSON,
UUID,
)
# from sqlalchemy.schema import DDL
from sqlalchemy.orm import Mapped, mapped_column
from cognee.infrastructure.databases.relational import Base
class Node(Base):
__tablename__ = "nodes"
id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)
slug: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False)
label: Mapped[str | None] = mapped_column(String(255))
type: Mapped[str] = mapped_column(String(255), nullable=False)
indexed_fields: Mapped[list] = mapped_column(JSON, nullable=False)
attributes: Mapped[dict | None] = mapped_column(JSON)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
)
__table_args__ = (
Index("index_node_dataset_slug", "dataset_id", "slug"),
Index("index_node_dataset_data", "dataset_id", "data_id"),
# {"postgresql_partition_by": "HASH (user_id)"}, # HASH partitioning on user_id
)
# Enable row-level security (RLS) for nodes
# enable_node_rls = DDL("""
# ALTER TABLE nodes ENABLE ROW LEVEL SECURITY;
# """)
# create_user_isolation_policy = DDL("""
# CREATE POLICY user_isolation_policy
# ON nodes
# USING (user_id = current_setting('app.current_user_id')::uuid)
# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid);
# """)
# event.listen(Node.__table__, "after_create", enable_node_rls)
# event.listen(Node.__table__, "after_create", create_user_isolation_policy)

View file

@ -0,0 +1,2 @@
from .Edge import Edge
from .Node import Node

View file

@ -0,0 +1,43 @@
import pytest
from uuid import uuid4
from types import SimpleNamespace
from unittest.mock import AsyncMock
from cognee.modules.graph.methods import delete_data_related_edges
class DummyScalarResult:
def __init__(self, items):
self._items = items
def all(self):
return self._items
class FakeEdge:
def __init__(self, edge_id):
self.id = edge_id
@pytest.mark.asyncio
async def test_delete_data_related_edges_deletes_found_rows():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([FakeEdge(1), FakeEdge(2)]))
session.execute = AsyncMock()
await delete_data_related_edges(uuid4(), session=session)
session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_delete_data_related_edges_handles_empty_list():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([]))
session.execute = AsyncMock()
await delete_data_related_edges(uuid4(), session=session)
session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()

View file

@ -0,0 +1,43 @@
import pytest
from uuid import uuid4
from types import SimpleNamespace
from unittest.mock import AsyncMock
from cognee.modules.graph.methods import delete_data_related_nodes
class DummyScalarResult:
def __init__(self, items):
self._items = items
def all(self):
return self._items
class FakeNode:
def __init__(self, node_id):
self.id = node_id
@pytest.mark.asyncio
async def test_delete_data_related_nodes_deletes_found_rows():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([FakeNode(1), FakeNode(2)]))
session.execute = AsyncMock()
await delete_data_related_nodes(uuid4(), session=session)
session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_delete_data_related_nodes_handles_empty_list():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([]))
session.execute = AsyncMock()
await delete_data_related_nodes(uuid4(), session=session)
session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()

View file

@ -0,0 +1,43 @@
import pytest
from uuid import uuid4
from types import SimpleNamespace
from unittest.mock import AsyncMock
from cognee.modules.graph.methods import delete_dataset_related_edges
class DummyScalarResult:
def __init__(self, items):
self._items = items
def all(self):
return self._items
class FakeEdge:
def __init__(self, edge_id):
self.id = edge_id
@pytest.mark.asyncio
async def test_delete_dataset_related_edges_deletes_found_rows():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([FakeEdge(1), FakeEdge(2)]))
session.execute = AsyncMock()
await delete_dataset_related_edges(uuid4(), session=session)
session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_delete_dataset_related_edges_handles_empty_list():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([]))
session.execute = AsyncMock()
await delete_dataset_related_edges(uuid4(), session=session)
session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()

View file

@ -0,0 +1,43 @@
import pytest
from uuid import uuid4
from types import SimpleNamespace
from unittest.mock import AsyncMock
from cognee.modules.graph.methods import delete_dataset_related_nodes
class DummyScalarResult:
def __init__(self, items):
self._items = items
def all(self):
return self._items
class FakeNode:
def __init__(self, node_id):
self.id = node_id
@pytest.mark.asyncio
async def test_delete_dataset_related_nodes_deletes_found_rows():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([FakeNode(1), FakeNode(2)]))
session.execute = AsyncMock()
await delete_dataset_related_nodes(uuid4(), session=session)
session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_delete_dataset_related_nodes_handles_empty_list():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([]))
session.execute = AsyncMock()
await delete_dataset_related_nodes(uuid4(), session=session)
session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()