Merge 78d9d20815 into af72dd2fc2
This commit is contained in:
commit
e3950e6cfe
12 changed files with 349 additions and 0 deletions
|
|
@ -1 +1,7 @@
|
|||
from .get_formatted_graph_data import get_formatted_graph_data
|
||||
|
||||
from .delete_data_related_nodes import delete_data_related_nodes
|
||||
from .delete_data_related_edges import delete_data_related_edges
|
||||
|
||||
from .delete_dataset_related_nodes import delete_dataset_related_nodes
|
||||
from .delete_dataset_related_edges import delete_dataset_related_edges
|
||||
|
|
|
|||
13
cognee/modules/graph/methods/delete_data_related_edges.py
Normal file
13
cognee/modules/graph/methods/delete_data_related_edges.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Edge
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def delete_data_related_edges(data_id: UUID, session: AsyncSession):
|
||||
edges = (await session.scalars(select(Edge).where(Edge.data_id == data_id))).all()
|
||||
|
||||
await session.execute(delete(Edge).where(Edge.id.in_([edge.id for edge in edges])))
|
||||
13
cognee/modules/graph/methods/delete_data_related_nodes.py
Normal file
13
cognee/modules/graph/methods/delete_data_related_nodes.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Node
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def delete_data_related_nodes(data_id: UUID, session: AsyncSession):
|
||||
nodes = (await session.scalars(select(Node).where(Node.data_id == data_id))).all()
|
||||
|
||||
await session.execute(delete(Node).where(Node.id.in_([node.id for node in nodes])))
|
||||
13
cognee/modules/graph/methods/delete_dataset_related_edges.py
Normal file
13
cognee/modules/graph/methods/delete_dataset_related_edges.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Edge
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def delete_dataset_related_edges(dataset_id: UUID, session: AsyncSession):
|
||||
edges = (await session.scalars(select(Edge).where(Edge.dataset_id == dataset_id))).all()
|
||||
|
||||
await session.execute(delete(Edge).where(Edge.id.in_([edge.id for edge in edges])))
|
||||
13
cognee/modules/graph/methods/delete_dataset_related_nodes.py
Normal file
13
cognee/modules/graph/methods/delete_dataset_related_nodes.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Node
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def delete_dataset_related_nodes(dataset_id: UUID, session: AsyncSession):
|
||||
nodes = (await session.scalars(select(Node).where(Node.dataset_id == dataset_id))).all()
|
||||
|
||||
await session.execute(delete(Node).where(Node.id.in_([node.id for node in nodes])))
|
||||
58
cognee/modules/graph/models/Edge.py
Normal file
58
cognee/modules/graph/models/Edge.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
from datetime import datetime, timezone
|
||||
from sqlalchemy import (
|
||||
# event,
|
||||
DateTime,
|
||||
JSON,
|
||||
UUID,
|
||||
Text,
|
||||
)
|
||||
|
||||
# from sqlalchemy.schema import DDL
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class Edge(Base):
|
||||
__tablename__ = "edges"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)
|
||||
|
||||
slug: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False)
|
||||
|
||||
dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False)
|
||||
|
||||
source_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
destination_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
relationship_name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
label: Mapped[str | None] = mapped_column(Text)
|
||||
attributes: Mapped[dict | None] = mapped_column(JSON)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
|
||||
# __table_args__ = (
|
||||
# {"postgresql_partition_by": "HASH (user_id)"}, # partitioning by user
|
||||
# )
|
||||
|
||||
|
||||
# Enable row-level security (RLS) for edges
|
||||
# enable_edge_rls = DDL("""
|
||||
# ALTER TABLE edges ENABLE ROW LEVEL SECURITY;
|
||||
# """)
|
||||
# create_user_isolation_policy = DDL("""
|
||||
# CREATE POLICY user_isolation_policy
|
||||
# ON edges
|
||||
# USING (user_id = current_setting('app.current_user_id')::uuid)
|
||||
# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid);
|
||||
# """)
|
||||
|
||||
# event.listen(Edge.__table__, "after_create", enable_edge_rls)
|
||||
# event.listen(Edge.__table__, "after_create", create_user_isolation_policy)
|
||||
59
cognee/modules/graph/models/Node.py
Normal file
59
cognee/modules/graph/models/Node.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
from datetime import datetime, timezone
|
||||
from sqlalchemy import (
|
||||
DateTime,
|
||||
Index,
|
||||
# event,
|
||||
String,
|
||||
JSON,
|
||||
UUID,
|
||||
)
|
||||
|
||||
# from sqlalchemy.schema import DDL
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class Node(Base):
|
||||
__tablename__ = "nodes"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)
|
||||
|
||||
slug: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False)
|
||||
|
||||
label: Mapped[str | None] = mapped_column(String(255))
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
indexed_fields: Mapped[list] = mapped_column(JSON, nullable=False)
|
||||
|
||||
attributes: Mapped[dict | None] = mapped_column(JSON)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("index_node_dataset_slug", "dataset_id", "slug"),
|
||||
Index("index_node_dataset_data", "dataset_id", "data_id"),
|
||||
# {"postgresql_partition_by": "HASH (user_id)"}, # HASH partitioning on user_id
|
||||
)
|
||||
|
||||
|
||||
# Enable row-level security (RLS) for nodes
|
||||
# enable_node_rls = DDL("""
|
||||
# ALTER TABLE nodes ENABLE ROW LEVEL SECURITY;
|
||||
# """)
|
||||
# create_user_isolation_policy = DDL("""
|
||||
# CREATE POLICY user_isolation_policy
|
||||
# ON nodes
|
||||
# USING (user_id = current_setting('app.current_user_id')::uuid)
|
||||
# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid);
|
||||
# """)
|
||||
|
||||
# event.listen(Node.__table__, "after_create", enable_node_rls)
|
||||
# event.listen(Node.__table__, "after_create", create_user_isolation_policy)
|
||||
2
cognee/modules/graph/models/__init__.py
Normal file
2
cognee/modules/graph/models/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from .Edge import Edge
|
||||
from .Node import Node
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
import pytest
|
||||
from uuid import uuid4
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from cognee.modules.graph.methods import delete_data_related_edges
|
||||
|
||||
|
||||
class DummyScalarResult:
|
||||
def __init__(self, items):
|
||||
self._items = items
|
||||
|
||||
def all(self):
|
||||
return self._items
|
||||
|
||||
|
||||
class FakeEdge:
|
||||
def __init__(self, edge_id):
|
||||
self.id = edge_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_data_related_edges_deletes_found_rows():
|
||||
session = SimpleNamespace()
|
||||
session.scalars = AsyncMock(return_value=DummyScalarResult([FakeEdge(1), FakeEdge(2)]))
|
||||
session.execute = AsyncMock()
|
||||
|
||||
await delete_data_related_edges(uuid4(), session=session)
|
||||
|
||||
session.scalars.assert_awaited_once()
|
||||
session.execute.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_data_related_edges_handles_empty_list():
|
||||
session = SimpleNamespace()
|
||||
session.scalars = AsyncMock(return_value=DummyScalarResult([]))
|
||||
session.execute = AsyncMock()
|
||||
|
||||
await delete_data_related_edges(uuid4(), session=session)
|
||||
|
||||
session.scalars.assert_awaited_once()
|
||||
session.execute.assert_awaited_once()
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
import pytest
|
||||
from uuid import uuid4
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from cognee.modules.graph.methods import delete_data_related_nodes
|
||||
|
||||
|
||||
class DummyScalarResult:
|
||||
def __init__(self, items):
|
||||
self._items = items
|
||||
|
||||
def all(self):
|
||||
return self._items
|
||||
|
||||
|
||||
class FakeNode:
|
||||
def __init__(self, node_id):
|
||||
self.id = node_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_data_related_nodes_deletes_found_rows():
|
||||
session = SimpleNamespace()
|
||||
session.scalars = AsyncMock(return_value=DummyScalarResult([FakeNode(1), FakeNode(2)]))
|
||||
session.execute = AsyncMock()
|
||||
|
||||
await delete_data_related_nodes(uuid4(), session=session)
|
||||
|
||||
session.scalars.assert_awaited_once()
|
||||
session.execute.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_data_related_nodes_handles_empty_list():
|
||||
session = SimpleNamespace()
|
||||
session.scalars = AsyncMock(return_value=DummyScalarResult([]))
|
||||
session.execute = AsyncMock()
|
||||
|
||||
await delete_data_related_nodes(uuid4(), session=session)
|
||||
|
||||
session.scalars.assert_awaited_once()
|
||||
session.execute.assert_awaited_once()
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
import pytest
|
||||
from uuid import uuid4
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from cognee.modules.graph.methods import delete_dataset_related_edges
|
||||
|
||||
|
||||
class DummyScalarResult:
|
||||
def __init__(self, items):
|
||||
self._items = items
|
||||
|
||||
def all(self):
|
||||
return self._items
|
||||
|
||||
|
||||
class FakeEdge:
|
||||
def __init__(self, edge_id):
|
||||
self.id = edge_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_dataset_related_edges_deletes_found_rows():
|
||||
session = SimpleNamespace()
|
||||
session.scalars = AsyncMock(return_value=DummyScalarResult([FakeEdge(1), FakeEdge(2)]))
|
||||
session.execute = AsyncMock()
|
||||
|
||||
await delete_dataset_related_edges(uuid4(), session=session)
|
||||
|
||||
session.scalars.assert_awaited_once()
|
||||
session.execute.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_dataset_related_edges_handles_empty_list():
|
||||
session = SimpleNamespace()
|
||||
session.scalars = AsyncMock(return_value=DummyScalarResult([]))
|
||||
session.execute = AsyncMock()
|
||||
|
||||
await delete_dataset_related_edges(uuid4(), session=session)
|
||||
|
||||
session.scalars.assert_awaited_once()
|
||||
session.execute.assert_awaited_once()
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
import pytest
|
||||
from uuid import uuid4
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from cognee.modules.graph.methods import delete_dataset_related_nodes
|
||||
|
||||
|
||||
class DummyScalarResult:
|
||||
def __init__(self, items):
|
||||
self._items = items
|
||||
|
||||
def all(self):
|
||||
return self._items
|
||||
|
||||
|
||||
class FakeNode:
|
||||
def __init__(self, node_id):
|
||||
self.id = node_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_dataset_related_nodes_deletes_found_rows():
|
||||
session = SimpleNamespace()
|
||||
session.scalars = AsyncMock(return_value=DummyScalarResult([FakeNode(1), FakeNode(2)]))
|
||||
session.execute = AsyncMock()
|
||||
|
||||
await delete_dataset_related_nodes(uuid4(), session=session)
|
||||
|
||||
session.scalars.assert_awaited_once()
|
||||
session.execute.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_dataset_related_nodes_handles_empty_list():
|
||||
session = SimpleNamespace()
|
||||
session.scalars = AsyncMock(return_value=DummyScalarResult([]))
|
||||
session.execute = AsyncMock()
|
||||
|
||||
await delete_dataset_related_nodes(uuid4(), session=session)
|
||||
|
||||
session.scalars.assert_awaited_once()
|
||||
session.execute.assert_awaited_once()
|
||||
Loading…
Add table
Reference in a new issue