From 15b7b8ef2b10316c1e28799e35fa2cb60e911e0b Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Wed, 20 Nov 2024 14:54:35 +0100 Subject: [PATCH] fix: Resolve issue with table names in SQL commands Some SQL commands require lowercase characters in table names unless table name is wrapped in quotes. Renamed all new tables to use lowercase Fix COG-677 --- cognee/infrastructure/engine/models/DataPoint.py | 1 + cognee/modules/chunking/models/DocumentChunk.py | 1 + cognee/modules/engine/models/Entity.py | 1 + cognee/modules/engine/models/EntityType.py | 1 + cognee/tasks/chunks/query_chunks.py | 2 +- cognee/tasks/graph/query_graph_connections.py | 4 ++-- cognee/tasks/storage/index_data_points.py | 4 ++-- cognee/tasks/summarization/models.py | 1 + cognee/tasks/summarization/query_summaries.py | 2 +- cognee/tests/test_pgvector.py | 2 +- 10 files changed, 12 insertions(+), 7 deletions(-) diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 337306cb6..f8ea1c9f0 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -8,6 +8,7 @@ class MetaData(TypedDict): index_fields: list[str] class DataPoint(BaseModel): + __tablename__ = "data_point" id: UUID = Field(default_factory = uuid4) updated_at: Optional[datetime] = datetime.now(timezone.utc) _metadata: Optional[MetaData] = { diff --git a/cognee/modules/chunking/models/DocumentChunk.py b/cognee/modules/chunking/models/DocumentChunk.py index 975edb27e..b5b1cef94 100644 --- a/cognee/modules/chunking/models/DocumentChunk.py +++ b/cognee/modules/chunking/models/DocumentChunk.py @@ -3,6 +3,7 @@ from cognee.infrastructure.engine import DataPoint from cognee.modules.data.processing.document_types import Document class DocumentChunk(DataPoint): + __tablename__ = "document_chunk" text: str word_count: int chunk_index: int diff --git a/cognee/modules/engine/models/Entity.py b/cognee/modules/engine/models/Entity.py index c43774e38..cf946ceb6 100644 --- a/cognee/modules/engine/models/Entity.py +++ b/cognee/modules/engine/models/Entity.py @@ -3,6 +3,7 @@ from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from .EntityType import EntityType class Entity(DataPoint): + __tablename__ = "entity" name: str is_a: EntityType description: str diff --git a/cognee/modules/engine/models/EntityType.py b/cognee/modules/engine/models/EntityType.py index b4f495857..56092f261 100644 --- a/cognee/modules/engine/models/EntityType.py +++ b/cognee/modules/engine/models/EntityType.py @@ -2,6 +2,7 @@ from cognee.infrastructure.engine import DataPoint from cognee.modules.chunking.models.DocumentChunk import DocumentChunk class EntityType(DataPoint): + __tablename__ = "entity_type" name: str type: str description: str diff --git a/cognee/tasks/chunks/query_chunks.py b/cognee/tasks/chunks/query_chunks.py index 93f32a640..399528ee9 100644 --- a/cognee/tasks/chunks/query_chunks.py +++ b/cognee/tasks/chunks/query_chunks.py @@ -10,7 +10,7 @@ async def query_chunks(query: str) -> list[dict]: """ vector_engine = get_vector_engine() - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit = 5) + found_chunks = await vector_engine.search("document_chunk_text", query, limit = 5) chunks = [result.payload for result in found_chunks] diff --git a/cognee/tasks/graph/query_graph_connections.py b/cognee/tasks/graph/query_graph_connections.py index cd4d76a5e..4020ddd13 100644 --- a/cognee/tasks/graph/query_graph_connections.py +++ b/cognee/tasks/graph/query_graph_connections.py @@ -27,8 +27,8 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s else: vector_engine = get_vector_engine() results = await asyncio.gather( - vector_engine.search("Entity_name", query_text = query, limit = 5), - vector_engine.search("EntityType_name", query_text = query, limit = 5), + vector_engine.search("entity_name", query_text = query, limit = 5), + vector_engine.search("entity_type_name", query_text = query, limit = 5), ) results = [*results[0], *results[1]] relevant_results = [result for result in results if result.score < 0.5][:5] diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index dc74d705d..12903173a 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -16,10 +16,10 @@ async def index_data_points(data_points: list[DataPoint]): data_point_type = type(data_point) for field_name in data_point._metadata["index_fields"]: - index_name = f"{data_point_type.__name__}.{field_name}" + index_name = f"{data_point_type.__tablename__}.{field_name}" if index_name not in created_indexes: - await vector_engine.create_vector_index(data_point_type.__name__, field_name) + await vector_engine.create_vector_index(data_point_type.__tablename__, field_name) created_indexes[index_name] = True if index_name not in index_points: diff --git a/cognee/tasks/summarization/models.py b/cognee/tasks/summarization/models.py index c6a932b37..955c0e2fa 100644 --- a/cognee/tasks/summarization/models.py +++ b/cognee/tasks/summarization/models.py @@ -3,6 +3,7 @@ from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.data.processing.document_types import Document class TextSummary(DataPoint): + __tablename__ = "text_summary" text: str made_from: DocumentChunk diff --git a/cognee/tasks/summarization/query_summaries.py b/cognee/tasks/summarization/query_summaries.py index 896839143..d9ec0fa00 100644 --- a/cognee/tasks/summarization/query_summaries.py +++ b/cognee/tasks/summarization/query_summaries.py @@ -10,7 +10,7 @@ async def query_summaries(query: str) -> list: """ vector_engine = get_vector_engine() - summaries_results = await vector_engine.search("TextSummary_text", query, limit = 5) + summaries_results = await vector_engine.search("text_summary_text", query, limit = 5) summaries = [summary.payload for summary in summaries_results] diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index b5a6fc446..1466e195f 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -65,7 +65,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query_text = random_node_name)