diff --git a/cognee/tests/test_custom_model.py b/cognee/tests/test_custom_model.py index cdf41d605..56b4697ca 100755 --- a/cognee/tests/test_custom_model.py +++ b/cognee/tests/test_custom_model.py @@ -1,14 +1,8 @@ import os import pathlib import cognee -from cognee.modules.search.operations import get_history -from cognee.modules.users.methods import get_default_user -from cognee.shared.logging_utils import get_logger -from cognee.modules.search.types import SearchType from cognee.low_level import DataPoint - -logger = get_logger() - +from cognee.infrastructure.databases.graph import get_graph_engine async def main(): data_directory_path = str( @@ -53,48 +47,114 @@ async def main(): ) await cognee.add(text) - await cognee.cognify(graph_model=ProgrammingLanguage) - graph_file_path = str( - pathlib.Path( - os.path.join( - pathlib.Path(__file__).parent, - ".artifacts/test_custom_model/graph_visualization.html", - ) - ).resolve() - ) - await cognee.visualize_graph(graph_file_path) + + await cognee.visualize_graph(destination_file_path="cognee/tests/test_custom_model.html") - # Completion query that uses graph data to form context. - completion = await cognee.search(SearchType.GRAPH_COMPLETION, "What is python?") - assert len(completion) != 0, "Graph completion search didn't return any result." - print("Graph completion result is:") - print(completion) + graph_engine = await get_graph_engine() + + graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower() + + # Query for Python entity and verify it exists with correct type + python_found = False + python_type = None + + if graph_db_provider in ["neo4j", "neptune", "neptune_analytics"]: + query = """ + MATCH (n) + WHERE n.name = 'Python' + RETURN n.name as name, n.type as type + """ + results = await graph_engine.query(query) + if results: + python_found = True + python_type = results[0]["type"] + + elif graph_db_provider == "kuzu": + query = """ + MATCH (n:Node) + WHERE n.name = 'Python' + RETURN n.name, n.type + """ + results = await graph_engine.query(query) + if results: + python_found = True + python_type = results[0][1] + + else: + raise ValueError(f"Unsupported graph database provider: {graph_db_provider}") - # Completion query that uses document chunks to form context. - completion = await cognee.search(SearchType.RAG_COMPLETION, "What is Python?") - assert len(completion) != 0, "Completion search didn't return any result." - print("Completion result is:") - print(completion) + assert python_found, "Python entity was not extracted from the text" + assert python_type == "ProgrammingLanguage", f"Python entity has incorrect type: {python_type}, expected: ProgrammingLanguage" + + # Query for entities that should NOT exist (Guido van Rossum and 1991) + guido_found = False + year_1991_found = False + + if graph_db_provider in ["neo4j", "neptune", "neptune_analytics"]: + query = """ + MATCH (n) + WHERE n.name IN ['Guido van Rossum', '1991'] + RETURN n.name as name + """ + results = await graph_engine.query(query) + for result in results: + if result["name"] == "Guido van Rossum": + guido_found = True + elif result["name"] == "1991": + year_1991_found = True + + elif graph_db_provider == "kuzu": + query = """ + MATCH (n:Node) + WHERE n.name IN ['Guido van Rossum', '1991'] + RETURN n.name + """ + results = await graph_engine.query(query) + for result in results: + if result[0] == "Guido van Rossum": + guido_found = True + elif result[0] == "1991": + year_1991_found = True + + else: + raise ValueError(f"Unsupported graph database provider: {graph_db_provider}") - # Query all summaries related to query. - summaries = await cognee.search(SearchType.SUMMARIES, "Python") - assert len(summaries) != 0, "Summaries search didn't return any results." - print("Summary results are:") - for summary in summaries: - print(summary) + + assert not guido_found, "Guido van Rossum should not be extracted as it's not in the custom graph model" + assert not year_1991_found, "1991 should not be extracted as it's not in the custom graph model" + + # Query for Field entities that might have been extracted (data analysis, web development, machine learning) + field_entities = [] + + if graph_db_provider in ["neo4j", "neptune", "neptune_analytics"]: + query = """ + MATCH (n) + WHERE n.type = 'Field' + RETURN n.name as name + """ + results = await graph_engine.query(query) + field_entities = [r["name"] for r in results] + + elif graph_db_provider == "kuzu": + query = """ + MATCH (n:Node) + WHERE n.type = 'Field' + RETURN n.name + """ + results = await graph_engine.query(query) + field_entities = [r[0] for r in results if r[0]] + else: + raise ValueError(f"Unsupported graph database provider: {graph_db_provider}") - chunks = await cognee.search(SearchType.CHUNKS, query_text="Python") - assert len(chunks) != 0, "Chunks search didn't return any results." - print("Chunk results are:") - for chunk in chunks: - print(chunk) - - user = await get_default_user() - history = await get_history(user.id) - - assert len(history) == 8, "Search history is not correct." + assert len(field_entities) > 0, f"No Field entities were extracted. Expected fields like 'data analysis', 'web development', 'machine learning' but got: {field_entities}" + + expected_fields = ["data analysis", "web development", "machine learning"] + found_expected_fields = [f for f in expected_fields if any(f in field.lower() for field in field_entities)] + + assert len(found_expected_fields) > 0, f"None of the expected Field entities were found. Expected at least one of {expected_fields}, but got: {field_entities}" + if __name__ == "__main__": diff --git a/examples/python/custom-graph-model-example.py b/examples/python/custom_graph_model_example.py similarity index 100% rename from examples/python/custom-graph-model-example.py rename to examples/python/custom_graph_model_example.py