fix: update test for custom model
This commit is contained in:
parent
d720abee01
commit
75d705463f
2 changed files with 103 additions and 43 deletions
|
|
@ -1,14 +1,8 @@
|
|||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.low_level import DataPoint
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
async def main():
|
||||
data_directory_path = str(
|
||||
|
|
@ -53,48 +47,114 @@ async def main():
|
|||
)
|
||||
|
||||
await cognee.add(text)
|
||||
|
||||
await cognee.cognify(graph_model=ProgrammingLanguage)
|
||||
|
||||
graph_file_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".artifacts/test_custom_model/graph_visualization.html",
|
||||
)
|
||||
).resolve()
|
||||
)
|
||||
await cognee.visualize_graph(graph_file_path)
|
||||
|
||||
await cognee.visualize_graph(destination_file_path="cognee/tests/test_custom_model.html")
|
||||
|
||||
# Completion query that uses graph data to form context.
|
||||
completion = await cognee.search(SearchType.GRAPH_COMPLETION, "What is python?")
|
||||
assert len(completion) != 0, "Graph completion search didn't return any result."
|
||||
print("Graph completion result is:")
|
||||
print(completion)
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower()
|
||||
|
||||
# Query for Python entity and verify it exists with correct type
|
||||
python_found = False
|
||||
python_type = None
|
||||
|
||||
if graph_db_provider in ["neo4j", "neptune", "neptune_analytics"]:
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE n.name = 'Python'
|
||||
RETURN n.name as name, n.type as type
|
||||
"""
|
||||
results = await graph_engine.query(query)
|
||||
if results:
|
||||
python_found = True
|
||||
python_type = results[0]["type"]
|
||||
|
||||
elif graph_db_provider == "kuzu":
|
||||
query = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.name = 'Python'
|
||||
RETURN n.name, n.type
|
||||
"""
|
||||
results = await graph_engine.query(query)
|
||||
if results:
|
||||
python_found = True
|
||||
python_type = results[0][1]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
||||
|
||||
# Completion query that uses document chunks to form context.
|
||||
completion = await cognee.search(SearchType.RAG_COMPLETION, "What is Python?")
|
||||
assert len(completion) != 0, "Completion search didn't return any result."
|
||||
print("Completion result is:")
|
||||
print(completion)
|
||||
assert python_found, "Python entity was not extracted from the text"
|
||||
assert python_type == "ProgrammingLanguage", f"Python entity has incorrect type: {python_type}, expected: ProgrammingLanguage"
|
||||
|
||||
# Query for entities that should NOT exist (Guido van Rossum and 1991)
|
||||
guido_found = False
|
||||
year_1991_found = False
|
||||
|
||||
if graph_db_provider in ["neo4j", "neptune", "neptune_analytics"]:
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE n.name IN ['Guido van Rossum', '1991']
|
||||
RETURN n.name as name
|
||||
"""
|
||||
results = await graph_engine.query(query)
|
||||
for result in results:
|
||||
if result["name"] == "Guido van Rossum":
|
||||
guido_found = True
|
||||
elif result["name"] == "1991":
|
||||
year_1991_found = True
|
||||
|
||||
elif graph_db_provider == "kuzu":
|
||||
query = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.name IN ['Guido van Rossum', '1991']
|
||||
RETURN n.name
|
||||
"""
|
||||
results = await graph_engine.query(query)
|
||||
for result in results:
|
||||
if result[0] == "Guido van Rossum":
|
||||
guido_found = True
|
||||
elif result[0] == "1991":
|
||||
year_1991_found = True
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
||||
|
||||
# Query all summaries related to query.
|
||||
summaries = await cognee.search(SearchType.SUMMARIES, "Python")
|
||||
assert len(summaries) != 0, "Summaries search didn't return any results."
|
||||
print("Summary results are:")
|
||||
for summary in summaries:
|
||||
print(summary)
|
||||
|
||||
assert not guido_found, "Guido van Rossum should not be extracted as it's not in the custom graph model"
|
||||
assert not year_1991_found, "1991 should not be extracted as it's not in the custom graph model"
|
||||
|
||||
# Query for Field entities that might have been extracted (data analysis, web development, machine learning)
|
||||
field_entities = []
|
||||
|
||||
if graph_db_provider in ["neo4j", "neptune", "neptune_analytics"]:
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE n.type = 'Field'
|
||||
RETURN n.name as name
|
||||
"""
|
||||
results = await graph_engine.query(query)
|
||||
field_entities = [r["name"] for r in results]
|
||||
|
||||
elif graph_db_provider == "kuzu":
|
||||
query = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.type = 'Field'
|
||||
RETURN n.name
|
||||
"""
|
||||
results = await graph_engine.query(query)
|
||||
field_entities = [r[0] for r in results if r[0]]
|
||||
else:
|
||||
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
||||
|
||||
chunks = await cognee.search(SearchType.CHUNKS, query_text="Python")
|
||||
assert len(chunks) != 0, "Chunks search didn't return any results."
|
||||
print("Chunk results are:")
|
||||
for chunk in chunks:
|
||||
print(chunk)
|
||||
|
||||
user = await get_default_user()
|
||||
history = await get_history(user.id)
|
||||
|
||||
assert len(history) == 8, "Search history is not correct."
|
||||
assert len(field_entities) > 0, f"No Field entities were extracted. Expected fields like 'data analysis', 'web development', 'machine learning' but got: {field_entities}"
|
||||
|
||||
expected_fields = ["data analysis", "web development", "machine learning"]
|
||||
found_expected_fields = [f for f in expected_fields if any(f in field.lower() for field in field_entities)]
|
||||
|
||||
assert len(found_expected_fields) > 0, f"None of the expected Field entities were found. Expected at least one of {expected_fields}, but got: {field_entities}"
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue