Add type to DataPoint metadata (#364)

* Add type to DataPoint metadata

* Add missing index_fields

* Use DataPoint UUID type in pgvector create_data_points

* Make _metadata mandatory everywhere
This commit is contained in:
alekszievr 2024-12-16 16:27:03 +01:00 committed by GitHub
parent 5360093097
commit bfa0f06fb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 167 additions and 65 deletions

View file

@ -1,21 +1,26 @@
import asyncio import asyncio
# from datetime import datetime # from datetime import datetime
import json import json
from uuid import UUID
from textwrap import dedent from textwrap import dedent
from uuid import UUID
from falkordb import FalkorDB from falkordb import FalkorDB
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.databases.graph.graph_db_interface import \
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface GraphDBInterface
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface from cognee.infrastructure.databases.vector.vector_db_interface import \
VectorDBInterface
from cognee.infrastructure.engine import DataPoint
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
text: str text: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"] "index_fields": ["text"],
"type": "IndexSchema"
} }
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):

View file

@ -1,25 +1,29 @@
from typing import List, Optional, get_type_hints, Generic, TypeVar
import asyncio import asyncio
from typing import Generic, List, Optional, TypeVar, get_type_hints
from uuid import UUID from uuid import UUID
import lancedb import lancedb
from lancedb.pydantic import LanceModel, Vector
from pydantic import BaseModel from pydantic import BaseModel
from lancedb.pydantic import Vector, LanceModel
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.storage.utils import copy_model, get_own_properties from cognee.modules.storage.utils import copy_model, get_own_properties
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
from ..utils import normalize_distances
from ..vector_db_interface import VectorDBInterface
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
id: str id: str
text: str text: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"] "index_fields": ["text"],
"type": "IndexSchema"
} }
class LanceDBAdapter(VectorDBInterface): class LanceDBAdapter(VectorDBInterface):

View file

@ -4,10 +4,12 @@ import asyncio
import logging import logging
from typing import List, Optional from typing import List, Optional
from uuid import UUID from uuid import UUID
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
logger = logging.getLogger("MilvusAdapter") logger = logging.getLogger("MilvusAdapter")
@ -16,7 +18,8 @@ class IndexSchema(DataPoint):
text: str text: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"] "index_fields": ["text"],
"type": "IndexSchema"
} }

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
from uuid import UUID
from typing import List, Optional, get_type_hints from typing import List, Optional, get_type_hints
from uuid import UUID
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import JSON, Column, Table, select, delete, MetaData from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
@ -9,19 +10,21 @@ from cognee.exceptions import InvalidValueError
from cognee.infrastructure.databases.exceptions import EntityNotFoundError from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from .serialize_data import serialize_data
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ...relational.ModelBase import Base from ...relational.ModelBase import Base
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
from ..utils import normalize_distances
from ..vector_db_interface import VectorDBInterface
from .serialize_data import serialize_data
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
text: str text: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"] "index_fields": ["text"],
"type": "IndexSchema"
} }
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
@ -89,6 +92,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def create_data_points( async def create_data_points(
self, collection_name: str, data_points: List[DataPoint] self, collection_name: str, data_points: List[DataPoint]
): ):
data_point_types = get_type_hints(DataPoint)
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
await self.create_collection( await self.create_collection(
collection_name = collection_name, collection_name = collection_name,
@ -108,7 +112,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
primary_key: Mapped[int] = mapped_column( primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True primary_key=True, autoincrement=True
) )
id: Mapped[type(data_points[0].id)] id: Mapped[data_point_types["id"]]
payload = Column(JSON) payload = Column(JSON)
vector = Column(self.Vector(vector_size)) vector = Column(self.Vector(vector_size))

View file

@ -1,13 +1,16 @@
import logging import logging
from typing import Dict, List, Optional
from uuid import UUID from uuid import UUID
from typing import List, Dict, Optional
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult from cognee.infrastructure.databases.vector.models.ScoredResult import \
ScoredResult
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..vector_db_interface import VectorDBInterface
logger = logging.getLogger("QDrantAdapter") logger = logging.getLogger("QDrantAdapter")
@ -15,7 +18,8 @@ class IndexSchema(DataPoint):
text: str text: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"] "index_fields": ["text"],
"type": "IndexSchema"
} }
# class CollectionConfig(BaseModel, extra = "forbid"): # class CollectionConfig(BaseModel, extra = "forbid"):

View file

@ -5,9 +5,10 @@ from uuid import UUID
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
logger = logging.getLogger("WeaviateAdapter") logger = logging.getLogger("WeaviateAdapter")
@ -15,7 +16,8 @@ class IndexSchema(DataPoint):
text: str text: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"] "index_fields": ["text"],
"type": "IndexSchema"
} }
class WeaviateAdapter(VectorDBInterface): class WeaviateAdapter(VectorDBInterface):

View file

@ -1,8 +1,10 @@
from typing_extensions import TypedDict
from uuid import UUID, uuid4
from typing import Optional
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional
from uuid import UUID, uuid4
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import TypedDict
class MetaData(TypedDict): class MetaData(TypedDict):
index_fields: list[str] index_fields: list[str]
@ -13,7 +15,8 @@ class DataPoint(BaseModel):
updated_at: Optional[datetime] = datetime.now(timezone.utc) updated_at: Optional[datetime] = datetime.now(timezone.utc)
topological_rank: Optional[int] = 0 topological_rank: Optional[int] = 0
_metadata: Optional[MetaData] = { _metadata: Optional[MetaData] = {
"index_fields": [] "index_fields": [],
"type": "DataPoint"
} }
# class Config: # class Config:
@ -39,4 +42,4 @@ class DataPoint(BaseModel):
@classmethod @classmethod
def get_embeddable_property_names(self, data_point): def get_embeddable_property_names(self, data_point):
return data_point._metadata["index_fields"] or [] return data_point._metadata["index_fields"] or []

View file

@ -1,8 +1,10 @@
from typing import List, Optional from typing import List, Optional
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity from cognee.modules.engine.models import Entity
class DocumentChunk(DataPoint): class DocumentChunk(DataPoint):
__tablename__ = "document_chunk" __tablename__ = "document_chunk"
text: str text: str
@ -12,6 +14,7 @@ class DocumentChunk(DataPoint):
is_part_of: Document is_part_of: Document
contains: List[Entity] = None contains: List[Entity] = None
_metadata: Optional[dict] = { _metadata: dict = {
"index_fields": ["text"], "index_fields": ["text"],
"type": "DocumentChunk"
} }

View file

@ -1,12 +1,17 @@
from cognee.infrastructure.engine import DataPoint
from uuid import UUID from uuid import UUID
from cognee.infrastructure.engine import DataPoint
class Document(DataPoint): class Document(DataPoint):
type: str
name: str name: str
raw_data_location: str raw_data_location: str
metadata_id: UUID metadata_id: UUID
mime_type: str mime_type: str
_metadata: dict = {
"index_fields": ["name"],
"type": "Document"
}
def read(self, chunk_size: int) -> str: def read(self, chunk_size: int) -> str:
pass pass

View file

@ -10,4 +10,5 @@ class Entity(DataPoint):
_metadata: dict = { _metadata: dict = {
"index_fields": ["name"], "index_fields": ["name"],
"type": "Entity"
} }

View file

@ -1,11 +1,12 @@
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
class EntityType(DataPoint): class EntityType(DataPoint):
__tablename__ = "entity_type" __tablename__ = "entity_type"
name: str name: str
type: str
description: str description: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["name"], "index_fields": ["name"],
"type": "EntityType"
} }

View file

@ -1,11 +1,14 @@
from typing import Optional from typing import Optional
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
class EdgeType(DataPoint): class EdgeType(DataPoint):
__tablename__ = "edge_type" __tablename__ = "edge_type"
relationship_name: str relationship_name: str
number_of_edges: int number_of_edges: int
_metadata: Optional[dict] = { _metadata: dict = {
"index_fields": ["relationship_name"], "index_fields": ["relationship_name"],
"type": "EdgeType"
} }

View file

@ -2,7 +2,7 @@ from cognee.infrastructure.engine import DataPoint
def convert_node_to_data_point(node_data: dict) -> DataPoint: def convert_node_to_data_point(node_data: dict) -> DataPoint:
subclass = find_subclass_by_name(DataPoint, node_data["type"]) subclass = find_subclass_by_name(DataPoint, node_data._metadata["type"])
return subclass(**node_data) return subclass(**node_data)

View file

@ -1,15 +1,19 @@
from typing import List, Optional from typing import List, Optional
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
class Repository(DataPoint): class Repository(DataPoint):
__tablename__ = "Repository" __tablename__ = "Repository"
path: str path: str
type: Optional[str] = "Repository" _metadata: dict = {
"index_fields": ["source_code"],
"type": "Repository"
}
class CodeFile(DataPoint): class CodeFile(DataPoint):
__tablename__ = "codefile" __tablename__ = "codefile"
extracted_id: str # actually file path extracted_id: str # actually file path
type: Optional[str] = "CodeFile"
source_code: Optional[str] = None source_code: Optional[str] = None
part_of: Optional[Repository] = None part_of: Optional[Repository] = None
depends_on: Optional[List["CodeFile"]] = None depends_on: Optional[List["CodeFile"]] = None
@ -17,24 +21,27 @@ class CodeFile(DataPoint):
contains: Optional[List["CodePart"]] = None contains: Optional[List["CodePart"]] = None
_metadata: dict = { _metadata: dict = {
"index_fields": ["source_code"] "index_fields": ["source_code"],
"type": "CodeFile"
} }
class CodePart(DataPoint): class CodePart(DataPoint):
__tablename__ = "codepart" __tablename__ = "codepart"
# part_of: Optional[CodeFile] # part_of: Optional[CodeFile]
source_code: str source_code: str
type: Optional[str] = "CodePart"
_metadata: dict = { _metadata: dict = {
"index_fields": ["source_code"] "index_fields": ["source_code"],
"type": "CodePart"
} }
class CodeRelationship(DataPoint): class CodeRelationship(DataPoint):
source_id: str source_id: str
target_id: str target_id: str
type: str # between files
relation: str # depends on or depends directly relation: str # depends on or depends directly
_metadata: dict = {
"type": "CodeRelationship"
}
CodeFile.model_rebuild() CodeFile.model_rebuild()
CodePart.model_rebuild() CodePart.model_rebuild()

View file

@ -1,79 +1,90 @@
from typing import Any, List, Union, Literal, Optional from typing import Any, List, Literal, Optional, Union
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
class Variable(DataPoint): class Variable(DataPoint):
id: str id: str
name: str name: str
type: Literal["Variable"] = "Variable"
description: str description: str
is_static: Optional[bool] = False is_static: Optional[bool] = False
default_value: Optional[str] = None default_value: Optional[str] = None
data_type: str data_type: str
_metadata = { _metadata = {
"index_fields": ["name"] "index_fields": ["name"],
"type": "Variable"
} }
class Operator(DataPoint): class Operator(DataPoint):
id: str id: str
name: str name: str
type: Literal["Operator"] = "Operator"
description: str description: str
return_type: str return_type: str
_metadata = {
"index_fields": ["name"],
"type": "Operator"
}
class Class(DataPoint): class Class(DataPoint):
id: str id: str
name: str name: str
type: Literal["Class"] = "Class"
description: str description: str
constructor_parameters: List[Variable] constructor_parameters: List[Variable]
extended_from_class: Optional["Class"] = None extended_from_class: Optional["Class"] = None
has_methods: List["Function"] has_methods: List["Function"]
_metadata = { _metadata = {
"index_fields": ["name"] "index_fields": ["name"],
"type": "Class"
} }
class ClassInstance(DataPoint): class ClassInstance(DataPoint):
id: str id: str
name: str name: str
type: Literal["ClassInstance"] = "ClassInstance"
description: str description: str
from_class: Class from_class: Class
instantiated_by: Union["Function"] instantiated_by: Union["Function"]
instantiation_arguments: List[Variable] instantiation_arguments: List[Variable]
_metadata = { _metadata = {
"index_fields": ["name"] "index_fields": ["name"],
"type": "ClassInstance"
} }
class Function(DataPoint): class Function(DataPoint):
id: str id: str
name: str name: str
type: Literal["Function"] = "Function"
description: str description: str
parameters: List[Variable] parameters: List[Variable]
return_type: str return_type: str
is_static: Optional[bool] = False is_static: Optional[bool] = False
_metadata = { _metadata = {
"index_fields": ["name"] "index_fields": ["name"],
"type": "Function"
} }
class FunctionCall(DataPoint): class FunctionCall(DataPoint):
id: str id: str
type: Literal["FunctionCall"] = "FunctionCall"
called_by: Union[Function, Literal["main"]] called_by: Union[Function, Literal["main"]]
function_called: Function function_called: Function
function_arguments: List[Any] function_arguments: List[Any]
_metadata = {
"index_fields": [],
"type": "FunctionCall"
}
class Expression(DataPoint): class Expression(DataPoint):
id: str id: str
name: str name: str
type: Literal["Expression"] = "Expression"
description: str description: str
expression: str expression: str
members: List[Union[Variable, Function, Operator, "Expression"]] members: List[Union[Variable, Function, Operator, "Expression"]]
_metadata = {
"index_fields": ["name"],
"type": "Expression"
}
class SourceCodeGraph(DataPoint): class SourceCodeGraph(DataPoint):
id: str id: str
@ -89,8 +100,13 @@ class SourceCodeGraph(DataPoint):
Operator, Operator,
Expression, Expression,
]] ]]
_metadata = {
"index_fields": ["name"],
"type": "SourceCodeGraph"
}
Class.model_rebuild() Class.model_rebuild()
ClassInstance.model_rebuild() ClassInstance.model_rebuild()
Expression.model_rebuild() Expression.model_rebuild()
FunctionCall.model_rebuild() FunctionCall.model_rebuild()
SourceCodeGraph.model_rebuild() SourceCodeGraph.model_rebuild()

View file

@ -1,6 +1,7 @@
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
async def index_data_points(data_points: list[DataPoint]): async def index_data_points(data_points: list[DataPoint]):
created_indexes = {} created_indexes = {}
index_points = {} index_points = {}
@ -80,11 +81,20 @@ if __name__ == "__main__":
class Car(DataPoint): class Car(DataPoint):
model: str model: str
color: str color: str
_metadata = {
"index_fields": ["name"],
"type": "Car"
}
class Person(DataPoint): class Person(DataPoint):
name: str name: str
age: int age: int
owns_car: list[Car] owns_car: list[Car]
_metadata = {
"index_fields": ["name"],
"type": "Person"
}
car1 = Car(model = "Tesla Model S", color = "Blue") car1 = Car(model = "Tesla Model S", color = "Blue")
car2 = Car(model = "Toyota Camry", color = "Red") car2 = Car(model = "Toyota Camry", color = "Red")
@ -92,4 +102,4 @@ if __name__ == "__main__":
data_points = get_data_points_from_model(person) data_points = get_data_points_from_model(person)
print(data_points) print(data_points)

View file

@ -10,6 +10,7 @@ class TextSummary(DataPoint):
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"], "index_fields": ["text"],
"type": "TextSummary"
} }
@ -20,4 +21,5 @@ class CodeSummary(DataPoint):
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"], "index_fields": ["text"],
"type": "CodeSummary"
} }

View file

@ -2,7 +2,7 @@ import asyncio
import random import random
import time import time
from typing import List from typing import List
from uuid import uuid5, NAMESPACE_OID from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model from cognee.modules.graph.utils import get_graph_from_model
@ -11,16 +11,28 @@ random.seed(1500)
class Repository(DataPoint): class Repository(DataPoint):
path: str path: str
_metadata = {
"index_fields": [],
"type": "Repository"
}
class CodeFile(DataPoint): class CodeFile(DataPoint):
part_of: Repository part_of: Repository
contains: List["CodePart"] = [] contains: List["CodePart"] = []
depends_on: List["CodeFile"] = [] depends_on: List["CodeFile"] = []
source_code: str source_code: str
_metadata = {
"index_fields": [],
"type": "CodeFile"
}
class CodePart(DataPoint): class CodePart(DataPoint):
part_of: CodeFile part_of: CodeFile
source_code: str source_code: str
_metadata = {
"index_fields": [],
"type": "CodePart"
}
CodeFile.model_rebuild() CodeFile.model_rebuild()
CodePart.model_rebuild() CodePart.model_rebuild()

View file

@ -1,25 +1,42 @@
import asyncio import asyncio
import random import random
from typing import List from typing import List
from uuid import uuid5, NAMESPACE_OID from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model from cognee.modules.graph.utils import get_graph_from_model
class Document(DataPoint): class Document(DataPoint):
path: str path: str
_metadata = {
"index_fields": [],
"type": "Document"
}
class DocumentChunk(DataPoint): class DocumentChunk(DataPoint):
part_of: Document part_of: Document
text: str text: str
contains: List["Entity"] = None contains: List["Entity"] = None
_metadata = {
"index_fields": ["text"],
"type": "DocumentChunk"
}
class EntityType(DataPoint): class EntityType(DataPoint):
name: str name: str
_metadata = {
"index_fields": ["name"],
"type": "EntityType"
}
class Entity(DataPoint): class Entity(DataPoint):
name: str name: str
is_type: EntityType is_type: EntityType
_metadata = {
"index_fields": ["name"],
"type": "Entity"
}
DocumentChunk.model_rebuild() DocumentChunk.model_rebuild()