fix: support structured data conversion to data points (#512)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - New Features - Introduced version tracking and enhanced metadata in core data models for improved data consistency. - Bug Fixes - Improved error handling during graph data loading to prevent disruptions from unexpected identifier formats. - Refactor - Centralized identifier parsing and streamlined model definitions, ensuring smoother and more consistent operations across search, retrieval, and indexing workflows. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
55a50153c6
commit
8f84713b54
25 changed files with 51 additions and 63 deletions
|
|
@ -25,8 +25,6 @@ from cognee.tasks.documents import (
|
||||||
)
|
)
|
||||||
from cognee.tasks.graph import extract_graph_from_data
|
from cognee.tasks.graph import extract_graph_from_data
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
from cognee.modules.data.methods import store_descriptive_metrics
|
|
||||||
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
|
||||||
from cognee.tasks.summarization import summarize_text
|
from cognee.tasks.summarization import summarize_text
|
||||||
|
|
||||||
logger = logging.getLogger("cognify.v2")
|
logger = logging.getLogger("cognify.v2")
|
||||||
|
|
@ -112,8 +110,6 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, tasks: list[Task]):
|
||||||
async for result in pipeline:
|
async for result in pipeline:
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
await index_graph_edges()
|
|
||||||
|
|
||||||
send_telemetry("cognee.cognify EXECUTION COMPLETED", user.id)
|
send_telemetry("cognee.cognify EXECUTION COMPLETED", user.id)
|
||||||
|
|
||||||
await log_pipeline_status(
|
await log_pipeline_status(
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from re import A
|
|
||||||
from typing import Dict, Any, List, Union
|
from typing import Dict, Any, List, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
|
@ -13,6 +12,7 @@ import aiofiles.os as aiofiles_os
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -268,7 +268,11 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
for node in graph_data["nodes"]:
|
for node in graph_data["nodes"]:
|
||||||
try:
|
try:
|
||||||
if not isinstance(node["id"], UUID):
|
if not isinstance(node["id"], UUID):
|
||||||
node["id"] = UUID(node["id"])
|
try:
|
||||||
|
node["id"] = UUID(node["id"])
|
||||||
|
except Exception:
|
||||||
|
# If conversion fails, keep the original id
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
raise e
|
raise e
|
||||||
|
|
@ -285,12 +289,12 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
for edge in graph_data["links"]:
|
for edge in graph_data["links"]:
|
||||||
try:
|
try:
|
||||||
if not isinstance(edge["source"], UUID):
|
if not isinstance(edge["source"], UUID):
|
||||||
source_id = UUID(edge["source"])
|
source_id = parse_id(edge["source"])
|
||||||
else:
|
else:
|
||||||
source_id = edge["source"]
|
source_id = edge["source"]
|
||||||
|
|
||||||
if not isinstance(edge["target"], UUID):
|
if not isinstance(edge["target"], UUID):
|
||||||
target_id = UUID(edge["target"])
|
target_id = parse_id(edge["target"])
|
||||||
else:
|
else:
|
||||||
target_id = edge["target"]
|
target_id = edge["target"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||||
return ",".join([f"{key}:{parse_value(value)}" for key, value in properties.items()])
|
return ",".join([f"{key}:{parse_value(value)}" for key, value in properties.items()])
|
||||||
|
|
||||||
async def create_data_point_query(self, data_point: DataPoint, vectorized_values: dict):
|
async def create_data_point_query(self, data_point: DataPoint, vectorized_values: dict):
|
||||||
node_label = type(data_point).__tablename__
|
node_label = type(data_point).__name__
|
||||||
property_names = DataPoint.get_embeddable_property_names(data_point)
|
property_names = DataPoint.get_embeddable_property_names(data_point)
|
||||||
|
|
||||||
node_properties = await self.stringify_properties(
|
node_properties = await self.stringify_properties(
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Generic, List, Optional, TypeVar, get_type_hints
|
from typing import Generic, List, Optional, TypeVar, get_type_hints
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
@ -8,6 +7,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||||
|
|
||||||
|
|
@ -133,7 +133,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=UUID(result["id"]),
|
id=parse_id(result["id"]),
|
||||||
payload=result["payload"],
|
payload=result["payload"],
|
||||||
score=0,
|
score=0,
|
||||||
)
|
)
|
||||||
|
|
@ -162,7 +162,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=UUID(result["id"]),
|
id=parse_id(result["id"]),
|
||||||
payload=result["payload"],
|
payload=result["payload"],
|
||||||
score=normalized_values[value_index],
|
score=normalized_values[value_index],
|
||||||
)
|
)
|
||||||
|
|
@ -195,7 +195,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=UUID(result["id"]),
|
id=parse_id(result["id"]),
|
||||||
payload=result["payload"],
|
payload=result["payload"],
|
||||||
score=normalized_values[value_index],
|
score=normalized_values[value_index],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
|
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
|
|
@ -193,7 +193,7 @@ class MilvusAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=UUID(result["id"]),
|
id=parse_id(result["id"]),
|
||||||
score=result["distance"],
|
score=result["distance"],
|
||||||
payload=result.get("entity", {}),
|
payload=result.get("entity", {}),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
|
|
||||||
from ...relational.ModelBase import Base
|
from ...relational.ModelBase import Base
|
||||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||||
|
|
@ -169,7 +170,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
results = results.all()
|
results = results.all()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(id=UUID(result.id), payload=result.payload, score=0)
|
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
|
||||||
for result in results
|
for result in results
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -208,7 +209,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
# Create and return ScoredResult objects
|
# Create and return ScoredResult objects
|
||||||
return [
|
return [
|
||||||
ScoredResult(id=UUID(str(row.id)), payload=row.payload, score=row.similarity)
|
ScoredResult(id=parse_id(str(row.id)), payload=row.payload, score=row.similarity)
|
||||||
for row in vector_list
|
for row in vector_list
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -249,7 +250,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
for vector in closest_items:
|
for vector in closest_items:
|
||||||
vector_list.append(
|
vector_list.append(
|
||||||
{
|
{
|
||||||
"id": UUID(str(vector.id)),
|
"id": parse_id(str(vector.id)),
|
||||||
"payload": vector.payload,
|
"payload": vector.payload,
|
||||||
"_distance": vector.similarity,
|
"_distance": vector.similarity,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
|
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
|
|
@ -170,10 +170,10 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=UUID(result.id),
|
id=parse_id(result.id),
|
||||||
payload={
|
payload={
|
||||||
**result.payload,
|
**result.payload,
|
||||||
"id": UUID(result.id),
|
"id": parse_id(result.id),
|
||||||
},
|
},
|
||||||
score=1 - result.score,
|
score=1 - result.score,
|
||||||
)
|
)
|
||||||
|
|
@ -209,10 +209,10 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=UUID(result.id),
|
id=parse_id(result.id),
|
||||||
payload={
|
payload={
|
||||||
**result.payload,
|
**result.payload,
|
||||||
"id": UUID(result.id),
|
"id": parse_id(result.id),
|
||||||
},
|
},
|
||||||
score=1 - result.score,
|
score=1 - result.score,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
|
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
|
|
@ -188,7 +188,7 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=UUID(str(result.uuid)),
|
id=parse_id(str(result.uuid)),
|
||||||
payload=result.properties,
|
payload=result.properties,
|
||||||
score=1 - float(result.metadata.score),
|
score=1 - float(result.metadata.score),
|
||||||
)
|
)
|
||||||
|
|
@ -221,7 +221,7 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=UUID(str(result.uuid)),
|
id=parse_id(str(result.uuid)),
|
||||||
payload=result.properties,
|
payload=result.properties,
|
||||||
score=1 - float(result.metadata.score),
|
score=1 - float(result.metadata.score),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ class MetaData(TypedDict):
|
||||||
|
|
||||||
# Updated DataPoint model with versioning and new fields
|
# Updated DataPoint model with versioning and new fields
|
||||||
class DataPoint(BaseModel):
|
class DataPoint(BaseModel):
|
||||||
__tablename__ = "data_point"
|
|
||||||
id: UUID = Field(default_factory=uuid4)
|
id: UUID = Field(default_factory=uuid4)
|
||||||
created_at: int = Field(
|
created_at: int = Field(
|
||||||
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
|
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||||
|
|
|
||||||
1
cognee/infrastructure/engine/utils/__init__.py
Normal file
1
cognee/infrastructure/engine/utils/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from .parse_id import parse_id
|
||||||
10
cognee/infrastructure/engine/utils/parse_id.py
Normal file
10
cognee/infrastructure/engine/utils/parse_id.py
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
def parse_id(id: any):
|
||||||
|
if isinstance(id, str):
|
||||||
|
try:
|
||||||
|
return UUID(id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return id
|
||||||
|
|
@ -6,14 +6,12 @@ from cognee.modules.engine.models import Entity
|
||||||
|
|
||||||
|
|
||||||
class DocumentChunk(DataPoint):
|
class DocumentChunk(DataPoint):
|
||||||
__tablename__ = "document_chunk"
|
|
||||||
text: str
|
text: str
|
||||||
word_count: int
|
word_count: int
|
||||||
token_count: int
|
token_count: int
|
||||||
chunk_index: int
|
chunk_index: int
|
||||||
cut_type: str
|
cut_type: str
|
||||||
is_part_of: Document
|
is_part_of: Document
|
||||||
pydantic_type: str = "DocumentChunk"
|
|
||||||
contains: List[Entity] = None
|
contains: List[Entity] = None
|
||||||
|
|
||||||
metadata: dict = {"index_fields": ["text"]}
|
metadata: dict = {"index_fields": ["text"]}
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,8 @@ from cognee.modules.engine.models.EntityType import EntityType
|
||||||
|
|
||||||
|
|
||||||
class Entity(DataPoint):
|
class Entity(DataPoint):
|
||||||
__tablename__ = "entity"
|
|
||||||
name: str
|
name: str
|
||||||
is_a: EntityType
|
is_a: EntityType
|
||||||
description: str
|
description: str
|
||||||
pydantic_type: str = "Entity"
|
|
||||||
|
|
||||||
metadata: dict = {"index_fields": ["name"]}
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,7 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
|
||||||
class EntityType(DataPoint):
|
class EntityType(DataPoint):
|
||||||
__tablename__ = "entity_type"
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
pydantic_type: str = "EntityType"
|
|
||||||
|
|
||||||
metadata: dict = {"index_fields": ["name"]}
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
raise ValueError("Failed to generate query embedding.")
|
raise ValueError("Failed to generate query embedding.")
|
||||||
|
|
||||||
edge_distances = await vector_engine.get_distance_from_collection_elements(
|
edge_distances = await vector_engine.get_distance_from_collection_elements(
|
||||||
"edge_type_relationship_name", query_text=query
|
"EdgeType_relationship_name", query_text=query
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
|
||||||
class EdgeType(DataPoint):
|
class EdgeType(DataPoint):
|
||||||
__tablename__ = "edge_type"
|
|
||||||
relationship_name: str
|
relationship_name: str
|
||||||
number_of_edges: int
|
number_of_edges: int
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,9 +61,6 @@ async def get_graph_from_model(
|
||||||
if include_root and str(data_point.id) not in added_nodes:
|
if include_root and str(data_point.id) not in added_nodes:
|
||||||
SimpleDataPointModel = copy_model(
|
SimpleDataPointModel = copy_model(
|
||||||
type(data_point),
|
type(data_point),
|
||||||
include_fields={
|
|
||||||
"__tablename__": (str, data_point.__tablename__),
|
|
||||||
},
|
|
||||||
exclude_fields=list(excluded_properties),
|
exclude_fields=list(excluded_properties),
|
||||||
)
|
)
|
||||||
nodes.append(SimpleDataPointModel(**data_point_properties))
|
nodes.append(SimpleDataPointModel(**data_point_properties))
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import json
|
import json
|
||||||
from uuid import UUID
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from cognee.modules.retrieval.code_graph_retrieval import code_graph_retrieval
|
from cognee.modules.retrieval.code_graph_retrieval import code_graph_retrieval
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
|
|
@ -32,7 +32,7 @@ async def search(
|
||||||
|
|
||||||
for search_result in search_results:
|
for search_result in search_results:
|
||||||
document_id = search_result["document_id"] if "document_id" in search_result else None
|
document_id = search_result["document_id"] if "document_id" in search_result else None
|
||||||
document_id = UUID(document_id) if isinstance(document_id, str) else document_id
|
document_id = parse_id(document_id)
|
||||||
|
|
||||||
if document_id is None or document_id in own_document_ids:
|
if document_id is None or document_id in own_document_ids:
|
||||||
filtered_search_results.append(search_result)
|
filtered_search_results.append(search_result)
|
||||||
|
|
|
||||||
|
|
@ -3,16 +3,12 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
|
||||||
class Repository(DataPoint):
|
class Repository(DataPoint):
|
||||||
__tablename__ = "Repository"
|
|
||||||
path: str
|
path: str
|
||||||
pydantic_type: str = "Repository"
|
|
||||||
metadata: dict = {"index_fields": []}
|
metadata: dict = {"index_fields": []}
|
||||||
|
|
||||||
|
|
||||||
class CodeFile(DataPoint):
|
class CodeFile(DataPoint):
|
||||||
__tablename__ = "codefile"
|
|
||||||
extracted_id: str # actually file path
|
extracted_id: str # actually file path
|
||||||
pydantic_type: str = "CodeFile"
|
|
||||||
source_code: Optional[str] = None
|
source_code: Optional[str] = None
|
||||||
part_of: Optional[Repository] = None
|
part_of: Optional[Repository] = None
|
||||||
depends_on: Optional[List["CodeFile"]] = None
|
depends_on: Optional[List["CodeFile"]] = None
|
||||||
|
|
@ -22,19 +18,15 @@ class CodeFile(DataPoint):
|
||||||
|
|
||||||
|
|
||||||
class CodePart(DataPoint):
|
class CodePart(DataPoint):
|
||||||
__tablename__ = "codepart"
|
|
||||||
file_path: str # file path
|
file_path: str # file path
|
||||||
# part_of: Optional[CodeFile] = None
|
# part_of: Optional[CodeFile] = None
|
||||||
pydantic_type: str = "CodePart"
|
|
||||||
source_code: Optional[str] = None
|
source_code: Optional[str] = None
|
||||||
metadata: dict = {"index_fields": []}
|
metadata: dict = {"index_fields": []}
|
||||||
|
|
||||||
|
|
||||||
class SourceCodeChunk(DataPoint):
|
class SourceCodeChunk(DataPoint):
|
||||||
__tablename__ = "sourcecodechunk"
|
|
||||||
code_chunk_of: Optional[CodePart] = None
|
code_chunk_of: Optional[CodePart] = None
|
||||||
source_code: Optional[str] = None
|
source_code: Optional[str] = None
|
||||||
pydantic_type: str = "SourceCodeChunk"
|
|
||||||
previous_chunk: Optional["SourceCodeChunk"] = None
|
previous_chunk: Optional["SourceCodeChunk"] = None
|
||||||
|
|
||||||
metadata: dict = {"index_fields": ["source_code"]}
|
metadata: dict = {"index_fields": ["source_code"]}
|
||||||
|
|
|
||||||
|
|
@ -231,7 +231,6 @@ class SummarizedContent(BaseModel):
|
||||||
|
|
||||||
summary: str
|
summary: str
|
||||||
description: str
|
description: str
|
||||||
pydantic_type: str = "SummarizedContent"
|
|
||||||
|
|
||||||
|
|
||||||
class SummarizedFunction(BaseModel):
|
class SummarizedFunction(BaseModel):
|
||||||
|
|
@ -240,7 +239,6 @@ class SummarizedFunction(BaseModel):
|
||||||
inputs: Optional[List[str]] = None
|
inputs: Optional[List[str]] = None
|
||||||
outputs: Optional[List[str]] = None
|
outputs: Optional[List[str]] = None
|
||||||
decorators: Optional[List[str]] = None
|
decorators: Optional[List[str]] = None
|
||||||
pydantic_type: str = "SummarizedFunction"
|
|
||||||
|
|
||||||
|
|
||||||
class SummarizedClass(BaseModel):
|
class SummarizedClass(BaseModel):
|
||||||
|
|
@ -248,7 +246,6 @@ class SummarizedClass(BaseModel):
|
||||||
description: str
|
description: str
|
||||||
methods: Optional[List[SummarizedFunction]] = None
|
methods: Optional[List[SummarizedFunction]] = None
|
||||||
decorators: Optional[List[str]] = None
|
decorators: Optional[List[str]] = None
|
||||||
pydantic_type: str = "SummarizedClass"
|
|
||||||
|
|
||||||
|
|
||||||
class SummarizedCode(BaseModel):
|
class SummarizedCode(BaseModel):
|
||||||
|
|
@ -259,7 +256,6 @@ class SummarizedCode(BaseModel):
|
||||||
classes: List[SummarizedClass] = []
|
classes: List[SummarizedClass] = []
|
||||||
functions: List[SummarizedFunction] = []
|
functions: List[SummarizedFunction] = []
|
||||||
workflow_description: Optional[str] = None
|
workflow_description: Optional[str] = None
|
||||||
pydantic_type: str = "SummarizedCode"
|
|
||||||
|
|
||||||
|
|
||||||
class GraphDBType(Enum):
|
class GraphDBType(Enum):
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
||||||
from .index_data_points import index_data_points
|
from .index_data_points import index_data_points
|
||||||
|
from .index_graph_edges import index_graph_edges
|
||||||
|
|
||||||
|
|
||||||
async def add_data_points(data_points: list[DataPoint]):
|
async def add_data_points(data_points: list[DataPoint]):
|
||||||
|
|
@ -38,4 +39,7 @@ async def add_data_points(data_points: list[DataPoint]):
|
||||||
await graph_engine.add_nodes(nodes)
|
await graph_engine.add_nodes(nodes)
|
||||||
await graph_engine.add_edges(edges)
|
await graph_engine.add_edges(edges)
|
||||||
|
|
||||||
|
# This step has to happen after adding nodes and edges because we query the graph.
|
||||||
|
await index_graph_edges()
|
||||||
|
|
||||||
return data_points
|
return data_points
|
||||||
|
|
|
||||||
|
|
@ -51,10 +51,10 @@ async def index_graph_edges():
|
||||||
data_point_type = type(edge)
|
data_point_type = type(edge)
|
||||||
|
|
||||||
for field_name in edge.metadata["index_fields"]:
|
for field_name in edge.metadata["index_fields"]:
|
||||||
index_name = f"{data_point_type.__tablename__}.{field_name}"
|
index_name = f"{data_point_type.__name__}.{field_name}"
|
||||||
|
|
||||||
if index_name not in created_indexes:
|
if index_name not in created_indexes:
|
||||||
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
|
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
|
||||||
created_indexes[index_name] = True
|
created_indexes[index_name] = True
|
||||||
|
|
||||||
if index_name not in index_points:
|
if index_name not in index_points:
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
|
||||||
|
|
||||||
|
|
||||||
class TextSummary(DataPoint):
|
class TextSummary(DataPoint):
|
||||||
__tablename__ = "text_summary"
|
|
||||||
text: str
|
text: str
|
||||||
made_from: DocumentChunk
|
made_from: DocumentChunk
|
||||||
|
|
||||||
|
|
@ -14,9 +13,7 @@ class TextSummary(DataPoint):
|
||||||
|
|
||||||
|
|
||||||
class CodeSummary(DataPoint):
|
class CodeSummary(DataPoint):
|
||||||
__tablename__ = "code_summary"
|
|
||||||
text: str
|
text: str
|
||||||
summarizes: Union[CodeFile, CodePart, SourceCodeChunk]
|
summarizes: Union[CodeFile, CodePart, SourceCodeChunk]
|
||||||
pydantic_type: str = "CodeSummary"
|
|
||||||
|
|
||||||
metadata: dict = {"index_fields": ["text"]}
|
metadata: dict = {"index_fields": ["text"]}
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,8 @@ from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class GraphitiNode(DataPoint):
|
class GraphitiNode(DataPoint):
|
||||||
__tablename__ = "graphitinode"
|
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
summary: Optional[str] = None
|
summary: Optional[str] = None
|
||||||
pydantic_type: str = "GraphitiNode"
|
|
||||||
|
|
||||||
metadata: dict = {"index_fields": ["name", "summary", "content"]}
|
metadata: dict = {"index_fields": ["name", "summary", "content"]}
|
||||||
|
|
|
||||||
|
|
@ -37,10 +37,10 @@ async def index_and_transform_graphiti_nodes_and_edges():
|
||||||
data_point_type = type(graphiti_node)
|
data_point_type = type(graphiti_node)
|
||||||
|
|
||||||
for field_name in graphiti_node.metadata["index_fields"]:
|
for field_name in graphiti_node.metadata["index_fields"]:
|
||||||
index_name = f"{data_point_type.__tablename__}.{field_name}"
|
index_name = f"{data_point_type.__name__}.{field_name}"
|
||||||
|
|
||||||
if index_name not in created_indexes:
|
if index_name not in created_indexes:
|
||||||
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
|
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
|
||||||
created_indexes[index_name] = True
|
created_indexes[index_name] = True
|
||||||
|
|
||||||
if index_name not in index_points:
|
if index_name not in index_points:
|
||||||
|
|
@ -66,10 +66,10 @@ async def index_and_transform_graphiti_nodes_and_edges():
|
||||||
data_point_type = type(edge)
|
data_point_type = type(edge)
|
||||||
|
|
||||||
for field_name in edge.metadata["index_fields"]:
|
for field_name in edge.metadata["index_fields"]:
|
||||||
index_name = f"{data_point_type.__tablename__}.{field_name}"
|
index_name = f"{data_point_type.__name__}.{field_name}"
|
||||||
|
|
||||||
if index_name not in created_indexes:
|
if index_name not in created_indexes:
|
||||||
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
|
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
|
||||||
created_indexes[index_name] = True
|
created_indexes[index_name] = True
|
||||||
|
|
||||||
if index_name not in index_points:
|
if index_name not in index_points:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue