fix: custom model pipeline (#508)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
• Graph visualizations now allow exporting to a user-specified file path
for more flexible output management.
• The text embedding process has been enhanced with an additional
tokenizer option for improved performance.
• A new `ExtendableDataPoint` class has been introduced for future
extensions.
• New JSON files for companies and individuals have been added to
facilitate testing and data processing.

- **Improvements**
• Search functionality now uses updated identifiers for more reliable
content retrieval.
• Metadata handling has been streamlined across various classes by
removing unnecessary type specifications.
• Enhanced serialization of properties in the Neo4j adapter for improved
handling of complex structures.
• The setup process for databases has been improved with a new
asynchronous setup function.

- **Chores**
• Dependency and configuration updates improve overall stability and
performance.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Boris 2025-02-08 02:00:15 +01:00 committed by GitHub
parent 6be6b3d222
commit f75e35c337
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
66 changed files with 1248 additions and 748 deletions

1
.gitignore vendored
View file

@ -179,6 +179,7 @@ cognee/cache/
# Default cognee system directory, used in development
.cognee_system/
.data_storage/
.artifacts/
.anon_id
node_modules/

View file

@ -9,13 +9,19 @@ import asyncio
from cognee.shared.utils import setup_logging
async def visualize_graph():
async def visualize_graph(destination_file_path: str = None):
graph_engine = await get_graph_engine()
graph_data = await graph_engine.get_graph_data()
logging.info(graph_data)
graph = await cognee_network_visualization(graph_data)
logging.info("The HTML file has been stored on your home directory! Navigate there with cd ~")
graph = await cognee_network_visualization(graph_data, destination_file_path)
if destination_file_path:
logging.info(f"The HTML file has been stored at path: {destination_file_path}")
else:
logging.info(
"The HTML file has been stored on your home directory! Navigate there with cd ~"
)
return graph

View file

@ -1,9 +1,10 @@
"""Neo4j Adapter for Graph Database"""
import json
import logging
import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict, Union
from typing import Optional, Any, List, Dict
from contextlib import asynccontextmanager
from uuid import UUID
from neo4j import AsyncSession
@ -11,6 +12,7 @@ from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.modules.storage.utils import JSONEncoder
logger = logging.getLogger("Neo4jAdapter")
@ -434,6 +436,10 @@ class Neo4jAdapter(GraphDBInterface):
serialized_properties[property_key] = str(property_value)
continue
if isinstance(property_value, dict):
serialized_properties[property_key] = json.dumps(property_value, cls=JSONEncoder)
continue
serialized_properties[property_key] = property_value
return serialized_properties

View file

@ -18,7 +18,7 @@ from cognee.infrastructure.engine import DataPoint
class IndexSchema(DataPoint):
text: str
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
metadata: dict = {"index_fields": ["text"]}
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):

View file

@ -61,7 +61,7 @@ class SQLAlchemyAdapter:
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
await connection.execute(
text(
f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"
f'CREATE TABLE IF NOT EXISTS {schema_name}."{table_name}" ({", ".join(fields_query_parts)});'
)
)
await connection.close()
@ -71,10 +71,10 @@ class SQLAlchemyAdapter:
if self.engine.dialect.name == "sqlite":
# SQLite doesnt support schema namespaces and the CASCADE keyword.
# However, foreign key constraint can be defined with ON DELETE CASCADE during table creation.
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
await connection.execute(text(f'DROP TABLE IF EXISTS "{table_name}";'))
else:
await connection.execute(
text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;")
text(f'DROP TABLE IF EXISTS {schema_name}."{table_name}" CASCADE;')
)
async def insert_data(
@ -252,7 +252,7 @@ class SQLAlchemyAdapter:
async def get_data(self, table_name: str, filters: dict = None):
async with self.engine.begin() as connection:
query = f"SELECT * FROM {table_name}"
query = f'SELECT * FROM "{table_name}"'
if filters:
filter_conditions = " AND ".join(
[
@ -336,7 +336,7 @@ class SQLAlchemyAdapter:
await connection.run_sync(metadata.reflect, schema=schema_name)
for table in metadata.sorted_tables:
drop_table_query = text(
f"DROP TABLE IF EXISTS {schema_name}.{table.name} CASCADE"
f'DROP TABLE IF EXISTS {schema_name}."{table.name}" CASCADE'
)
await connection.execute(drop_table_query)
metadata.clear()

View file

@ -8,6 +8,7 @@ from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import Em
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
from cognee.infrastructure.llm.tokenizer.Mistral import MistralTokenizer
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
litellm.set_verbose = False
@ -126,6 +127,8 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
elif "gemini" in self.provider.lower():
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
elif "mistral" in self.provider.lower():
tokenizer = MistralTokenizer(model=model, max_tokens=self.max_tokens)
else:
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)

View file

@ -21,7 +21,7 @@ class IndexSchema(DataPoint):
id: str
text: str
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
metadata: dict = {"index_fields": ["text"]}
class LanceDBAdapter(VectorDBInterface):
@ -245,7 +245,7 @@ class LanceDBAdapter(VectorDBInterface):
[
IndexSchema(
id=str(data_point.id),
text=getattr(data_point, data_point._metadata["index_fields"][0]),
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points
],
@ -269,5 +269,5 @@ class LanceDBAdapter(VectorDBInterface):
include_fields={
"id": (str, ...),
},
exclude_fields=["_metadata"],
exclude_fields=["metadata"],
)

View file

@ -17,7 +17,7 @@ logger = logging.getLogger("MilvusAdapter")
class IndexSchema(DataPoint):
text: str
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
metadata: dict = {"index_fields": ["text"]}
class MilvusAdapter(VectorDBInterface):
@ -133,7 +133,7 @@ class MilvusAdapter(VectorDBInterface):
formatted_data_points = [
IndexSchema(
id=data_point.id,
text=getattr(data_point, data_point._metadata["index_fields"][0]),
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points
]

View file

@ -22,7 +22,7 @@ from ..utils import normalize_distances
class IndexSchema(DataPoint):
text: str
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
metadata: dict = {"index_fields": ["text"]}
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):

View file

@ -17,7 +17,7 @@ logger = logging.getLogger("QDrantAdapter")
class IndexSchema(DataPoint):
text: str
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
metadata: dict = {"index_fields": ["text"]}
# class CollectionConfig(BaseModel, extra = "forbid"):
@ -131,7 +131,7 @@ class QDrantAdapter(VectorDBInterface):
[
IndexSchema(
id=data_point.id,
text=getattr(data_point, data_point._metadata["index_fields"][0]),
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points
],

View file

@ -16,7 +16,7 @@ logger = logging.getLogger("WeaviateAdapter")
class IndexSchema(DataPoint):
text: str
_metadata: dict = {"index_fields": ["text"], "type": "IndexSchema"}
metadata: dict = {"index_fields": ["text"]}
class WeaviateAdapter(VectorDBInterface):

View file

@ -1 +1,2 @@
from .models.DataPoint import DataPoint
from .models.ExtendableDataPoint import ExtendableDataPoint

View file

@ -9,6 +9,7 @@ import pickle
# Define metadata type
class MetaData(TypedDict):
type: str
index_fields: list[str]
@ -24,35 +25,35 @@ class DataPoint(BaseModel):
)
version: int = 1 # Default version
topological_rank: Optional[int] = 0
_metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"}
metadata: Optional[MetaData] = {"index_fields": []}
@classmethod
def get_embeddable_data(self, data_point):
def get_embeddable_data(self, data_point: "DataPoint"):
if (
data_point._metadata
and len(data_point._metadata["index_fields"]) > 0
and hasattr(data_point, data_point._metadata["index_fields"][0])
data_point.metadata
and len(data_point.metadata["index_fields"]) > 0
and hasattr(data_point, data_point.metadata["index_fields"][0])
):
attribute = getattr(data_point, data_point._metadata["index_fields"][0])
attribute = getattr(data_point, data_point.metadata["index_fields"][0])
if isinstance(attribute, str):
return attribute.strip()
return attribute
@classmethod
def get_embeddable_properties(self, data_point):
def get_embeddable_properties(self, data_point: "DataPoint"):
"""Retrieve all embeddable properties."""
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0:
return [
getattr(data_point, field, None) for field in data_point._metadata["index_fields"]
getattr(data_point, field, None) for field in data_point.metadata["index_fields"]
]
return []
@classmethod
def get_embeddable_property_names(self, data_point):
def get_embeddable_property_names(self, data_point: "DataPoint"):
"""Retrieve names of embeddable properties."""
return data_point._metadata["index_fields"] or []
return data_point.metadata["index_fields"] or []
def update_version(self):
"""Update the version and updated_at timestamp."""

View file

@ -0,0 +1,5 @@
from .DataPoint import DataPoint
class ExtendableDataPoint(DataPoint):
pass

View file

@ -1,7 +1,6 @@
"""Adapter for Generic API LLM provider API"""
import asyncio
from typing import List, Type
from typing import Type
from pydantic import BaseModel
import instructor
@ -24,13 +23,9 @@ class GenericAPIAdapter(LLMInterface):
self.endpoint = endpoint
self.max_tokens = max_tokens
llm_config = get_llm_config()
if llm_config.llm_provider == "ollama":
self.aclient = instructor.from_litellm(litellm.acompletion)
else:
self.aclient = instructor.from_litellm(litellm.acompletion)
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]

View file

@ -34,5 +34,5 @@ class HuggingFaceTokenizer(TokenizerInterface):
return len(self.tokenizer.tokenize(text))
def decode_single_token(self, encoding: int):
# Gemini tokenizer doesn't have the option to decode tokens
# HuggingFace tokenizer doesn't have the option to decode tokens
raise NotImplementedError

View file

@ -0,0 +1 @@
from .adapter import MistralTokenizer

View file

@ -0,0 +1,47 @@
from typing import List, Any
from ..tokenizer_interface import TokenizerInterface
class MistralTokenizer(TokenizerInterface):
def __init__(
self,
model: str,
max_tokens: int = 3072,
):
self.model = model
self.max_tokens = max_tokens
# Import here to make it an optional dependency
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
self.tokenizer = MistralTokenizer.from_model(model)
def extract_tokens(self, text: str) -> List[Any]:
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.tokens.tokenizers.base import Tokenized
encoding: Tokenized = self.tokenizer.encode_chat_completion(
ChatCompletionRequest(
messages=[UserMessage(role="user", content=text)],
model=self.model,
)
)
return encoding.tokens
def count_tokens(self, text: str) -> int:
"""
Returns the number of tokens in the given text.
Args:
text: str
Returns:
number of tokens in the given text
"""
return len(self.extract_tokens(text))
def decode_single_token(self, encoding: int):
# Mistral tokenizer doesn't have the option to decode tokens
raise NotImplementedError

2
cognee/low_level.py Normal file
View file

@ -0,0 +1,2 @@
from cognee.infrastructure.engine import ExtendableDataPoint as DataPoint
from cognee.modules.engine.operations.setup import setup

View file

@ -53,7 +53,7 @@ class TextChunker:
chunk_index=self.chunk_index,
cut_type=chunk_data["cut_type"],
contains=[],
_metadata={
metadata={
"index_fields": ["text"],
},
)
@ -73,7 +73,7 @@ class TextChunker:
chunk_index=self.chunk_index,
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
contains=[],
_metadata={
metadata={
"index_fields": ["text"],
},
)
@ -97,7 +97,7 @@ class TextChunker:
chunk_index=self.chunk_index,
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
contains=[],
_metadata={"index_fields": ["text"]},
metadata={"index_fields": ["text"]},
)
except Exception as e:
logger.error(e)

View file

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
@ -16,4 +16,4 @@ class DocumentChunk(DataPoint):
pydantic_type: str = "DocumentChunk"
contains: List[Entity] = None
_metadata: dict = {"index_fields": ["text"], "type": "DocumentChunk"}
metadata: dict = {"index_fields": ["text"]}

View file

@ -1,6 +1,4 @@
from typing import Optional
from uuid import UUID
from cognee.infrastructure.engine import DataPoint
@ -9,7 +7,7 @@ class Document(DataPoint):
raw_data_location: str
external_metadata: Optional[str]
mime_type: str
_metadata: dict = {"index_fields": ["name"], "type": "Document"}
metadata: dict = {"index_fields": ["name"]}
def read(self, chunk_size: int, chunker=str, max_chunk_tokens: Optional[int] = None) -> str:
pass

View file

@ -1,5 +1,3 @@
from typing import Optional
from .ChunkerMapping import ChunkerConfig
from .Document import Document

View file

@ -9,4 +9,4 @@ class Entity(DataPoint):
description: str
pydantic_type: str = "Entity"
_metadata: dict = {"index_fields": ["name"], "type": "Entity"}
metadata: dict = {"index_fields": ["name"]}

View file

@ -7,4 +7,4 @@ class EntityType(DataPoint):
description: str
pydantic_type: str = "EntityType"
_metadata: dict = {"index_fields": ["name"], "type": "EntityType"}
metadata: dict = {"index_fields": ["name"]}

View file

@ -0,0 +1,11 @@
from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
)
from cognee.infrastructure.databases.vector.pgvector import (
create_db_and_tables as create_pgvector_db_and_tables,
)
async def setup():
await create_relational_db_and_tables()
await create_pgvector_db_and_tables()

View file

@ -1,5 +1,3 @@
from typing import Optional
from cognee.infrastructure.engine import DataPoint
@ -8,4 +6,4 @@ class EdgeType(DataPoint):
relationship_name: str
number_of_edges: int
_metadata: dict = {"index_fields": ["relationship_name"], "type": "EdgeType"}
metadata: dict = {"index_fields": ["relationship_name"]}

View file

@ -2,7 +2,7 @@ from cognee.infrastructure.engine import DataPoint
def convert_node_to_data_point(node_data: dict) -> DataPoint:
subclass = find_subclass_by_name(DataPoint, node_data._metadata["type"])
subclass = find_subclass_by_name(DataPoint, node_data["type"])
return subclass(**node_data)

View file

@ -17,12 +17,14 @@ async def get_graph_from_model(
edges = []
visited_properties = visited_properties or {}
data_point_properties = {}
data_point_properties = {
"type": type(data_point).__name__,
}
excluded_properties = set()
properties_to_visit = set()
for field_name, field_value in data_point:
if field_name == "_metadata":
if field_name == "metadata":
continue
if isinstance(field_value, DataPoint):
@ -60,7 +62,6 @@ async def get_graph_from_model(
SimpleDataPointModel = copy_model(
type(data_point),
include_fields={
"_metadata": (dict, data_point._metadata),
"__tablename__": (str, data_point.__tablename__),
},
exclude_fields=list(excluded_properties),

View file

@ -84,10 +84,10 @@ async def brute_force_search(
if collections is None:
collections = [
"entity_name",
"text_summary_text",
"entity_type_name",
"document_chunk_text",
"Entity_name",
"TextSummary_text",
"EntityType_name",
"DocumentChunk_text",
]
try:
@ -127,9 +127,14 @@ async def brute_force_search(
return results
except Exception as e:
except Exception as error:
logging.error(
"Error during brute force search for user: %s, query: %s. Error: %s", user.id, query, e
"Error during brute force search for user: %s, query: %s. Error: %s",
user.id,
query,
error,
)
send_telemetry("cognee.brute_force_triplet_search EXECUTION FAILED", user.id)
raise RuntimeError("An error occurred during brute force search") from e
send_telemetry(
"cognee.brute_force_triplet_search EXECUTION FAILED", user.id, {"error": str(error)}
)
raise RuntimeError("An error occurred during brute force search") from error

View file

@ -101,7 +101,6 @@ async def code_description_to_code_part(
"text",
"file_path",
"source_code",
"pydantic_type",
],
edge_properties_to_project=["relationship_name"],
)
@ -117,13 +116,13 @@ async def code_description_to_code_part(
continue
for code_file in node_to_search_from.get_skeleton_neighbours():
if code_file.get_attribute("pydantic_type") == "SourceCodeChunk":
if code_file.get_attribute("type") == "SourceCodeChunk":
for code_file_edge in code_file.get_skeleton_edges():
if code_file_edge.get_attribute("relationship_name") == "code_chunk_of":
code_pieces_to_return.add(code_file_edge.get_destination_node())
elif code_file.get_attribute("pydantic_type") == "CodePart":
elif code_file.get_attribute("type") == "CodePart":
code_pieces_to_return.add(code_file)
elif code_file.get_attribute("pydantic_type") == "CodeFile":
elif code_file.get_attribute("type") == "CodeFile":
for code_file_edge in code_file.get_skeleton_edges():
if code_file_edge.get_attribute("relationship_name") == "contains":
code_pieces_to_return.add(code_file_edge.get_destination_node())

View file

@ -36,7 +36,7 @@ def get_own_properties(data_point: DataPoint):
for field_name, field_value in data_point:
if (
field_name == "_metadata"
field_name == "metadata"
or isinstance(field_value, dict)
or isinstance(field_value, DataPoint)
or (isinstance(field_value, list) and isinstance(field_value[0], DataPoint))

View file

@ -1,9 +1,15 @@
import logging
import networkx as nx
import json
import os
from cognee.infrastructure.files.storage import LocalStorage
async def cognee_network_visualization(graph_data):
logger = logging.getLogger(__name__)
async def cognee_network_visualization(graph_data, destination_file_path: str = None):
nodes_data, edges_data = graph_data
G = nx.DiGraph()
@ -19,7 +25,7 @@ async def cognee_network_visualization(graph_data):
for node_id, node_info in nodes_data:
node_info = node_info.copy()
node_info["id"] = str(node_id)
node_info["color"] = color_map.get(node_info.get("pydantic_type", "default"), "#D3D3D3")
node_info["color"] = color_map.get(node_info.get("type", "default"), "#D3D3D3")
node_info["name"] = node_info.get("name", str(node_id))
try:
@ -178,12 +184,15 @@ async def cognee_network_visualization(graph_data):
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
html_content = html_content.replace("{links}", json.dumps(links_list))
home_dir = os.path.expanduser("~")
output_file = os.path.join(home_dir, "graph_visualization.html")
if not destination_file_path:
home_dir = os.path.expanduser("~")
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
with open(output_file, "w") as f:
LocalStorage.ensure_directory_exists(os.path.dirname(destination_file_path))
with open(destination_file_path, "w") as f:
f.write(html_content)
print(f"Graph visualization saved as {output_file}")
logger.info(f"Graph visualization saved as {destination_file_path}")
return html_content

View file

@ -6,7 +6,7 @@ class Repository(DataPoint):
__tablename__ = "Repository"
path: str
pydantic_type: str = "Repository"
_metadata: dict = {"index_fields": [], "type": "Repository"}
metadata: dict = {"index_fields": []}
class CodeFile(DataPoint):
@ -18,7 +18,7 @@ class CodeFile(DataPoint):
depends_on: Optional[List["CodeFile"]] = None
depends_directly_on: Optional[List["CodeFile"]] = None
contains: Optional[List["CodePart"]] = None
_metadata: dict = {"index_fields": [], "type": "CodeFile"}
metadata: dict = {"index_fields": []}
class CodePart(DataPoint):
@ -27,7 +27,7 @@ class CodePart(DataPoint):
# part_of: Optional[CodeFile] = None
pydantic_type: str = "CodePart"
source_code: Optional[str] = None
_metadata: dict = {"index_fields": [], "type": "CodePart"}
metadata: dict = {"index_fields": []}
class SourceCodeChunk(DataPoint):
@ -37,7 +37,7 @@ class SourceCodeChunk(DataPoint):
pydantic_type: str = "SourceCodeChunk"
previous_chunk: Optional["SourceCodeChunk"] = None
_metadata: dict = {"index_fields": ["source_code"], "type": "SourceCodeChunk"}
metadata: dict = {"index_fields": ["source_code"]}
CodeFile.model_rebuild()

View file

@ -11,7 +11,7 @@ class Variable(DataPoint):
default_value: Optional[str] = None
data_type: str
_metadata = {"index_fields": ["name"], "type": "Variable"}
metadata: dict = {"index_fields": ["name"]}
class Operator(DataPoint):
@ -19,7 +19,7 @@ class Operator(DataPoint):
name: str
description: str
return_type: str
_metadata = {"index_fields": ["name"], "type": "Operator"}
metadata: dict = {"index_fields": ["name"]}
class Class(DataPoint):
@ -30,7 +30,7 @@ class Class(DataPoint):
extended_from_class: Optional["Class"] = None
has_methods: List["Function"]
_metadata = {"index_fields": ["name"], "type": "Class"}
metadata: dict = {"index_fields": ["name"]}
class ClassInstance(DataPoint):
@ -41,7 +41,7 @@ class ClassInstance(DataPoint):
instantiated_by: Union["Function"]
instantiation_arguments: List[Variable]
_metadata = {"index_fields": ["name"], "type": "ClassInstance"}
metadata: dict = {"index_fields": ["name"]}
class Function(DataPoint):
@ -52,7 +52,7 @@ class Function(DataPoint):
return_type: str
is_static: Optional[bool] = False
_metadata = {"index_fields": ["name"], "type": "Function"}
metadata: dict = {"index_fields": ["name"]}
class FunctionCall(DataPoint):
@ -60,7 +60,7 @@ class FunctionCall(DataPoint):
called_by: Union[Function, Literal["main"]]
function_called: Function
function_arguments: List[Any]
_metadata = {"index_fields": [], "type": "FunctionCall"}
metadata: dict = {"index_fields": []}
class Expression(DataPoint):
@ -69,7 +69,7 @@ class Expression(DataPoint):
description: str
expression: str
members: List[Union[Variable, Function, Operator, "Expression"]]
_metadata = {"index_fields": ["name"], "type": "Expression"}
metadata: dict = {"index_fields": ["name"]}
class SourceCodeGraph(DataPoint):
@ -88,7 +88,7 @@ class SourceCodeGraph(DataPoint):
Expression,
]
]
_metadata = {"index_fields": ["name"], "type": "SourceCodeGraph"}
metadata: dict = {"index_fields": ["name"]}
Class.model_rebuild()

View file

@ -13,7 +13,6 @@ import matplotlib.pyplot as plt
import logging
import sys
import json
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.graph import get_graph_engine
@ -164,6 +163,7 @@ def prepare_nodes(graph, include_size=False):
continue
node_data = {
**node_info,
"id": str(node),
"name": node_info["name"] if "name" in node_info else str(node),
}
@ -183,7 +183,7 @@ def prepare_nodes(graph, include_size=False):
async def render_graph(
graph, include_nodes=False, include_color=False, include_size=False, include_labels=False
graph=None, include_nodes=True, include_color=False, include_size=False, include_labels=True
):
await register_graphistry()

View file

@ -15,11 +15,11 @@ async def query_chunks(query: str) -> list[dict]:
Notes:
- The function uses the `search` method of the vector engine to find matches.
- Limits the results to the top 5 matching chunks to balance performance and relevance.
- Ensure that the vector database is properly initialized and contains the "document_chunk_text" collection.
- Ensure that the vector database is properly initialized and contains the "DocumentChunk_text" collection.
"""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("document_chunk_text", query, limit=5)
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5)
chunks = [result.payload for result in found_chunks]

View file

@ -1,4 +1,5 @@
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import ExtendableDataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.tasks.completion.exceptions import NoRelevantDataFound
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
@ -34,9 +35,20 @@ async def graph_query_completion(query: str) -> list:
- The `brute_force_triplet_search` is used to retrieve relevant graph data.
- Prompts are dynamically rendered and provided to the LLM for contextual understanding.
- Ensure that the LLM client and graph database are properly configured and accessible.
"""
found_triplets = await brute_force_triplet_search(query, top_k=5)
subclasses = get_all_subclasses(ExtendableDataPoint)
vector_index_collections = []
for subclass in subclasses:
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
for field_name in index_fields:
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
found_triplets = await brute_force_triplet_search(
query, top_k=5, collections=vector_index_collections or None
)
if len(found_triplets) == 0:
raise NoRelevantDataFound

View file

@ -23,7 +23,7 @@ async def query_completion(query: str) -> list:
"""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("document_chunk_text", query, limit=1)
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1)
if len(found_chunks) == 0:
raise NoRelevantDataFound

View file

@ -10,6 +10,7 @@ from cognee.modules.graph.utils import (
expand_with_nodes_and_edges,
retrieve_existing_edges,
)
from cognee.shared.data_models import KnowledgeGraph
from cognee.tasks.storage import add_data_points
@ -18,7 +19,6 @@ async def extract_graph_from_data(
) -> List[DocumentChunk]:
"""
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
"""
chunk_graphs = await asyncio.gather(
@ -26,6 +26,13 @@ async def extract_graph_from_data(
)
graph_engine = await get_graph_engine()
if graph_model is not KnowledgeGraph:
for chunk_index, chunk_graph in enumerate(chunk_graphs):
data_chunks[chunk_index].contains = chunk_graph
await add_data_points(chunk_graphs)
return data_chunks
existing_edges_map = await retrieve_existing_edges(
data_chunks,
chunk_graphs,

View file

@ -28,8 +28,8 @@ async def query_graph_connections(query: str, exploration_levels=1) -> list[(str
else:
vector_engine = get_vector_engine()
results = await asyncio.gather(
vector_engine.search("entity_name", query_text=query, limit=5),
vector_engine.search("entity_type_name", query_text=query, limit=5),
vector_engine.search("Entity_name", query_text=query, limit=5),
vector_engine.search("EntityType_name", query_text=query, limit=5),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]

View file

@ -16,25 +16,25 @@ async def index_data_points(data_points: list[DataPoint]):
for data_point in data_points:
data_point_type = type(data_point)
for field_name in data_point._metadata["index_fields"]:
for field_name in data_point.metadata["index_fields"]:
if getattr(data_point, field_name, None) is None:
continue
index_name = f"{data_point_type.__tablename__}.{field_name}"
index_name = f"{data_point_type.__name__}_{field_name}"
if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
created_indexes[index_name] = True
if index_name not in index_points:
index_points[index_name] = []
indexed_data_point = data_point.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
indexed_data_point.metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)
for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
index_name, field_name = index_name.split("_")
try:
await vector_engine.index_data_points(index_name, field_name, indexable_points)
except EmbeddingException as e:
@ -101,13 +101,13 @@ if __name__ == "__main__":
class Car(DataPoint):
model: str
color: str
_metadata = {"index_fields": ["name"], "type": "Car"}
metadata: dict = {"index_fields": ["name"]}
class Person(DataPoint):
name: str
age: int
owns_car: list[Car]
_metadata = {"index_fields": ["name"], "type": "Person"}
metadata: dict = {"index_fields": ["name"]}
car1 = Car(model="Tesla Model S", color="Blue")
car2 = Car(model="Toyota Camry", color="Red")

View file

@ -50,7 +50,7 @@ async def index_graph_edges():
edge = EdgeType(relationship_name=text, number_of_edges=count)
data_point_type = type(edge)
for field_name in edge._metadata["index_fields"]:
for field_name in edge.metadata["index_fields"]:
index_name = f"{data_point_type.__tablename__}.{field_name}"
if index_name not in created_indexes:
@ -61,7 +61,7 @@ async def index_graph_edges():
index_points[index_name] = []
indexed_data_point = edge.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
indexed_data_point.metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)
for index_name, indexable_points in index_points.items():

View file

@ -10,7 +10,7 @@ class TextSummary(DataPoint):
text: str
made_from: DocumentChunk
_metadata: dict = {"index_fields": ["text"], "type": "TextSummary"}
metadata: dict = {"index_fields": ["text"]}
class CodeSummary(DataPoint):
@ -19,4 +19,4 @@ class CodeSummary(DataPoint):
summarizes: Union[CodeFile, CodePart, SourceCodeChunk]
pydantic_type: str = "CodeSummary"
_metadata: dict = {"index_fields": ["text"], "type": "CodeSummary"}
metadata: dict = {"index_fields": ["text"]}

View file

@ -11,7 +11,7 @@ async def query_summaries(query: str) -> list:
"""
vector_engine = get_vector_engine()
summaries_results = await vector_engine.search("text_summary_text", query, limit=5)
summaries_results = await vector_engine.search("TextSummary_text", query, limit=5)
summaries = [summary.payload for summary in summaries_results]

View file

@ -1,5 +1,5 @@
from cognee.infrastructure.engine import DataPoint
from typing import ClassVar, Optional
from typing import Optional
class GraphitiNode(DataPoint):
@ -9,4 +9,4 @@ class GraphitiNode(DataPoint):
summary: Optional[str] = None
pydantic_type: str = "GraphitiNode"
_metadata: dict = {"index_fields": ["name", "summary", "content"], "type": "GraphitiNode"}
metadata: dict = {"index_fields": ["name", "summary", "content"]}

View file

@ -36,7 +36,7 @@ async def index_and_transform_graphiti_nodes_and_edges():
data_point_type = type(graphiti_node)
for field_name in graphiti_node._metadata["index_fields"]:
for field_name in graphiti_node.metadata["index_fields"]:
index_name = f"{data_point_type.__tablename__}.{field_name}"
if index_name not in created_indexes:
@ -48,7 +48,7 @@ async def index_and_transform_graphiti_nodes_and_edges():
if getattr(graphiti_node, field_name, None) is not None:
indexed_data_point = graphiti_node.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
indexed_data_point.metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)
for index_name, indexable_points in index_points.items():
@ -65,7 +65,7 @@ async def index_and_transform_graphiti_nodes_and_edges():
edge = EdgeType(relationship_name=text, number_of_edges=count)
data_point_type = type(edge)
for field_name in edge._metadata["index_fields"]:
for field_name in edge.metadata["index_fields"]:
index_name = f"{data_point_type.__tablename__}.{field_name}"
if index_name not in created_indexes:
@ -76,7 +76,7 @@ async def index_and_transform_graphiti_nodes_and_edges():
index_points[index_name] = []
indexed_data_point = edge.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
indexed_data_point.metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)
for index_name, indexable_points in index_points.items():

View file

@ -0,0 +1,38 @@
[
{
"name": "TechNova Inc.",
"departments": [
"Engineering",
"Marketing"
]
},
{
"name": "GreenFuture Solutions",
"departments": [
"Research & Development",
"Sales",
"Customer Support"
]
},
{
"name": "Skyline Financials",
"departments": [
"Accounting"
]
},
{
"name": "MediCare Plus",
"departments": [
"Healthcare",
"Administration"
]
},
{
"name": "NextGen Robotics",
"departments": [
"AI Development",
"Manufacturing",
"HR"
]
}
]

View file

@ -0,0 +1,52 @@
[
{
"name": "John Doe",
"company": "TechNova Inc.",
"department": "Engineering"
},
{
"name": "Jane Smith",
"company": "TechNova Inc.",
"department": "Marketing"
},
{
"name": "Alice Johnson",
"company": "GreenFuture Solutions",
"department": "Sales"
},
{
"name": "Bob Williams",
"company": "GreenFuture Solutions",
"department": "Customer Support"
},
{
"name": "Michael Brown",
"company": "Skyline Financials",
"department": "Accounting"
},
{
"name": "Emily Davis",
"company": "MediCare Plus",
"department": "Healthcare"
},
{
"name": "David Wilson",
"company": "MediCare Plus",
"department": "Administration"
},
{
"name": "Emma Thompson",
"company": "NextGen Robotics",
"department": "AI Development"
},
{
"name": "Chris Martin",
"company": "NextGen Robotics",
"department": "Manufacturing"
},
{
"name": "Sophia White",
"company": "NextGen Robotics",
"department": "HR"
}
]

View file

@ -0,0 +1,94 @@
import os
import json
import asyncio
from cognee import prune
from cognee import visualize_graph
from cognee.low_level import setup, DataPoint
from cognee.pipelines import run_tasks, Task
from cognee.tasks.storage import add_data_points
from cognee.shared.utils import render_graph
class Person(DataPoint):
name: str
class Department(DataPoint):
name: str
employees: list[Person]
class CompanyType(DataPoint):
name: str = "Company"
class Company(DataPoint):
name: str
departments: list[Department]
is_type: CompanyType
def ingest_files():
companies_file_path = os.path.join(os.path.dirname(__file__), "companies.json")
companies = json.loads(open(companies_file_path, "r").read())
people_file_path = os.path.join(os.path.dirname(__file__), "people.json")
people = json.loads(open(people_file_path, "r").read())
people_data_points = {}
departments_data_points = {}
for person in people:
new_person = Person(name=person["name"])
people_data_points[person["name"]] = new_person
if person["department"] not in departments_data_points:
departments_data_points[person["department"]] = Department(
name=person["department"], employees=[new_person]
)
else:
departments_data_points[person["department"]].employees.append(new_person)
companies_data_points = {}
# Create a single CompanyType node, so we connect all companies to it.
companyType = CompanyType()
for company in companies:
new_company = Company(name=company["name"], departments=[], is_type=companyType)
companies_data_points[company["name"]] = new_company
for department_name in company["departments"]:
if department_name not in departments_data_points:
departments_data_points[department_name] = Department(
name=department_name, employees=[]
)
new_company.departments.append(departments_data_points[department_name])
return companies_data_points.values()
async def main():
await prune.prune_data()
await prune.prune_system(metadata=True)
await setup()
pipeline = run_tasks([Task(ingest_files), Task(add_data_points)])
async for status in pipeline:
print(status)
# Get a graphistry url (Register for a free account at https://www.graphistry.com)
await render_graph()
# Or use our simple graph preview
graph_file_path = str(
os.path.join(os.path.dirname(__file__), ".artifacts/graph_visualization.html")
)
await visualize_graph(graph_file_path)
if __name__ == "__main__":
asyncio.run(main())

104
cognee/tests/test_custom_model.py Executable file
View file

@ -0,0 +1,104 @@
import os
import logging
import pathlib
import cognee
from cognee.modules.search.types import SearchType
from cognee.shared.utils import render_graph
from cognee.low_level import DataPoint
logging.basicConfig(level=logging.DEBUG)
async def main():
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_custom_model")
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_custom_model")
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Define a custom graph model for programming languages.
class FieldType(DataPoint):
name: str = "Field"
metadata: dict = {"index_fields": ["name"]}
class Field(DataPoint):
name: str
is_type: FieldType
metadata: dict = {"index_fields": ["name"]}
class ProgrammingLanguageType(DataPoint):
name: str = "Programming Language"
metadata: dict = {"index_fields": ["name"]}
class ProgrammingLanguage(DataPoint):
name: str
used_in: list[Field] = []
is_type: ProgrammingLanguageType
metadata: dict = {"index_fields": ["name"]}
text = (
"Python is an interpreted, high-level, general-purpose programming language. It was created by Guido van Rossum and first released in 1991. "
+ "Python is widely used in data analysis, web development, and machine learning."
)
await cognee.add(text)
await cognee.cognify(graph_model=ProgrammingLanguage)
url = await render_graph()
print(f"Graphistry URL: {url}")
graph_file_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent,
".artifacts/test_custom_model/graph_visualization.html",
)
).resolve()
)
await cognee.visualize_graph(graph_file_path)
# Completion query that uses graph data to form context.
completion = await cognee.search(SearchType.GRAPH_COMPLETION, "What is python?")
assert len(completion) != 0, "Graph completion search didn't return any result."
print("Graph completion result is:")
print(completion)
# Completion query that uses document chunks to form context.
completion = await cognee.search(SearchType.COMPLETION, "What is Python?")
assert len(completion) != 0, "Completion search didn't return any result."
print("Completion result is:")
print(completion)
# Query all summaries related to query.
summaries = await cognee.search(SearchType.SUMMARIES, "Python")
assert len(summaries) != 0, "Summaries search didn't return any results."
print("Summary results are:")
for summary in summaries:
print(summary)
chunks = await cognee.search(SearchType.CHUNKS, query_text="Python")
assert len(chunks) != 0, "Chunks search didn't return any results."
print("Chunk results are:")
for chunk in chunks:
print(chunk)
history = await cognee.get_search_history()
assert len(history) == 8, "Search history is not correct."
if __name__ == "__main__":
import asyncio
asyncio.run(main(), debug=True)

View file

@ -44,7 +44,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entity_name", "AI"))[0]
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -55,7 +55,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -48,7 +48,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -123,7 +123,7 @@ async def main():
await test_getting_of_documents(dataset_name_1)
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
@ -142,6 +142,15 @@ async def main():
for result in search_results:
print(f"{result}\n")
graph_completion = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=random_node_name,
datasets=[dataset_name_2],
)
assert len(graph_completion) != 0, "Completion result is empty."
print("Completion result is:")
print(graph_completion)
search_results = await cognee.search(
query_type=SearchType.SUMMARIES, query_text=random_node_name
)
@ -151,7 +160,7 @@ async def main():
print(f"{result}\n")
history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct."
assert len(history) == 8, "Search history is not correct."
results = await brute_force_triplet_search("What is a quantum computer?")
assert len(results) > 0

View file

@ -48,7 +48,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -48,7 +48,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -12,7 +12,7 @@ random.seed(1500)
class Repository(DataPoint):
path: str
_metadata = {"index_fields": [], "type": "Repository"}
metadata: dict = {"index_fields": []}
class CodeFile(DataPoint):
@ -20,13 +20,13 @@ class CodeFile(DataPoint):
contains: List["CodePart"] = []
depends_on: List["CodeFile"] = []
source_code: str
_metadata = {"index_fields": [], "type": "CodeFile"}
metadata: dict = {"index_fields": []}
class CodePart(DataPoint):
part_of: CodeFile
source_code: str
_metadata = {"index_fields": [], "type": "CodePart"}
metadata: dict = {"index_fields": []}
CodeFile.model_rebuild()

View file

@ -9,25 +9,25 @@ from cognee.modules.graph.utils import get_graph_from_model
class Document(DataPoint):
path: str
_metadata = {"index_fields": [], "type": "Document"}
metadata: dict = {"index_fields": []}
class DocumentChunk(DataPoint):
part_of: Document
text: str
contains: List["Entity"] = None
_metadata = {"index_fields": ["text"], "type": "DocumentChunk"}
metadata: dict = {"index_fields": ["text"]}
class EntityType(DataPoint):
name: str
_metadata = {"index_fields": ["name"], "type": "EntityType"}
metadata: dict = {"index_fields": ["name"]}
class Entity(DataPoint):
name: str
is_type: EntityType
_metadata = {"index_fields": ["name"], "type": "Entity"}
metadata: dict = {"index_fields": ["name"]}
DocumentChunk.model_rebuild()

View file

@ -7,25 +7,25 @@ from cognee.modules.graph.utils import get_graph_from_model
class Document(DataPoint):
path: str
_metadata = {"index_fields": [], "type": "Document"}
metadata: dict = {"index_fields": []}
class DocumentChunk(DataPoint):
part_of: Document
text: str
contains: List["Entity"] = None
_metadata = {"index_fields": ["text"], "type": "DocumentChunk"}
metadata: dict = {"index_fields": ["text"]}
class EntityType(DataPoint):
name: str
_metadata = {"index_fields": ["name"], "type": "EntityType"}
metadata: dict = {"index_fields": ["name"]}
class Entity(DataPoint):
name: str
is_type: EntityType
_metadata = {"index_fields": ["name"], "type": "Entity"}
metadata: dict = {"index_fields": ["name"]}
DocumentChunk.model_rebuild()

View file

@ -7,11 +7,11 @@ from cognee.modules.visualization.cognee_network_visualization import (
@pytest.mark.asyncio
async def test_create_cognee_style_network_with_logo():
nodes_data = [
(1, {"pydantic_type": "Entity", "name": "Node1", "updated_at": 123, "created_at": 123}),
(1, {"type": "Entity", "name": "Node1", "updated_at": 123, "created_at": 123}),
(
2,
{
"pydantic_type": "DocumentChunk",
"type": "DocumentChunk",
"name": "Node2",
"updated_at": 123,
"created_at": 123,

View file

@ -110,7 +110,7 @@ async def get_context_with_simple_rag(instance: dict) -> str:
await cognify_instance(instance)
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("document_chunk_text", instance["question"], limit=5)
found_chunks = await vector_engine.search("DocumentChunk_text", instance["question"], limit=5)
search_results_str = "\n".join([context_item.payload["text"] for context_item in found_chunks])

View file

@ -797,7 +797,7 @@
"from cognee.infrastructure.databases.vector import get_vector_engine\n",
"\n",
"vector_engine = get_vector_engine()\n",
"results = await search(vector_engine, \"entity_name\", \"sarah.nguyen@example.com\")\n",
"results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n",
"for result in results:\n",
" print(result)"
]
@ -827,7 +827,7 @@
"source": [
"from cognee.api.v1.search import SearchType\n",
"\n",
"node = (await vector_engine.search(\"entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
"node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
"node_name = node.payload[\"text\"]\n",
"\n",
"search_results = await cognee.search(query_type=SearchType.SUMMARIES, query_text = node_name)\n",
@ -1237,7 +1237,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "cognee-c83GrcRT-py3.11",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@ -1251,7 +1251,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
"version": "3.11.8"
}
},
"nbformat": 4,

View file

@ -533,7 +533,7 @@
"from cognee.infrastructure.databases.vector import get_vector_engine\n",
"\n",
"vector_engine = get_vector_engine()\n",
"results = await search(vector_engine, \"entity_name\", \"sarah.nguyen@example.com\")\n",
"results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n",
"for result in results:\n",
" print(result)"
]
@ -563,7 +563,7 @@
"source": [
"from cognee.api.v1.search import SearchType\n",
"\n",
"node = (await vector_engine.search(\"entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
"node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
"node_name = node.payload[\"text\"]\n",
"\n",
"search_results = await cognee.search(query_type=SearchType.SUMMARIES, query_text=node_name)\n",

1299
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -80,6 +80,7 @@ nltk = "3.9.1"
google-generativeai = {version = "^0.8.4", optional = true}
parso = {version = "^0.8.4", optional = true}
jedi = {version = "^0.19.2", optional = true}
mistral-common = {version = "^1.5.2", optional = true}
[tool.poetry.extras]
@ -93,6 +94,7 @@ langchain = ["langsmith", "langchain_text_splitters"]
llama-index = ["llama-index-core"]
gemini = ["google-generativeai"]
huggingface = ["transformers"]
mistral = ["mistral-common"]
deepeval = ["deepeval"]
posthog = ["posthog"]
falkordb = ["falkordb"]