test: cog 3168/add entity extraction tests (#1572)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> Added a test to check if graphs are correctly generated and the entities we expect are actually there. Could be improved with a bigger file and more assertions, depends on how heavy we want the test to be. ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
1e458b29b6
2 changed files with 118 additions and 5 deletions
34
.github/workflows/e2e_tests.yml
vendored
34
.github/workflows/e2e_tests.yml
vendored
|
|
@ -1,6 +1,4 @@
|
|||
name: Reusable Integration Tests
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
|
|
@ -267,8 +265,6 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_edge_ingestion.py
|
||||
|
||||
|
||||
|
||||
run_concurrent_subprocess_access_test:
|
||||
name: Concurrent Subprocess access test
|
||||
runs-on: ubuntu-latest
|
||||
|
|
@ -330,4 +326,32 @@ jobs:
|
|||
DB_PORT: 5432
|
||||
DB_USERNAME: cognee
|
||||
DB_PASSWORD: cognee
|
||||
run: uv run python ./cognee/tests/test_concurrent_subprocess_access.py
|
||||
run: uv run python ./cognee/tests/test_concurrent_subprocess_access.py
|
||||
|
||||
test-entity-extraction:
|
||||
name: Test Entity Extraction
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
- name: Run Entity Extraction Test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/tasks/entity_extraction/entity_extraction_test.py
|
||||
|
|
|
|||
|
|
@ -0,0 +1,89 @@
|
|||
import os
|
||||
import pathlib
|
||||
import asyncio
|
||||
|
||||
import cognee
|
||||
import cognee.modules.ingestion as ingestion
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||
from cognee.infrastructure.llm.extraction import extract_content_graph
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.tasks.documents import extract_chunks_from_documents
|
||||
from cognee.tasks.ingestion import save_data_item_to_storage
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
|
||||
|
||||
async def extract_graphs(document_chunks):
|
||||
"""
|
||||
Extract graph, and check if entities are present
|
||||
"""
|
||||
|
||||
extraction_results = await asyncio.gather(
|
||||
*[extract_content_graph(chunk.text, KnowledgeGraph) for chunk in document_chunks]
|
||||
)
|
||||
|
||||
return all(
|
||||
any(
|
||||
term in node.name.lower()
|
||||
for extraction_result in extraction_results
|
||||
for node in extraction_result.nodes
|
||||
)
|
||||
for term in ("qubit", "algorithm", "superposition")
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Test how well the entity extraction works. Repeat graph generation a few times.
|
||||
If 80% or more graphs are correctly generated, the test passes.
|
||||
"""
|
||||
|
||||
file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent.parent.parent, "test_data/Quantum_computers.txt"
|
||||
)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
await cognee.add("NLP is a subfield of computer science.")
|
||||
|
||||
original_file_path = await save_data_item_to_storage(file_path)
|
||||
|
||||
async with open_data_file(original_file_path) as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
|
||||
# data_id is the hash of original file contents + owner id to avoid duplicate data
|
||||
data_id = ingestion.identify(classified_data, await get_default_user())
|
||||
|
||||
await cognee.add(file_path)
|
||||
|
||||
text_document = TextDocument(
|
||||
id=data_id,
|
||||
type="text",
|
||||
mime_type="text/plain",
|
||||
name="quantum_text",
|
||||
raw_data_location=file_path,
|
||||
external_metadata=None,
|
||||
)
|
||||
|
||||
document_chunks = []
|
||||
async for chunk in extract_chunks_from_documents(
|
||||
[text_document], max_chunk_size=get_max_chunk_tokens(), chunker=TextChunker
|
||||
):
|
||||
document_chunks.append(chunk)
|
||||
|
||||
number_of_reps = 5
|
||||
|
||||
graph_results = await asyncio.gather(
|
||||
*[extract_graphs(document_chunks) for _ in range(number_of_reps)]
|
||||
)
|
||||
|
||||
correct_graphs = [result for result in graph_results if result]
|
||||
|
||||
assert len(correct_graphs) >= 0.8 * number_of_reps
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
Add table
Reference in a new issue