fix: custom model pipeline (#508)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** • Graph visualizations now allow exporting to a user-specified file path for more flexible output management. • The text embedding process has been enhanced with an additional tokenizer option for improved performance. • A new `ExtendableDataPoint` class has been introduced for future extensions. • New JSON files for companies and individuals have been added to facilitate testing and data processing. - **Improvements** • Search functionality now uses updated identifiers for more reliable content retrieval. • Metadata handling has been streamlined across various classes by removing unnecessary type specifications. • Enhanced serialization of properties in the Neo4j adapter for improved handling of complex structures. • The setup process for databases has been improved with a new asynchronous setup function. - **Chores** • Dependency and configuration updates improve overall stability and performance. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
6be6b3d222
commit
f75e35c337
66 changed files with 1248 additions and 748 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -179,6 +179,7 @@ cognee/cache/
|
|||
# Default cognee system directory, used in development
|
||||
.cognee_system/
|
||||
.data_storage/
|
||||
.artifacts/
|
||||
.anon_id
|
||||
|
||||
node_modules/
|
||||
|
|
|
|||
|
|
@ -9,13 +9,19 @@ import asyncio
|
|||
from cognee.shared.utils import setup_logging
|
||||
|
||||
|
||||
async def visualize_graph():
|
||||
async def visualize_graph(destination_file_path: str = None):
|
||||
graph_engine = await get_graph_engine()
|
||||
graph_data = await graph_engine.get_graph_data()
|
||||
logging.info(graph_data)
|
||||
|
||||
graph = await cognee_network_visualization(graph_data)
|
||||
logging.info("The HTML file has been stored on your home directory! Navigate there with cd ~")
|
||||
graph = await cognee_network_visualization(graph_data, destination_file_path)
|
||||
|
||||
if destination_file_path:
|
||||
logging.info(f"The HTML file has been stored at path: {destination_file_path}")
|
||||
else:
|
||||
logging.info(
|
||||
"The HTML file has been stored on your home directory! Navigate there with cd ~"
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
"""Neo4j Adapter for Graph Database"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from textwrap import dedent
|
||||
from typing import Optional, Any, List, Dict, Union
|
||||
from typing import Optional, Any, List, Dict
|
||||
from contextlib import asynccontextmanager
|
||||
from uuid import UUID
|
||||
from neo4j import AsyncSession
|
||||
|
|
@ -11,6 +12,7 @@ from neo4j import AsyncGraphDatabase
|
|||
from neo4j.exceptions import Neo4jError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
||||
logger = logging.getLogger("Neo4jAdapter")
|
||||
|
||||
|
|
@ -434,6 +436,10 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
serialized_properties[property_key] = str(property_value)
|
||||
continue
|
||||
|
||||
if isinstance(property_value, dict):
|
||||
serialized_properties[property_key] = json.dumps(property_value, cls=JSONEncoder)
|
||||
continue
|
||||
|
||||
serialized_properties[property_key] = property_value
|
||||
|
||||
return serialized_properties
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from cognee.infrastructure.engine import DataPoint
|
|||
class IndexSchema(DataPoint):
|
||||
text: str
|
||||
|
||||
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class SQLAlchemyAdapter:
|
|||
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
|
||||
await connection.execute(
|
||||
text(
|
||||
f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"
|
||||
f'CREATE TABLE IF NOT EXISTS {schema_name}."{table_name}" ({", ".join(fields_query_parts)});'
|
||||
)
|
||||
)
|
||||
await connection.close()
|
||||
|
|
@ -71,10 +71,10 @@ class SQLAlchemyAdapter:
|
|||
if self.engine.dialect.name == "sqlite":
|
||||
# SQLite doesn’t support schema namespaces and the CASCADE keyword.
|
||||
# However, foreign key constraint can be defined with ON DELETE CASCADE during table creation.
|
||||
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
|
||||
await connection.execute(text(f'DROP TABLE IF EXISTS "{table_name}";'))
|
||||
else:
|
||||
await connection.execute(
|
||||
text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;")
|
||||
text(f'DROP TABLE IF EXISTS {schema_name}."{table_name}" CASCADE;')
|
||||
)
|
||||
|
||||
async def insert_data(
|
||||
|
|
@ -252,7 +252,7 @@ class SQLAlchemyAdapter:
|
|||
|
||||
async def get_data(self, table_name: str, filters: dict = None):
|
||||
async with self.engine.begin() as connection:
|
||||
query = f"SELECT * FROM {table_name}"
|
||||
query = f'SELECT * FROM "{table_name}"'
|
||||
if filters:
|
||||
filter_conditions = " AND ".join(
|
||||
[
|
||||
|
|
@ -336,7 +336,7 @@ class SQLAlchemyAdapter:
|
|||
await connection.run_sync(metadata.reflect, schema=schema_name)
|
||||
for table in metadata.sorted_tables:
|
||||
drop_table_query = text(
|
||||
f"DROP TABLE IF EXISTS {schema_name}.{table.name} CASCADE"
|
||||
f'DROP TABLE IF EXISTS {schema_name}."{table.name}" CASCADE'
|
||||
)
|
||||
await connection.execute(drop_table_query)
|
||||
metadata.clear()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import Em
|
|||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||
from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer
|
||||
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
|
||||
from cognee.infrastructure.llm.tokenizer.Mistral import MistralTokenizer
|
||||
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
|
@ -126,6 +127,8 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
elif "gemini" in self.provider.lower():
|
||||
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
elif "mistral" in self.provider.lower():
|
||||
tokenizer = MistralTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
else:
|
||||
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class IndexSchema(DataPoint):
|
|||
id: str
|
||||
text: str
|
||||
|
||||
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class LanceDBAdapter(VectorDBInterface):
|
||||
|
|
@ -245,7 +245,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
[
|
||||
IndexSchema(
|
||||
id=str(data_point.id),
|
||||
text=getattr(data_point, data_point._metadata["index_fields"][0]),
|
||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||
)
|
||||
for data_point in data_points
|
||||
],
|
||||
|
|
@ -269,5 +269,5 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
include_fields={
|
||||
"id": (str, ...),
|
||||
},
|
||||
exclude_fields=["_metadata"],
|
||||
exclude_fields=["metadata"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ logger = logging.getLogger("MilvusAdapter")
|
|||
class IndexSchema(DataPoint):
|
||||
text: str
|
||||
|
||||
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class MilvusAdapter(VectorDBInterface):
|
||||
|
|
@ -133,7 +133,7 @@ class MilvusAdapter(VectorDBInterface):
|
|||
formatted_data_points = [
|
||||
IndexSchema(
|
||||
id=data_point.id,
|
||||
text=getattr(data_point, data_point._metadata["index_fields"][0]),
|
||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||
)
|
||||
for data_point in data_points
|
||||
]
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from ..utils import normalize_distances
|
|||
class IndexSchema(DataPoint):
|
||||
text: str
|
||||
|
||||
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ logger = logging.getLogger("QDrantAdapter")
|
|||
class IndexSchema(DataPoint):
|
||||
text: str
|
||||
|
||||
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
# class CollectionConfig(BaseModel, extra = "forbid"):
|
||||
|
|
@ -131,7 +131,7 @@ class QDrantAdapter(VectorDBInterface):
|
|||
[
|
||||
IndexSchema(
|
||||
id=data_point.id,
|
||||
text=getattr(data_point, data_point._metadata["index_fields"][0]),
|
||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||
)
|
||||
for data_point in data_points
|
||||
],
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ logger = logging.getLogger("WeaviateAdapter")
|
|||
class IndexSchema(DataPoint):
|
||||
text: str
|
||||
|
||||
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class WeaviateAdapter(VectorDBInterface):
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from .models.DataPoint import DataPoint
|
||||
from .models.ExtendableDataPoint import ExtendableDataPoint
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pickle
|
|||
|
||||
# Define metadata type
|
||||
class MetaData(TypedDict):
|
||||
type: str
|
||||
index_fields: list[str]
|
||||
|
||||
|
||||
|
|
@ -24,35 +25,35 @@ class DataPoint(BaseModel):
|
|||
)
|
||||
version: int = 1 # Default version
|
||||
topological_rank: Optional[int] = 0
|
||||
_metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"}
|
||||
metadata: Optional[MetaData] = {"index_fields": []}
|
||||
|
||||
@classmethod
|
||||
def get_embeddable_data(self, data_point):
|
||||
def get_embeddable_data(self, data_point: "DataPoint"):
|
||||
if (
|
||||
data_point._metadata
|
||||
and len(data_point._metadata["index_fields"]) > 0
|
||||
and hasattr(data_point, data_point._metadata["index_fields"][0])
|
||||
data_point.metadata
|
||||
and len(data_point.metadata["index_fields"]) > 0
|
||||
and hasattr(data_point, data_point.metadata["index_fields"][0])
|
||||
):
|
||||
attribute = getattr(data_point, data_point._metadata["index_fields"][0])
|
||||
attribute = getattr(data_point, data_point.metadata["index_fields"][0])
|
||||
|
||||
if isinstance(attribute, str):
|
||||
return attribute.strip()
|
||||
return attribute
|
||||
|
||||
@classmethod
|
||||
def get_embeddable_properties(self, data_point):
|
||||
def get_embeddable_properties(self, data_point: "DataPoint"):
|
||||
"""Retrieve all embeddable properties."""
|
||||
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
|
||||
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0:
|
||||
return [
|
||||
getattr(data_point, field, None) for field in data_point._metadata["index_fields"]
|
||||
getattr(data_point, field, None) for field in data_point.metadata["index_fields"]
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def get_embeddable_property_names(self, data_point):
|
||||
def get_embeddable_property_names(self, data_point: "DataPoint"):
|
||||
"""Retrieve names of embeddable properties."""
|
||||
return data_point._metadata["index_fields"] or []
|
||||
return data_point.metadata["index_fields"] or []
|
||||
|
||||
def update_version(self):
|
||||
"""Update the version and updated_at timestamp."""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
from .DataPoint import DataPoint
|
||||
|
||||
|
||||
class ExtendableDataPoint(DataPoint):
|
||||
pass
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
"""Adapter for Generic API LLM provider API"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Type
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
import instructor
|
||||
|
|
@ -24,13 +23,9 @@ class GenericAPIAdapter(LLMInterface):
|
|||
self.endpoint = endpoint
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
llm_config = get_llm_config()
|
||||
|
||||
if llm_config.llm_provider == "ollama":
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||
|
||||
else:
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
|
||||
)
|
||||
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
|
|
|
|||
|
|
@ -34,5 +34,5 @@ class HuggingFaceTokenizer(TokenizerInterface):
|
|||
return len(self.tokenizer.tokenize(text))
|
||||
|
||||
def decode_single_token(self, encoding: int):
|
||||
# Gemini tokenizer doesn't have the option to decode tokens
|
||||
# HuggingFace tokenizer doesn't have the option to decode tokens
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
1
cognee/infrastructure/llm/tokenizer/Mistral/__init__.py
Normal file
1
cognee/infrastructure/llm/tokenizer/Mistral/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .adapter import MistralTokenizer
|
||||
47
cognee/infrastructure/llm/tokenizer/Mistral/adapter.py
Normal file
47
cognee/infrastructure/llm/tokenizer/Mistral/adapter.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from typing import List, Any
|
||||
|
||||
from ..tokenizer_interface import TokenizerInterface
|
||||
|
||||
|
||||
class MistralTokenizer(TokenizerInterface):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int = 3072,
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Import here to make it an optional dependency
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
self.tokenizer = MistralTokenizer.from_model(model)
|
||||
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from mistral_common.tokens.tokenizers.base import Tokenized
|
||||
|
||||
encoding: Tokenized = self.tokenizer.encode_chat_completion(
|
||||
ChatCompletionRequest(
|
||||
messages=[UserMessage(role="user", content=text)],
|
||||
model=self.model,
|
||||
)
|
||||
)
|
||||
return encoding.tokens
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Returns the number of tokens in the given text.
|
||||
Args:
|
||||
text: str
|
||||
|
||||
Returns:
|
||||
number of tokens in the given text
|
||||
|
||||
"""
|
||||
return len(self.extract_tokens(text))
|
||||
|
||||
def decode_single_token(self, encoding: int):
|
||||
# Mistral tokenizer doesn't have the option to decode tokens
|
||||
raise NotImplementedError
|
||||
2
cognee/low_level.py
Normal file
2
cognee/low_level.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from cognee.infrastructure.engine import ExtendableDataPoint as DataPoint
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
|
|
@ -53,7 +53,7 @@ class TextChunker:
|
|||
chunk_index=self.chunk_index,
|
||||
cut_type=chunk_data["cut_type"],
|
||||
contains=[],
|
||||
_metadata={
|
||||
metadata={
|
||||
"index_fields": ["text"],
|
||||
},
|
||||
)
|
||||
|
|
@ -73,7 +73,7 @@ class TextChunker:
|
|||
chunk_index=self.chunk_index,
|
||||
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
contains=[],
|
||||
_metadata={
|
||||
metadata={
|
||||
"index_fields": ["text"],
|
||||
},
|
||||
)
|
||||
|
|
@ -97,7 +97,7 @@ class TextChunker:
|
|||
chunk_index=self.chunk_index,
|
||||
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
contains=[],
|
||||
_metadata={"index_fields": ["text"]},
|
||||
metadata={"index_fields": ["text"]},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
|
|
@ -16,4 +16,4 @@ class DocumentChunk(DataPoint):
|
|||
pydantic_type: str = "DocumentChunk"
|
||||
contains: List[Entity] = None
|
||||
|
||||
_metadata: dict = {"index_fields": ["text"], "type": "DocumentChunk"}
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
|
|
@ -9,7 +7,7 @@ class Document(DataPoint):
|
|||
raw_data_location: str
|
||||
external_metadata: Optional[str]
|
||||
mime_type: str
|
||||
_metadata: dict = {"index_fields": ["name"], "type": "Document"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
def read(self, chunk_size: int, chunker=str, max_chunk_tokens: Optional[int] = None) -> str:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
from .Document import Document
|
||||
|
||||
|
|
|
|||
|
|
@ -9,4 +9,4 @@ class Entity(DataPoint):
|
|||
description: str
|
||||
pydantic_type: str = "Entity"
|
||||
|
||||
_metadata: dict = {"index_fields": ["name"], "type": "Entity"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
|
|
|||
|
|
@ -7,4 +7,4 @@ class EntityType(DataPoint):
|
|||
description: str
|
||||
pydantic_type: str = "EntityType"
|
||||
|
||||
_metadata: dict = {"index_fields": ["name"], "type": "EntityType"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
|
|
|||
11
cognee/modules/engine/operations/setup.py
Normal file
11
cognee/modules/engine/operations/setup.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from cognee.infrastructure.databases.relational import (
|
||||
create_db_and_tables as create_relational_db_and_tables,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.pgvector import (
|
||||
create_db_and_tables as create_pgvector_db_and_tables,
|
||||
)
|
||||
|
||||
|
||||
async def setup():
|
||||
await create_relational_db_and_tables()
|
||||
await create_pgvector_db_and_tables()
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
|
|
@ -8,4 +6,4 @@ class EdgeType(DataPoint):
|
|||
relationship_name: str
|
||||
number_of_edges: int
|
||||
|
||||
_metadata: dict = {"index_fields": ["relationship_name"], "type": "EdgeType"}
|
||||
metadata: dict = {"index_fields": ["relationship_name"]}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from cognee.infrastructure.engine import DataPoint
|
|||
|
||||
|
||||
def convert_node_to_data_point(node_data: dict) -> DataPoint:
|
||||
subclass = find_subclass_by_name(DataPoint, node_data._metadata["type"])
|
||||
subclass = find_subclass_by_name(DataPoint, node_data["type"])
|
||||
|
||||
return subclass(**node_data)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@ async def get_graph_from_model(
|
|||
edges = []
|
||||
visited_properties = visited_properties or {}
|
||||
|
||||
data_point_properties = {}
|
||||
data_point_properties = {
|
||||
"type": type(data_point).__name__,
|
||||
}
|
||||
excluded_properties = set()
|
||||
properties_to_visit = set()
|
||||
|
||||
for field_name, field_value in data_point:
|
||||
if field_name == "_metadata":
|
||||
if field_name == "metadata":
|
||||
continue
|
||||
|
||||
if isinstance(field_value, DataPoint):
|
||||
|
|
@ -60,7 +62,6 @@ async def get_graph_from_model(
|
|||
SimpleDataPointModel = copy_model(
|
||||
type(data_point),
|
||||
include_fields={
|
||||
"_metadata": (dict, data_point._metadata),
|
||||
"__tablename__": (str, data_point.__tablename__),
|
||||
},
|
||||
exclude_fields=list(excluded_properties),
|
||||
|
|
|
|||
|
|
@ -84,10 +84,10 @@ async def brute_force_search(
|
|||
|
||||
if collections is None:
|
||||
collections = [
|
||||
"entity_name",
|
||||
"text_summary_text",
|
||||
"entity_type_name",
|
||||
"document_chunk_text",
|
||||
"Entity_name",
|
||||
"TextSummary_text",
|
||||
"EntityType_name",
|
||||
"DocumentChunk_text",
|
||||
]
|
||||
|
||||
try:
|
||||
|
|
@ -127,9 +127,14 @@ async def brute_force_search(
|
|||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
except Exception as error:
|
||||
logging.error(
|
||||
"Error during brute force search for user: %s, query: %s. Error: %s", user.id, query, e
|
||||
"Error during brute force search for user: %s, query: %s. Error: %s",
|
||||
user.id,
|
||||
query,
|
||||
error,
|
||||
)
|
||||
send_telemetry("cognee.brute_force_triplet_search EXECUTION FAILED", user.id)
|
||||
raise RuntimeError("An error occurred during brute force search") from e
|
||||
send_telemetry(
|
||||
"cognee.brute_force_triplet_search EXECUTION FAILED", user.id, {"error": str(error)}
|
||||
)
|
||||
raise RuntimeError("An error occurred during brute force search") from error
|
||||
|
|
|
|||
|
|
@ -101,7 +101,6 @@ async def code_description_to_code_part(
|
|||
"text",
|
||||
"file_path",
|
||||
"source_code",
|
||||
"pydantic_type",
|
||||
],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
)
|
||||
|
|
@ -117,13 +116,13 @@ async def code_description_to_code_part(
|
|||
continue
|
||||
|
||||
for code_file in node_to_search_from.get_skeleton_neighbours():
|
||||
if code_file.get_attribute("pydantic_type") == "SourceCodeChunk":
|
||||
if code_file.get_attribute("type") == "SourceCodeChunk":
|
||||
for code_file_edge in code_file.get_skeleton_edges():
|
||||
if code_file_edge.get_attribute("relationship_name") == "code_chunk_of":
|
||||
code_pieces_to_return.add(code_file_edge.get_destination_node())
|
||||
elif code_file.get_attribute("pydantic_type") == "CodePart":
|
||||
elif code_file.get_attribute("type") == "CodePart":
|
||||
code_pieces_to_return.add(code_file)
|
||||
elif code_file.get_attribute("pydantic_type") == "CodeFile":
|
||||
elif code_file.get_attribute("type") == "CodeFile":
|
||||
for code_file_edge in code_file.get_skeleton_edges():
|
||||
if code_file_edge.get_attribute("relationship_name") == "contains":
|
||||
code_pieces_to_return.add(code_file_edge.get_destination_node())
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def get_own_properties(data_point: DataPoint):
|
|||
|
||||
for field_name, field_value in data_point:
|
||||
if (
|
||||
field_name == "_metadata"
|
||||
field_name == "metadata"
|
||||
or isinstance(field_value, dict)
|
||||
or isinstance(field_value, DataPoint)
|
||||
or (isinstance(field_value, list) and isinstance(field_value[0], DataPoint))
|
||||
|
|
|
|||
|
|
@ -1,9 +1,15 @@
|
|||
import logging
|
||||
import networkx as nx
|
||||
import json
|
||||
import os
|
||||
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
|
||||
async def cognee_network_visualization(graph_data):
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def cognee_network_visualization(graph_data, destination_file_path: str = None):
|
||||
nodes_data, edges_data = graph_data
|
||||
|
||||
G = nx.DiGraph()
|
||||
|
|
@ -19,7 +25,7 @@ async def cognee_network_visualization(graph_data):
|
|||
for node_id, node_info in nodes_data:
|
||||
node_info = node_info.copy()
|
||||
node_info["id"] = str(node_id)
|
||||
node_info["color"] = color_map.get(node_info.get("pydantic_type", "default"), "#D3D3D3")
|
||||
node_info["color"] = color_map.get(node_info.get("type", "default"), "#D3D3D3")
|
||||
node_info["name"] = node_info.get("name", str(node_id))
|
||||
|
||||
try:
|
||||
|
|
@ -178,12 +184,15 @@ async def cognee_network_visualization(graph_data):
|
|||
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
|
||||
html_content = html_content.replace("{links}", json.dumps(links_list))
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
output_file = os.path.join(home_dir, "graph_visualization.html")
|
||||
if not destination_file_path:
|
||||
home_dir = os.path.expanduser("~")
|
||||
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
LocalStorage.ensure_directory_exists(os.path.dirname(destination_file_path))
|
||||
|
||||
with open(destination_file_path, "w") as f:
|
||||
f.write(html_content)
|
||||
|
||||
print(f"Graph visualization saved as {output_file}")
|
||||
logger.info(f"Graph visualization saved as {destination_file_path}")
|
||||
|
||||
return html_content
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ class Repository(DataPoint):
|
|||
__tablename__ = "Repository"
|
||||
path: str
|
||||
pydantic_type: str = "Repository"
|
||||
_metadata: dict = {"index_fields": [], "type": "Repository"}
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
class CodeFile(DataPoint):
|
||||
|
|
@ -18,7 +18,7 @@ class CodeFile(DataPoint):
|
|||
depends_on: Optional[List["CodeFile"]] = None
|
||||
depends_directly_on: Optional[List["CodeFile"]] = None
|
||||
contains: Optional[List["CodePart"]] = None
|
||||
_metadata: dict = {"index_fields": [], "type": "CodeFile"}
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
class CodePart(DataPoint):
|
||||
|
|
@ -27,7 +27,7 @@ class CodePart(DataPoint):
|
|||
# part_of: Optional[CodeFile] = None
|
||||
pydantic_type: str = "CodePart"
|
||||
source_code: Optional[str] = None
|
||||
_metadata: dict = {"index_fields": [], "type": "CodePart"}
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
class SourceCodeChunk(DataPoint):
|
||||
|
|
@ -37,7 +37,7 @@ class SourceCodeChunk(DataPoint):
|
|||
pydantic_type: str = "SourceCodeChunk"
|
||||
previous_chunk: Optional["SourceCodeChunk"] = None
|
||||
|
||||
_metadata: dict = {"index_fields": ["source_code"], "type": "SourceCodeChunk"}
|
||||
metadata: dict = {"index_fields": ["source_code"]}
|
||||
|
||||
|
||||
CodeFile.model_rebuild()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class Variable(DataPoint):
|
|||
default_value: Optional[str] = None
|
||||
data_type: str
|
||||
|
||||
_metadata = {"index_fields": ["name"], "type": "Variable"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class Operator(DataPoint):
|
||||
|
|
@ -19,7 +19,7 @@ class Operator(DataPoint):
|
|||
name: str
|
||||
description: str
|
||||
return_type: str
|
||||
_metadata = {"index_fields": ["name"], "type": "Operator"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class Class(DataPoint):
|
||||
|
|
@ -30,7 +30,7 @@ class Class(DataPoint):
|
|||
extended_from_class: Optional["Class"] = None
|
||||
has_methods: List["Function"]
|
||||
|
||||
_metadata = {"index_fields": ["name"], "type": "Class"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class ClassInstance(DataPoint):
|
||||
|
|
@ -41,7 +41,7 @@ class ClassInstance(DataPoint):
|
|||
instantiated_by: Union["Function"]
|
||||
instantiation_arguments: List[Variable]
|
||||
|
||||
_metadata = {"index_fields": ["name"], "type": "ClassInstance"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class Function(DataPoint):
|
||||
|
|
@ -52,7 +52,7 @@ class Function(DataPoint):
|
|||
return_type: str
|
||||
is_static: Optional[bool] = False
|
||||
|
||||
_metadata = {"index_fields": ["name"], "type": "Function"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class FunctionCall(DataPoint):
|
||||
|
|
@ -60,7 +60,7 @@ class FunctionCall(DataPoint):
|
|||
called_by: Union[Function, Literal["main"]]
|
||||
function_called: Function
|
||||
function_arguments: List[Any]
|
||||
_metadata = {"index_fields": [], "type": "FunctionCall"}
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
class Expression(DataPoint):
|
||||
|
|
@ -69,7 +69,7 @@ class Expression(DataPoint):
|
|||
description: str
|
||||
expression: str
|
||||
members: List[Union[Variable, Function, Operator, "Expression"]]
|
||||
_metadata = {"index_fields": ["name"], "type": "Expression"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class SourceCodeGraph(DataPoint):
|
||||
|
|
@ -88,7 +88,7 @@ class SourceCodeGraph(DataPoint):
|
|||
Expression,
|
||||
]
|
||||
]
|
||||
_metadata = {"index_fields": ["name"], "type": "SourceCodeGraph"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
Class.model_rebuild()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import matplotlib.pyplot as plt
|
|||
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
|
@ -164,6 +163,7 @@ def prepare_nodes(graph, include_size=False):
|
|||
continue
|
||||
|
||||
node_data = {
|
||||
**node_info,
|
||||
"id": str(node),
|
||||
"name": node_info["name"] if "name" in node_info else str(node),
|
||||
}
|
||||
|
|
@ -183,7 +183,7 @@ def prepare_nodes(graph, include_size=False):
|
|||
|
||||
|
||||
async def render_graph(
|
||||
graph, include_nodes=False, include_color=False, include_size=False, include_labels=False
|
||||
graph=None, include_nodes=True, include_color=False, include_size=False, include_labels=True
|
||||
):
|
||||
await register_graphistry()
|
||||
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@ async def query_chunks(query: str) -> list[dict]:
|
|||
Notes:
|
||||
- The function uses the `search` method of the vector engine to find matches.
|
||||
- Limits the results to the top 5 matching chunks to balance performance and relevance.
|
||||
- Ensure that the vector database is properly initialized and contains the "document_chunk_text" collection.
|
||||
- Ensure that the vector database is properly initialized and contains the "DocumentChunk_text" collection.
|
||||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
found_chunks = await vector_engine.search("document_chunk_text", query, limit=5)
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5)
|
||||
|
||||
chunks = [result.payload for result in found_chunks]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.engine import ExtendableDataPoint
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
|
|
@ -34,9 +35,20 @@ async def graph_query_completion(query: str) -> list:
|
|||
- The `brute_force_triplet_search` is used to retrieve relevant graph data.
|
||||
- Prompts are dynamically rendered and provided to the LLM for contextual understanding.
|
||||
- Ensure that the LLM client and graph database are properly configured and accessible.
|
||||
|
||||
"""
|
||||
found_triplets = await brute_force_triplet_search(query, top_k=5)
|
||||
|
||||
subclasses = get_all_subclasses(ExtendableDataPoint)
|
||||
|
||||
vector_index_collections = []
|
||||
|
||||
for subclass in subclasses:
|
||||
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
|
||||
for field_name in index_fields:
|
||||
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
||||
|
||||
found_triplets = await brute_force_triplet_search(
|
||||
query, top_k=5, collections=vector_index_collections or None
|
||||
)
|
||||
|
||||
if len(found_triplets) == 0:
|
||||
raise NoRelevantDataFound
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ async def query_completion(query: str) -> list:
|
|||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
found_chunks = await vector_engine.search("document_chunk_text", query, limit=1)
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1)
|
||||
|
||||
if len(found_chunks) == 0:
|
||||
raise NoRelevantDataFound
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from cognee.modules.graph.utils import (
|
|||
expand_with_nodes_and_edges,
|
||||
retrieve_existing_edges,
|
||||
)
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
|
||||
|
|
@ -18,7 +19,6 @@ async def extract_graph_from_data(
|
|||
) -> List[DocumentChunk]:
|
||||
"""
|
||||
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
|
||||
|
||||
"""
|
||||
|
||||
chunk_graphs = await asyncio.gather(
|
||||
|
|
@ -26,6 +26,13 @@ async def extract_graph_from_data(
|
|||
)
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
if graph_model is not KnowledgeGraph:
|
||||
for chunk_index, chunk_graph in enumerate(chunk_graphs):
|
||||
data_chunks[chunk_index].contains = chunk_graph
|
||||
|
||||
await add_data_points(chunk_graphs)
|
||||
return data_chunks
|
||||
|
||||
existing_edges_map = await retrieve_existing_edges(
|
||||
data_chunks,
|
||||
chunk_graphs,
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@ async def query_graph_connections(query: str, exploration_levels=1) -> list[(str
|
|||
else:
|
||||
vector_engine = get_vector_engine()
|
||||
results = await asyncio.gather(
|
||||
vector_engine.search("entity_name", query_text=query, limit=5),
|
||||
vector_engine.search("entity_type_name", query_text=query, limit=5),
|
||||
vector_engine.search("Entity_name", query_text=query, limit=5),
|
||||
vector_engine.search("EntityType_name", query_text=query, limit=5),
|
||||
)
|
||||
results = [*results[0], *results[1]]
|
||||
relevant_results = [result for result in results if result.score < 0.5][:5]
|
||||
|
|
|
|||
|
|
@ -16,25 +16,25 @@ async def index_data_points(data_points: list[DataPoint]):
|
|||
for data_point in data_points:
|
||||
data_point_type = type(data_point)
|
||||
|
||||
for field_name in data_point._metadata["index_fields"]:
|
||||
for field_name in data_point.metadata["index_fields"]:
|
||||
if getattr(data_point, field_name, None) is None:
|
||||
continue
|
||||
|
||||
index_name = f"{data_point_type.__tablename__}.{field_name}"
|
||||
index_name = f"{data_point_type.__name__}_{field_name}"
|
||||
|
||||
if index_name not in created_indexes:
|
||||
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
|
||||
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
|
||||
created_indexes[index_name] = True
|
||||
|
||||
if index_name not in index_points:
|
||||
index_points[index_name] = []
|
||||
|
||||
indexed_data_point = data_point.model_copy()
|
||||
indexed_data_point._metadata["index_fields"] = [field_name]
|
||||
indexed_data_point.metadata["index_fields"] = [field_name]
|
||||
index_points[index_name].append(indexed_data_point)
|
||||
|
||||
for index_name, indexable_points in index_points.items():
|
||||
index_name, field_name = index_name.split(".")
|
||||
index_name, field_name = index_name.split("_")
|
||||
try:
|
||||
await vector_engine.index_data_points(index_name, field_name, indexable_points)
|
||||
except EmbeddingException as e:
|
||||
|
|
@ -101,13 +101,13 @@ if __name__ == "__main__":
|
|||
class Car(DataPoint):
|
||||
model: str
|
||||
color: str
|
||||
_metadata = {"index_fields": ["name"], "type": "Car"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
age: int
|
||||
owns_car: list[Car]
|
||||
_metadata = {"index_fields": ["name"], "type": "Person"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
car1 = Car(model="Tesla Model S", color="Blue")
|
||||
car2 = Car(model="Toyota Camry", color="Red")
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ async def index_graph_edges():
|
|||
edge = EdgeType(relationship_name=text, number_of_edges=count)
|
||||
data_point_type = type(edge)
|
||||
|
||||
for field_name in edge._metadata["index_fields"]:
|
||||
for field_name in edge.metadata["index_fields"]:
|
||||
index_name = f"{data_point_type.__tablename__}.{field_name}"
|
||||
|
||||
if index_name not in created_indexes:
|
||||
|
|
@ -61,7 +61,7 @@ async def index_graph_edges():
|
|||
index_points[index_name] = []
|
||||
|
||||
indexed_data_point = edge.model_copy()
|
||||
indexed_data_point._metadata["index_fields"] = [field_name]
|
||||
indexed_data_point.metadata["index_fields"] = [field_name]
|
||||
index_points[index_name].append(indexed_data_point)
|
||||
|
||||
for index_name, indexable_points in index_points.items():
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class TextSummary(DataPoint):
|
|||
text: str
|
||||
made_from: DocumentChunk
|
||||
|
||||
_metadata: dict = {"index_fields": ["text"], "type": "TextSummary"}
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class CodeSummary(DataPoint):
|
||||
|
|
@ -19,4 +19,4 @@ class CodeSummary(DataPoint):
|
|||
summarizes: Union[CodeFile, CodePart, SourceCodeChunk]
|
||||
pydantic_type: str = "CodeSummary"
|
||||
|
||||
_metadata: dict = {"index_fields": ["text"], "type": "CodeSummary"}
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ async def query_summaries(query: str) -> list:
|
|||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
summaries_results = await vector_engine.search("text_summary_text", query, limit=5)
|
||||
summaries_results = await vector_engine.search("TextSummary_text", query, limit=5)
|
||||
|
||||
summaries = [summary.payload for summary in summaries_results]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from typing import ClassVar, Optional
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class GraphitiNode(DataPoint):
|
||||
|
|
@ -9,4 +9,4 @@ class GraphitiNode(DataPoint):
|
|||
summary: Optional[str] = None
|
||||
pydantic_type: str = "GraphitiNode"
|
||||
|
||||
_metadata: dict = {"index_fields": ["name", "summary", "content"], "type": "GraphitiNode"}
|
||||
metadata: dict = {"index_fields": ["name", "summary", "content"]}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ async def index_and_transform_graphiti_nodes_and_edges():
|
|||
|
||||
data_point_type = type(graphiti_node)
|
||||
|
||||
for field_name in graphiti_node._metadata["index_fields"]:
|
||||
for field_name in graphiti_node.metadata["index_fields"]:
|
||||
index_name = f"{data_point_type.__tablename__}.{field_name}"
|
||||
|
||||
if index_name not in created_indexes:
|
||||
|
|
@ -48,7 +48,7 @@ async def index_and_transform_graphiti_nodes_and_edges():
|
|||
|
||||
if getattr(graphiti_node, field_name, None) is not None:
|
||||
indexed_data_point = graphiti_node.model_copy()
|
||||
indexed_data_point._metadata["index_fields"] = [field_name]
|
||||
indexed_data_point.metadata["index_fields"] = [field_name]
|
||||
index_points[index_name].append(indexed_data_point)
|
||||
|
||||
for index_name, indexable_points in index_points.items():
|
||||
|
|
@ -65,7 +65,7 @@ async def index_and_transform_graphiti_nodes_and_edges():
|
|||
edge = EdgeType(relationship_name=text, number_of_edges=count)
|
||||
data_point_type = type(edge)
|
||||
|
||||
for field_name in edge._metadata["index_fields"]:
|
||||
for field_name in edge.metadata["index_fields"]:
|
||||
index_name = f"{data_point_type.__tablename__}.{field_name}"
|
||||
|
||||
if index_name not in created_indexes:
|
||||
|
|
@ -76,7 +76,7 @@ async def index_and_transform_graphiti_nodes_and_edges():
|
|||
index_points[index_name] = []
|
||||
|
||||
indexed_data_point = edge.model_copy()
|
||||
indexed_data_point._metadata["index_fields"] = [field_name]
|
||||
indexed_data_point.metadata["index_fields"] = [field_name]
|
||||
index_points[index_name].append(indexed_data_point)
|
||||
|
||||
for index_name, indexable_points in index_points.items():
|
||||
|
|
|
|||
38
cognee/tests/low_level/companies.json
Normal file
38
cognee/tests/low_level/companies.json
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
[
|
||||
{
|
||||
"name": "TechNova Inc.",
|
||||
"departments": [
|
||||
"Engineering",
|
||||
"Marketing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GreenFuture Solutions",
|
||||
"departments": [
|
||||
"Research & Development",
|
||||
"Sales",
|
||||
"Customer Support"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Skyline Financials",
|
||||
"departments": [
|
||||
"Accounting"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MediCare Plus",
|
||||
"departments": [
|
||||
"Healthcare",
|
||||
"Administration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "NextGen Robotics",
|
||||
"departments": [
|
||||
"AI Development",
|
||||
"Manufacturing",
|
||||
"HR"
|
||||
]
|
||||
}
|
||||
]
|
||||
52
cognee/tests/low_level/people.json
Normal file
52
cognee/tests/low_level/people.json
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
[
|
||||
{
|
||||
"name": "John Doe",
|
||||
"company": "TechNova Inc.",
|
||||
"department": "Engineering"
|
||||
},
|
||||
{
|
||||
"name": "Jane Smith",
|
||||
"company": "TechNova Inc.",
|
||||
"department": "Marketing"
|
||||
},
|
||||
{
|
||||
"name": "Alice Johnson",
|
||||
"company": "GreenFuture Solutions",
|
||||
"department": "Sales"
|
||||
},
|
||||
{
|
||||
"name": "Bob Williams",
|
||||
"company": "GreenFuture Solutions",
|
||||
"department": "Customer Support"
|
||||
},
|
||||
{
|
||||
"name": "Michael Brown",
|
||||
"company": "Skyline Financials",
|
||||
"department": "Accounting"
|
||||
},
|
||||
{
|
||||
"name": "Emily Davis",
|
||||
"company": "MediCare Plus",
|
||||
"department": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "David Wilson",
|
||||
"company": "MediCare Plus",
|
||||
"department": "Administration"
|
||||
},
|
||||
{
|
||||
"name": "Emma Thompson",
|
||||
"company": "NextGen Robotics",
|
||||
"department": "AI Development"
|
||||
},
|
||||
{
|
||||
"name": "Chris Martin",
|
||||
"company": "NextGen Robotics",
|
||||
"department": "Manufacturing"
|
||||
},
|
||||
{
|
||||
"name": "Sophia White",
|
||||
"company": "NextGen Robotics",
|
||||
"department": "HR"
|
||||
}
|
||||
]
|
||||
94
cognee/tests/low_level/pipeline.py
Normal file
94
cognee/tests/low_level/pipeline.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from cognee import prune
|
||||
from cognee import visualize_graph
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.pipelines import run_tasks, Task
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.shared.utils import render_graph
|
||||
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
|
||||
|
||||
class Department(DataPoint):
|
||||
name: str
|
||||
employees: list[Person]
|
||||
|
||||
|
||||
class CompanyType(DataPoint):
|
||||
name: str = "Company"
|
||||
|
||||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
departments: list[Department]
|
||||
is_type: CompanyType
|
||||
|
||||
|
||||
def ingest_files():
|
||||
companies_file_path = os.path.join(os.path.dirname(__file__), "companies.json")
|
||||
companies = json.loads(open(companies_file_path, "r").read())
|
||||
|
||||
people_file_path = os.path.join(os.path.dirname(__file__), "people.json")
|
||||
people = json.loads(open(people_file_path, "r").read())
|
||||
|
||||
people_data_points = {}
|
||||
departments_data_points = {}
|
||||
|
||||
for person in people:
|
||||
new_person = Person(name=person["name"])
|
||||
people_data_points[person["name"]] = new_person
|
||||
|
||||
if person["department"] not in departments_data_points:
|
||||
departments_data_points[person["department"]] = Department(
|
||||
name=person["department"], employees=[new_person]
|
||||
)
|
||||
else:
|
||||
departments_data_points[person["department"]].employees.append(new_person)
|
||||
|
||||
companies_data_points = {}
|
||||
|
||||
# Create a single CompanyType node, so we connect all companies to it.
|
||||
companyType = CompanyType()
|
||||
|
||||
for company in companies:
|
||||
new_company = Company(name=company["name"], departments=[], is_type=companyType)
|
||||
companies_data_points[company["name"]] = new_company
|
||||
|
||||
for department_name in company["departments"]:
|
||||
if department_name not in departments_data_points:
|
||||
departments_data_points[department_name] = Department(
|
||||
name=department_name, employees=[]
|
||||
)
|
||||
|
||||
new_company.departments.append(departments_data_points[department_name])
|
||||
|
||||
return companies_data_points.values()
|
||||
|
||||
|
||||
async def main():
|
||||
await prune.prune_data()
|
||||
await prune.prune_system(metadata=True)
|
||||
|
||||
await setup()
|
||||
|
||||
pipeline = run_tasks([Task(ingest_files), Task(add_data_points)])
|
||||
|
||||
async for status in pipeline:
|
||||
print(status)
|
||||
|
||||
# Get a graphistry url (Register for a free account at https://www.graphistry.com)
|
||||
await render_graph()
|
||||
|
||||
# Or use our simple graph preview
|
||||
graph_file_path = str(
|
||||
os.path.join(os.path.dirname(__file__), ".artifacts/graph_visualization.html")
|
||||
)
|
||||
await visualize_graph(graph_file_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
104
cognee/tests/test_custom_model.py
Executable file
104
cognee/tests/test_custom_model.py
Executable file
|
|
@ -0,0 +1,104 @@
|
|||
import os
|
||||
import logging
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.utils import render_graph
|
||||
from cognee.low_level import DataPoint
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
async def main():
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_custom_model")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_custom_model")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
# Define a custom graph model for programming languages.
|
||||
class FieldType(DataPoint):
|
||||
name: str = "Field"
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Field(DataPoint):
|
||||
name: str
|
||||
is_type: FieldType
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class ProgrammingLanguageType(DataPoint):
|
||||
name: str = "Programming Language"
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class ProgrammingLanguage(DataPoint):
|
||||
name: str
|
||||
used_in: list[Field] = []
|
||||
is_type: ProgrammingLanguageType
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
text = (
|
||||
"Python is an interpreted, high-level, general-purpose programming language. It was created by Guido van Rossum and first released in 1991. "
|
||||
+ "Python is widely used in data analysis, web development, and machine learning."
|
||||
)
|
||||
|
||||
await cognee.add(text)
|
||||
|
||||
await cognee.cognify(graph_model=ProgrammingLanguage)
|
||||
|
||||
url = await render_graph()
|
||||
print(f"Graphistry URL: {url}")
|
||||
|
||||
graph_file_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".artifacts/test_custom_model/graph_visualization.html",
|
||||
)
|
||||
).resolve()
|
||||
)
|
||||
await cognee.visualize_graph(graph_file_path)
|
||||
|
||||
# Completion query that uses graph data to form context.
|
||||
completion = await cognee.search(SearchType.GRAPH_COMPLETION, "What is python?")
|
||||
assert len(completion) != 0, "Graph completion search didn't return any result."
|
||||
print("Graph completion result is:")
|
||||
print(completion)
|
||||
|
||||
# Completion query that uses document chunks to form context.
|
||||
completion = await cognee.search(SearchType.COMPLETION, "What is Python?")
|
||||
assert len(completion) != 0, "Completion search didn't return any result."
|
||||
print("Completion result is:")
|
||||
print(completion)
|
||||
|
||||
# Query all summaries related to query.
|
||||
summaries = await cognee.search(SearchType.SUMMARIES, "Python")
|
||||
assert len(summaries) != 0, "Summaries search didn't return any results."
|
||||
print("Summary results are:")
|
||||
for summary in summaries:
|
||||
print(summary)
|
||||
|
||||
chunks = await cognee.search(SearchType.CHUNKS, query_text="Python")
|
||||
assert len(chunks) != 0, "Chunks search didn't return any results."
|
||||
print("Chunk results are:")
|
||||
for chunk in chunks:
|
||||
print(chunk)
|
||||
|
||||
history = await cognee.get_search_history()
|
||||
|
||||
assert len(history) == 8, "Search history is not correct."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main(), debug=True)
|
||||
|
|
@ -44,7 +44,7 @@ async def main():
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entity_name", "AI"))[0]
|
||||
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ async def main():
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ async def main():
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ async def main():
|
|||
await test_getting_of_documents(dataset_name_1)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
@ -142,6 +142,15 @@ async def main():
|
|||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
graph_completion = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text=random_node_name,
|
||||
datasets=[dataset_name_2],
|
||||
)
|
||||
assert len(graph_completion) != 0, "Completion result is empty."
|
||||
print("Completion result is:")
|
||||
print(graph_completion)
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.SUMMARIES, query_text=random_node_name
|
||||
)
|
||||
|
|
@ -151,7 +160,7 @@ async def main():
|
|||
print(f"{result}\n")
|
||||
|
||||
history = await cognee.get_search_history()
|
||||
assert len(history) == 6, "Search history is not correct."
|
||||
assert len(history) == 8, "Search history is not correct."
|
||||
|
||||
results = await brute_force_triplet_search("What is a quantum computer?")
|
||||
assert len(results) > 0
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ async def main():
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ async def main():
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ random.seed(1500)
|
|||
|
||||
class Repository(DataPoint):
|
||||
path: str
|
||||
_metadata = {"index_fields": [], "type": "Repository"}
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
class CodeFile(DataPoint):
|
||||
|
|
@ -20,13 +20,13 @@ class CodeFile(DataPoint):
|
|||
contains: List["CodePart"] = []
|
||||
depends_on: List["CodeFile"] = []
|
||||
source_code: str
|
||||
_metadata = {"index_fields": [], "type": "CodeFile"}
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
class CodePart(DataPoint):
|
||||
part_of: CodeFile
|
||||
source_code: str
|
||||
_metadata = {"index_fields": [], "type": "CodePart"}
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
CodeFile.model_rebuild()
|
||||
|
|
|
|||
|
|
@ -9,25 +9,25 @@ from cognee.modules.graph.utils import get_graph_from_model
|
|||
|
||||
class Document(DataPoint):
|
||||
path: str
|
||||
_metadata = {"index_fields": [], "type": "Document"}
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
class DocumentChunk(DataPoint):
|
||||
part_of: Document
|
||||
text: str
|
||||
contains: List["Entity"] = None
|
||||
_metadata = {"index_fields": ["text"], "type": "DocumentChunk"}
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class EntityType(DataPoint):
|
||||
name: str
|
||||
_metadata = {"index_fields": ["name"], "type": "EntityType"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class Entity(DataPoint):
|
||||
name: str
|
||||
is_type: EntityType
|
||||
_metadata = {"index_fields": ["name"], "type": "Entity"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
DocumentChunk.model_rebuild()
|
||||
|
|
|
|||
|
|
@ -7,25 +7,25 @@ from cognee.modules.graph.utils import get_graph_from_model
|
|||
|
||||
class Document(DataPoint):
|
||||
path: str
|
||||
_metadata = {"index_fields": [], "type": "Document"}
|
||||
metadata: dict = {"index_fields": []}
|
||||
|
||||
|
||||
class DocumentChunk(DataPoint):
|
||||
part_of: Document
|
||||
text: str
|
||||
contains: List["Entity"] = None
|
||||
_metadata = {"index_fields": ["text"], "type": "DocumentChunk"}
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class EntityType(DataPoint):
|
||||
name: str
|
||||
_metadata = {"index_fields": ["name"], "type": "EntityType"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class Entity(DataPoint):
|
||||
name: str
|
||||
is_type: EntityType
|
||||
_metadata = {"index_fields": ["name"], "type": "Entity"}
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
DocumentChunk.model_rebuild()
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@ from cognee.modules.visualization.cognee_network_visualization import (
|
|||
@pytest.mark.asyncio
|
||||
async def test_create_cognee_style_network_with_logo():
|
||||
nodes_data = [
|
||||
(1, {"pydantic_type": "Entity", "name": "Node1", "updated_at": 123, "created_at": 123}),
|
||||
(1, {"type": "Entity", "name": "Node1", "updated_at": 123, "created_at": 123}),
|
||||
(
|
||||
2,
|
||||
{
|
||||
"pydantic_type": "DocumentChunk",
|
||||
"type": "DocumentChunk",
|
||||
"name": "Node2",
|
||||
"updated_at": 123,
|
||||
"created_at": 123,
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ async def get_context_with_simple_rag(instance: dict) -> str:
|
|||
await cognify_instance(instance)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
found_chunks = await vector_engine.search("document_chunk_text", instance["question"], limit=5)
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", instance["question"], limit=5)
|
||||
|
||||
search_results_str = "\n".join([context_item.payload["text"] for context_item in found_chunks])
|
||||
|
||||
|
|
|
|||
|
|
@ -797,7 +797,7 @@
|
|||
"from cognee.infrastructure.databases.vector import get_vector_engine\n",
|
||||
"\n",
|
||||
"vector_engine = get_vector_engine()\n",
|
||||
"results = await search(vector_engine, \"entity_name\", \"sarah.nguyen@example.com\")\n",
|
||||
"results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n",
|
||||
"for result in results:\n",
|
||||
" print(result)"
|
||||
]
|
||||
|
|
@ -827,7 +827,7 @@
|
|||
"source": [
|
||||
"from cognee.api.v1.search import SearchType\n",
|
||||
"\n",
|
||||
"node = (await vector_engine.search(\"entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
|
||||
"node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
|
||||
"node_name = node.payload[\"text\"]\n",
|
||||
"\n",
|
||||
"search_results = await cognee.search(query_type=SearchType.SUMMARIES, query_text = node_name)\n",
|
||||
|
|
@ -1237,7 +1237,7 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "cognee-c83GrcRT-py3.11",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
|
@ -1251,7 +1251,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.10"
|
||||
"version": "3.11.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
|||
|
|
@ -533,7 +533,7 @@
|
|||
"from cognee.infrastructure.databases.vector import get_vector_engine\n",
|
||||
"\n",
|
||||
"vector_engine = get_vector_engine()\n",
|
||||
"results = await search(vector_engine, \"entity_name\", \"sarah.nguyen@example.com\")\n",
|
||||
"results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n",
|
||||
"for result in results:\n",
|
||||
" print(result)"
|
||||
]
|
||||
|
|
@ -563,7 +563,7 @@
|
|||
"source": [
|
||||
"from cognee.api.v1.search import SearchType\n",
|
||||
"\n",
|
||||
"node = (await vector_engine.search(\"entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
|
||||
"node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
|
||||
"node_name = node.payload[\"text\"]\n",
|
||||
"\n",
|
||||
"search_results = await cognee.search(query_type=SearchType.SUMMARIES, query_text=node_name)\n",
|
||||
|
|
|
|||
1299
poetry.lock
generated
1299
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -80,6 +80,7 @@ nltk = "3.9.1"
|
|||
google-generativeai = {version = "^0.8.4", optional = true}
|
||||
parso = {version = "^0.8.4", optional = true}
|
||||
jedi = {version = "^0.19.2", optional = true}
|
||||
mistral-common = {version = "^1.5.2", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
|
@ -93,6 +94,7 @@ langchain = ["langsmith", "langchain_text_splitters"]
|
|||
llama-index = ["llama-index-core"]
|
||||
gemini = ["google-generativeai"]
|
||||
huggingface = ["transformers"]
|
||||
mistral = ["mistral-common"]
|
||||
deepeval = ["deepeval"]
|
||||
posthog = ["posthog"]
|
||||
falkordb = ["falkordb"]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue