fix: update test for custom model

This commit is contained in:
Hande 2025-11-04 13:43:59 +01:00
parent d720abee01
commit 75d705463f
2 changed files with 103 additions and 43 deletions

View file

@ -1,14 +1,8 @@
import os
import pathlib
import cognee
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.low_level import DataPoint
logger = get_logger()
from cognee.infrastructure.databases.graph import get_graph_engine
async def main():
data_directory_path = str(
@ -53,48 +47,114 @@ async def main():
)
await cognee.add(text)
await cognee.cognify(graph_model=ProgrammingLanguage)
graph_file_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent,
".artifacts/test_custom_model/graph_visualization.html",
)
).resolve()
)
await cognee.visualize_graph(graph_file_path)
await cognee.visualize_graph(destination_file_path="cognee/tests/test_custom_model.html")
# Completion query that uses graph data to form context.
completion = await cognee.search(SearchType.GRAPH_COMPLETION, "What is python?")
assert len(completion) != 0, "Graph completion search didn't return any result."
print("Graph completion result is:")
print(completion)
graph_engine = await get_graph_engine()
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower()
# Query for Python entity and verify it exists with correct type
python_found = False
python_type = None
if graph_db_provider in ["neo4j", "neptune", "neptune_analytics"]:
query = """
MATCH (n)
WHERE n.name = 'Python'
RETURN n.name as name, n.type as type
"""
results = await graph_engine.query(query)
if results:
python_found = True
python_type = results[0]["type"]
elif graph_db_provider == "kuzu":
query = """
MATCH (n:Node)
WHERE n.name = 'Python'
RETURN n.name, n.type
"""
results = await graph_engine.query(query)
if results:
python_found = True
python_type = results[0][1]
else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
# Completion query that uses document chunks to form context.
completion = await cognee.search(SearchType.RAG_COMPLETION, "What is Python?")
assert len(completion) != 0, "Completion search didn't return any result."
print("Completion result is:")
print(completion)
assert python_found, "Python entity was not extracted from the text"
assert python_type == "ProgrammingLanguage", f"Python entity has incorrect type: {python_type}, expected: ProgrammingLanguage"
# Query for entities that should NOT exist (Guido van Rossum and 1991)
guido_found = False
year_1991_found = False
if graph_db_provider in ["neo4j", "neptune", "neptune_analytics"]:
query = """
MATCH (n)
WHERE n.name IN ['Guido van Rossum', '1991']
RETURN n.name as name
"""
results = await graph_engine.query(query)
for result in results:
if result["name"] == "Guido van Rossum":
guido_found = True
elif result["name"] == "1991":
year_1991_found = True
elif graph_db_provider == "kuzu":
query = """
MATCH (n:Node)
WHERE n.name IN ['Guido van Rossum', '1991']
RETURN n.name
"""
results = await graph_engine.query(query)
for result in results:
if result[0] == "Guido van Rossum":
guido_found = True
elif result[0] == "1991":
year_1991_found = True
else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
# Query all summaries related to query.
summaries = await cognee.search(SearchType.SUMMARIES, "Python")
assert len(summaries) != 0, "Summaries search didn't return any results."
print("Summary results are:")
for summary in summaries:
print(summary)
assert not guido_found, "Guido van Rossum should not be extracted as it's not in the custom graph model"
assert not year_1991_found, "1991 should not be extracted as it's not in the custom graph model"
# Query for Field entities that might have been extracted (data analysis, web development, machine learning)
field_entities = []
if graph_db_provider in ["neo4j", "neptune", "neptune_analytics"]:
query = """
MATCH (n)
WHERE n.type = 'Field'
RETURN n.name as name
"""
results = await graph_engine.query(query)
field_entities = [r["name"] for r in results]
elif graph_db_provider == "kuzu":
query = """
MATCH (n:Node)
WHERE n.type = 'Field'
RETURN n.name
"""
results = await graph_engine.query(query)
field_entities = [r[0] for r in results if r[0]]
else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
chunks = await cognee.search(SearchType.CHUNKS, query_text="Python")
assert len(chunks) != 0, "Chunks search didn't return any results."
print("Chunk results are:")
for chunk in chunks:
print(chunk)
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 8, "Search history is not correct."
assert len(field_entities) > 0, f"No Field entities were extracted. Expected fields like 'data analysis', 'web development', 'machine learning' but got: {field_entities}"
expected_fields = ["data analysis", "web development", "machine learning"]
found_expected_fields = [f for f in expected_fields if any(f in field.lower() for field in field_entities)]
assert len(found_expected_fields) > 0, f"None of the expected Field entities were found. Expected at least one of {expected_fields}, but got: {field_entities}"
if __name__ == "__main__":