added check
This commit is contained in:
parent
a06b3fc7e4
commit
b7cf8f2f3c
1 changed files with 286 additions and 1 deletions
|
|
@ -8,6 +8,280 @@ from cognee.modules.search.types import SearchType
|
|||
from cognee.modules.users.methods import get_default_user, create_user
|
||||
from cognee.modules.users.permissions.methods import authorized_give_permission_on_datasets
|
||||
|
||||
|
||||
async def test_knowledge_graph_quality_with_gpt4o():
|
||||
"""
|
||||
Test that verifies all main concepts and entities from a specific document are found
|
||||
in the knowledge graph using GPT-4o model for high-quality entity extraction.
|
||||
|
||||
This test addresses the issue where HotPotQA questions may not reflect diminishing
|
||||
quality of knowledge graph creation after data model changes.
|
||||
"""
|
||||
|
||||
# Configure GPT-4o for best quality
|
||||
os.environ["LLM_MODEL"] = "gpt-4o"
|
||||
cognee.config.set_llm_model("gpt-4o")
|
||||
|
||||
# Ensure we have API key
|
||||
if not os.environ.get("LLM_API_KEY"):
|
||||
raise ValueError("LLM_API_KEY must be set for this test")
|
||||
|
||||
# Set up test directories
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_kg_quality")
|
||||
).resolve()
|
||||
)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_kg_quality")
|
||||
).resolve()
|
||||
)
|
||||
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
# Clean up before starting
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
# Get test document path
|
||||
test_document_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
|
||||
)
|
||||
|
||||
# Expected entities and concepts from the NLP document
|
||||
expected_entities = [
|
||||
"Natural language processing",
|
||||
"NLP",
|
||||
"computer science",
|
||||
"information retrieval",
|
||||
"machine learning",
|
||||
"neural network",
|
||||
"speech recognition",
|
||||
"natural-language understanding",
|
||||
"natural-language generation",
|
||||
"theoretical linguistics",
|
||||
"text corpora",
|
||||
"speech corpora",
|
||||
"statistical approaches",
|
||||
"probabilistic approaches",
|
||||
"rule-based approaches",
|
||||
"documents",
|
||||
"language",
|
||||
"computers",
|
||||
]
|
||||
|
||||
expected_concepts = [
|
||||
"NLP is a subfield of computer science",
|
||||
"NLP is interdisciplinary",
|
||||
"NLP involves processing natural language datasets",
|
||||
"NLP uses machine learning approaches",
|
||||
"NLP borrows ideas from theoretical linguistics",
|
||||
"NLP can extract information from documents",
|
||||
"NLP can categorize and organize documents",
|
||||
"NLP involves speech recognition",
|
||||
"NLP involves natural-language understanding",
|
||||
"NLP involves natural-language generation",
|
||||
"computers can understand document contents",
|
||||
"neural networks are used in NLP",
|
||||
"statistical approaches are used in NLP",
|
||||
]
|
||||
|
||||
print("=" * 80)
|
||||
print("KNOWLEDGE GRAPH QUALITY TEST WITH GPT-4o")
|
||||
print("=" * 80)
|
||||
print(f"Using model: {os.environ.get('LLM_MODEL', 'gpt-4o')}")
|
||||
print(f"Test document: {test_document_path}")
|
||||
print()
|
||||
|
||||
# Add and process the document
|
||||
print("Adding document to cognee...")
|
||||
await cognee.add([test_document_path], dataset_name="NLP_TEST")
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
print("Processing document with cognify...")
|
||||
await cognee.cognify(["NLP_TEST"], user=user)
|
||||
print("Document processing completed.")
|
||||
print()
|
||||
|
||||
# Test different search types to find entities and concepts
|
||||
search_types_to_test = [
|
||||
(SearchType.INSIGHTS, "Get entity relationships and connections"),
|
||||
(SearchType.GRAPH_COMPLETION, "Natural language completion with graph context"),
|
||||
(SearchType.CHUNKS, "Find relevant document chunks"),
|
||||
(SearchType.SUMMARIES, "Get content summaries"),
|
||||
]
|
||||
|
||||
all_found_results = {}
|
||||
|
||||
for search_type, description in search_types_to_test:
|
||||
print(f"Testing {search_type.value} search - {description}")
|
||||
print("-" * 60)
|
||||
|
||||
# Search for entities
|
||||
entity_results = await cognee.search(
|
||||
query_type=search_type,
|
||||
query_text="What are the main entities, concepts, and terms mentioned in this document?",
|
||||
user=user,
|
||||
top_k=20,
|
||||
)
|
||||
|
||||
# Search for relationships
|
||||
relationship_results = await cognee.search(
|
||||
query_type=search_type,
|
||||
query_text="What are the key relationships and connections between concepts in this document?",
|
||||
user=user,
|
||||
top_k=20,
|
||||
)
|
||||
|
||||
all_found_results[search_type.value] = {
|
||||
"entities": entity_results,
|
||||
"relationships": relationship_results,
|
||||
}
|
||||
|
||||
print(f"Entity search results ({len(entity_results)} items):")
|
||||
for i, result in enumerate(entity_results[:3]): # Show first 3 results
|
||||
print(f" {i + 1}. {result}")
|
||||
|
||||
print(f"Relationship search results ({len(relationship_results)} items):")
|
||||
for i, result in enumerate(relationship_results[:3]): # Show first 3 results
|
||||
print(f" {i + 1}. {result}")
|
||||
print()
|
||||
|
||||
# Analyze results and check for expected entities and concepts
|
||||
print("ANALYSIS: Expected vs Found")
|
||||
print("=" * 80)
|
||||
|
||||
# Combine all results into a single text for analysis
|
||||
all_results_text = ""
|
||||
for search_type, results in all_found_results.items():
|
||||
for result_type, result_list in results.items():
|
||||
all_results_text += f" {' '.join(str(r) for r in result_list)}"
|
||||
|
||||
all_results_text = all_results_text.lower()
|
||||
|
||||
print("ENTITY ANALYSIS:")
|
||||
print("-" * 40)
|
||||
found_entities = []
|
||||
missing_entities = []
|
||||
|
||||
for entity in expected_entities:
|
||||
entity_lower = entity.lower()
|
||||
# Check if entity or its variations are found
|
||||
if (
|
||||
entity_lower in all_results_text
|
||||
or entity_lower.replace("-", " ") in all_results_text
|
||||
or entity_lower.replace(" ", "-") in all_results_text
|
||||
):
|
||||
found_entities.append(entity)
|
||||
print(f"✓ FOUND: {entity}")
|
||||
else:
|
||||
missing_entities.append(entity)
|
||||
print(f"✗ MISSING: {entity}")
|
||||
|
||||
print()
|
||||
print("CONCEPT ANALYSIS:")
|
||||
print("-" * 40)
|
||||
found_concepts = []
|
||||
missing_concepts = []
|
||||
|
||||
for concept in expected_concepts:
|
||||
concept_lower = concept.lower()
|
||||
# Check if key parts of the concept are found
|
||||
concept_words = concept_lower.split()
|
||||
key_words = [
|
||||
word
|
||||
for word in concept_words
|
||||
if len(word) > 2
|
||||
and word not in ["the", "and", "are", "can", "involves", "uses", "from"]
|
||||
]
|
||||
|
||||
if len(key_words) > 0:
|
||||
found_key_words = sum(1 for word in key_words if word in all_results_text)
|
||||
coverage = found_key_words / len(key_words)
|
||||
|
||||
if coverage >= 0.6: # At least 60% of key words found
|
||||
found_concepts.append(concept)
|
||||
print(f"✓ FOUND: {concept} (coverage: {coverage:.1%})")
|
||||
else:
|
||||
missing_concepts.append(concept)
|
||||
print(f"✗ MISSING: {concept} (coverage: {coverage:.1%})")
|
||||
else:
|
||||
missing_concepts.append(concept)
|
||||
print(f"✗ MISSING: {concept} (no key words)")
|
||||
|
||||
print()
|
||||
print("SUMMARY:")
|
||||
print("=" * 40)
|
||||
print(f"Expected entities: {len(expected_entities)}")
|
||||
print(f"Found entities: {len(found_entities)}")
|
||||
print(f"Missing entities: {len(missing_entities)}")
|
||||
print(f"Entity coverage: {len(found_entities) / len(expected_entities):.1%}")
|
||||
print()
|
||||
print(f"Expected concepts: {len(expected_concepts)}")
|
||||
print(f"Found concepts: {len(found_concepts)}")
|
||||
print(f"Missing concepts: {len(missing_concepts)}")
|
||||
print(f"Concept coverage: {len(found_concepts) / len(expected_concepts):.1%}")
|
||||
print()
|
||||
|
||||
# Test assertions
|
||||
entity_coverage = len(found_entities) / len(expected_entities)
|
||||
concept_coverage = len(found_concepts) / len(expected_concepts)
|
||||
|
||||
print("QUALITY ASSESSMENT:")
|
||||
print("-" * 40)
|
||||
|
||||
# We expect high coverage with GPT-4o
|
||||
min_entity_coverage = 0.70 # At least 70% of entities should be found
|
||||
min_concept_coverage = 0.60 # At least 60% of concepts should be found
|
||||
|
||||
if entity_coverage >= min_entity_coverage:
|
||||
print(
|
||||
f"✓ PASS: Entity coverage ({entity_coverage:.1%}) meets minimum requirement ({min_entity_coverage:.1%})"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"✗ FAIL: Entity coverage ({entity_coverage:.1%}) below minimum requirement ({min_entity_coverage:.1%})"
|
||||
)
|
||||
|
||||
if concept_coverage >= min_concept_coverage:
|
||||
print(
|
||||
f"✓ PASS: Concept coverage ({concept_coverage:.1%}) meets minimum requirement ({min_concept_coverage:.1%})"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"✗ FAIL: Concept coverage ({concept_coverage:.1%}) below minimum requirement ({min_concept_coverage:.1%})"
|
||||
)
|
||||
|
||||
overall_quality = (entity_coverage + concept_coverage) / 2
|
||||
print(f"Overall quality score: {overall_quality:.1%}")
|
||||
|
||||
# Assert that we have acceptable quality
|
||||
assert entity_coverage >= min_entity_coverage, (
|
||||
f"Entity coverage {entity_coverage:.1%} below minimum {min_entity_coverage:.1%}"
|
||||
)
|
||||
assert concept_coverage >= min_concept_coverage, (
|
||||
f"Concept coverage {concept_coverage:.1%} below minimum {min_concept_coverage:.1%}"
|
||||
)
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("KNOWLEDGE GRAPH QUALITY TEST COMPLETED SUCCESSFULLY")
|
||||
print("=" * 80)
|
||||
|
||||
return {
|
||||
"entity_coverage": entity_coverage,
|
||||
"concept_coverage": concept_coverage,
|
||||
"overall_quality": overall_quality,
|
||||
"found_entities": found_entities,
|
||||
"missing_entities": missing_entities,
|
||||
"found_concepts": found_concepts,
|
||||
"missing_concepts": missing_concepts,
|
||||
}
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
|
|
@ -197,7 +471,18 @@ async def main():
|
|||
await cognee.delete([explanation_file_path], dataset_id=test_user_dataset_id, user=default_user)
|
||||
|
||||
|
||||
async def main_quality_test():
|
||||
"""Main function to run the knowledge graph quality test"""
|
||||
await test_knowledge_graph_quality_with_gpt4o()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
asyncio.run(main())
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "quality":
|
||||
print("Running Knowledge Graph Quality Test...")
|
||||
asyncio.run(main_quality_test())
|
||||
else:
|
||||
print("Running Permissions Test...")
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue