diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 305e21218..9548ef493 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -1,6 +1,4 @@ name: Reusable Integration Tests -permissions: - contents: read on: workflow_call: @@ -267,8 +265,6 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_edge_ingestion.py - - run_concurrent_subprocess_access_test: name: Concurrent Subprocess access test runs-on: ubuntu-latest @@ -332,50 +328,24 @@ jobs: DB_PASSWORD: cognee run: uv run python ./cognee/tests/test_concurrent_subprocess_access.py - run_conversation_sessions_test: - name: Conversation sessions test - runs-on: ubuntu-latest - defaults: - run: - shell: bash - services: - postgres: - image: pgvector/pgvector:pg17 - env: - POSTGRES_USER: cognee - POSTGRES_PASSWORD: cognee - POSTGRES_DB: cognee_db - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - - redis: - image: redis:7 - ports: - - 6379:6379 - options: >- - --health-cmd "redis-cli ping" - --health-interval 5s - --health-timeout 3s - --health-retries 5 - + test-entity-extraction: + name: Test Entity Extraction + runs-on: ubuntu-22.04 steps: - - name: Checkout repository + - name: Check out repository uses: actions/checkout@v4 - name: Cognee Setup uses: ./.github/actions/cognee_setup with: python-version: '3.11.x' - extra-dependencies: "postgres redis" - - name: Run Conversation session tests + - name: Dependencies already installed + run: echo "Dependencies already installed in setup" + + - name: Run Entity Extraction Test env: - ENV: dev + ENV: 'dev' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} @@ -384,12 +354,4 @@ jobs: EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - GRAPH_DATABASE_PROVIDER: 'kuzu' - CACHING: true - DB_PROVIDER: 'postgres' - DB_NAME: 'cognee_db' - DB_HOST: '127.0.0.1' - DB_PORT: 5432 - DB_USERNAME: cognee - DB_PASSWORD: cognee - run: uv run python ./cognee/tests/test_conversation_history.py \ No newline at end of file + run: uv run python ./cognee/tests/tasks/entity_extraction/entity_extraction_test.py diff --git a/cognee/tests/tasks/entity_extraction/entity_extraction_test.py b/cognee/tests/tasks/entity_extraction/entity_extraction_test.py new file mode 100644 index 000000000..39e883e09 --- /dev/null +++ b/cognee/tests/tasks/entity_extraction/entity_extraction_test.py @@ -0,0 +1,89 @@ +import os +import pathlib +import asyncio + +import cognee +import cognee.modules.ingestion as ingestion +from cognee.infrastructure.llm import get_max_chunk_tokens +from cognee.infrastructure.llm.extraction import extract_content_graph +from cognee.modules.chunking.TextChunker import TextChunker +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.users.methods import get_default_user +from cognee.shared.data_models import KnowledgeGraph +from cognee.tasks.documents import extract_chunks_from_documents +from cognee.tasks.ingestion import save_data_item_to_storage +from cognee.infrastructure.files.utils.open_data_file import open_data_file + + +async def extract_graphs(document_chunks): + """ + Extract graph, and check if entities are present + """ + + extraction_results = await asyncio.gather( + *[extract_content_graph(chunk.text, KnowledgeGraph) for chunk in document_chunks] + ) + + return all( + any( + term in node.name.lower() + for extraction_result in extraction_results + for node in extraction_result.nodes + ) + for term in ("qubit", "algorithm", "superposition") + ) + + +async def main(): + """ + Test how well the entity extraction works. Repeat graph generation a few times. + If 80% or more graphs are correctly generated, the test passes. + """ + + file_path = os.path.join( + pathlib.Path(__file__).parent.parent.parent, "test_data/Quantum_computers.txt" + ) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + await cognee.add("NLP is a subfield of computer science.") + + original_file_path = await save_data_item_to_storage(file_path) + + async with open_data_file(original_file_path) as file: + classified_data = ingestion.classify(file) + + # data_id is the hash of original file contents + owner id to avoid duplicate data + data_id = ingestion.identify(classified_data, await get_default_user()) + + await cognee.add(file_path) + + text_document = TextDocument( + id=data_id, + type="text", + mime_type="text/plain", + name="quantum_text", + raw_data_location=file_path, + external_metadata=None, + ) + + document_chunks = [] + async for chunk in extract_chunks_from_documents( + [text_document], max_chunk_size=get_max_chunk_tokens(), chunker=TextChunker + ): + document_chunks.append(chunk) + + number_of_reps = 5 + + graph_results = await asyncio.gather( + *[extract_graphs(document_chunks) for _ in range(number_of_reps)] + ) + + correct_graphs = [result for result in graph_results if result] + + assert len(correct_graphs) >= 0.8 * number_of_reps + + +if __name__ == "__main__": + asyncio.run(main())