fix: update test for custom model
This commit is contained in:
parent
d720abee01
commit
75d705463f
2 changed files with 103 additions and 43 deletions
|
|
@ -1,14 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.modules.search.operations import get_history
|
|
||||||
from cognee.modules.users.methods import get_default_user
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
from cognee.modules.search.types import SearchType
|
|
||||||
from cognee.low_level import DataPoint
|
from cognee.low_level import DataPoint
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
data_directory_path = str(
|
data_directory_path = str(
|
||||||
|
|
@ -53,48 +47,114 @@ async def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
await cognee.add(text)
|
await cognee.add(text)
|
||||||
|
|
||||||
await cognee.cognify(graph_model=ProgrammingLanguage)
|
await cognee.cognify(graph_model=ProgrammingLanguage)
|
||||||
|
|
||||||
graph_file_path = str(
|
|
||||||
pathlib.Path(
|
await cognee.visualize_graph(destination_file_path="cognee/tests/test_custom_model.html")
|
||||||
os.path.join(
|
|
||||||
pathlib.Path(__file__).parent,
|
|
||||||
".artifacts/test_custom_model/graph_visualization.html",
|
|
||||||
)
|
|
||||||
).resolve()
|
|
||||||
)
|
|
||||||
await cognee.visualize_graph(graph_file_path)
|
|
||||||
|
|
||||||
# Completion query that uses graph data to form context.
|
graph_engine = await get_graph_engine()
|
||||||
completion = await cognee.search(SearchType.GRAPH_COMPLETION, "What is python?")
|
|
||||||
assert len(completion) != 0, "Graph completion search didn't return any result."
|
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower()
|
||||||
print("Graph completion result is:")
|
|
||||||
print(completion)
|
# Query for Python entity and verify it exists with correct type
|
||||||
|
python_found = False
|
||||||
|
python_type = None
|
||||||
|
|
||||||
|
if graph_db_provider in ["neo4j", "neptune", "neptune_analytics"]:
|
||||||
|
query = """
|
||||||
|
MATCH (n)
|
||||||
|
WHERE n.name = 'Python'
|
||||||
|
RETURN n.name as name, n.type as type
|
||||||
|
"""
|
||||||
|
results = await graph_engine.query(query)
|
||||||
|
if results:
|
||||||
|
python_found = True
|
||||||
|
python_type = results[0]["type"]
|
||||||
|
|
||||||
|
elif graph_db_provider == "kuzu":
|
||||||
|
query = """
|
||||||
|
MATCH (n:Node)
|
||||||
|
WHERE n.name = 'Python'
|
||||||
|
RETURN n.name, n.type
|
||||||
|
"""
|
||||||
|
results = await graph_engine.query(query)
|
||||||
|
if results:
|
||||||
|
python_found = True
|
||||||
|
python_type = results[0][1]
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
||||||
|
|
||||||
# Completion query that uses document chunks to form context.
|
assert python_found, "Python entity was not extracted from the text"
|
||||||
completion = await cognee.search(SearchType.RAG_COMPLETION, "What is Python?")
|
assert python_type == "ProgrammingLanguage", f"Python entity has incorrect type: {python_type}, expected: ProgrammingLanguage"
|
||||||
assert len(completion) != 0, "Completion search didn't return any result."
|
|
||||||
print("Completion result is:")
|
# Query for entities that should NOT exist (Guido van Rossum and 1991)
|
||||||
print(completion)
|
guido_found = False
|
||||||
|
year_1991_found = False
|
||||||
|
|
||||||
|
if graph_db_provider in ["neo4j", "neptune", "neptune_analytics"]:
|
||||||
|
query = """
|
||||||
|
MATCH (n)
|
||||||
|
WHERE n.name IN ['Guido van Rossum', '1991']
|
||||||
|
RETURN n.name as name
|
||||||
|
"""
|
||||||
|
results = await graph_engine.query(query)
|
||||||
|
for result in results:
|
||||||
|
if result["name"] == "Guido van Rossum":
|
||||||
|
guido_found = True
|
||||||
|
elif result["name"] == "1991":
|
||||||
|
year_1991_found = True
|
||||||
|
|
||||||
|
elif graph_db_provider == "kuzu":
|
||||||
|
query = """
|
||||||
|
MATCH (n:Node)
|
||||||
|
WHERE n.name IN ['Guido van Rossum', '1991']
|
||||||
|
RETURN n.name
|
||||||
|
"""
|
||||||
|
results = await graph_engine.query(query)
|
||||||
|
for result in results:
|
||||||
|
if result[0] == "Guido van Rossum":
|
||||||
|
guido_found = True
|
||||||
|
elif result[0] == "1991":
|
||||||
|
year_1991_found = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
||||||
|
|
||||||
# Query all summaries related to query.
|
|
||||||
summaries = await cognee.search(SearchType.SUMMARIES, "Python")
|
assert not guido_found, "Guido van Rossum should not be extracted as it's not in the custom graph model"
|
||||||
assert len(summaries) != 0, "Summaries search didn't return any results."
|
assert not year_1991_found, "1991 should not be extracted as it's not in the custom graph model"
|
||||||
print("Summary results are:")
|
|
||||||
for summary in summaries:
|
# Query for Field entities that might have been extracted (data analysis, web development, machine learning)
|
||||||
print(summary)
|
field_entities = []
|
||||||
|
|
||||||
|
if graph_db_provider in ["neo4j", "neptune", "neptune_analytics"]:
|
||||||
|
query = """
|
||||||
|
MATCH (n)
|
||||||
|
WHERE n.type = 'Field'
|
||||||
|
RETURN n.name as name
|
||||||
|
"""
|
||||||
|
results = await graph_engine.query(query)
|
||||||
|
field_entities = [r["name"] for r in results]
|
||||||
|
|
||||||
|
elif graph_db_provider == "kuzu":
|
||||||
|
query = """
|
||||||
|
MATCH (n:Node)
|
||||||
|
WHERE n.type = 'Field'
|
||||||
|
RETURN n.name
|
||||||
|
"""
|
||||||
|
results = await graph_engine.query(query)
|
||||||
|
field_entities = [r[0] for r in results if r[0]]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
||||||
|
|
||||||
chunks = await cognee.search(SearchType.CHUNKS, query_text="Python")
|
assert len(field_entities) > 0, f"No Field entities were extracted. Expected fields like 'data analysis', 'web development', 'machine learning' but got: {field_entities}"
|
||||||
assert len(chunks) != 0, "Chunks search didn't return any results."
|
|
||||||
print("Chunk results are:")
|
expected_fields = ["data analysis", "web development", "machine learning"]
|
||||||
for chunk in chunks:
|
found_expected_fields = [f for f in expected_fields if any(f in field.lower() for field in field_entities)]
|
||||||
print(chunk)
|
|
||||||
|
assert len(found_expected_fields) > 0, f"None of the expected Field entities were found. Expected at least one of {expected_fields}, but got: {field_entities}"
|
||||||
user = await get_default_user()
|
|
||||||
history = await get_history(user.id)
|
|
||||||
|
|
||||||
assert len(history) == 8, "Search history is not correct."
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue