chore: Format

This commit is contained in:
Andrej Milicevic 2025-10-15 10:08:40 +02:00
parent ac5fe4761b
commit 6cb54c94f1

View file

@ -17,27 +17,27 @@ from cognee.infrastructure.files.utils.open_data_file import open_data_file
async def extract_graphs(document_chunks):
"""
Extract graph, and check if entities are present
Extract graph, and check if entities are present
"""
extraction_results = await asyncio.gather(
*[
extract_content_graph(chunk.text, KnowledgeGraph)
for chunk in document_chunks
]
*[extract_content_graph(chunk.text, KnowledgeGraph) for chunk in document_chunks]
)
return all(
any(term in node.name.lower()
for extraction_result in extraction_results
for node in extraction_result.nodes)
for term in ("qubit", "algorithm", "superposition")
any(
term in node.name.lower()
for extraction_result in extraction_results
for node in extraction_result.nodes
)
for term in ("qubit", "algorithm", "superposition")
)
async def main():
"""
Test how well the entity extraction works. Repeat graph generation a few times.
If 80% or more graphs are correctly generated, the test passes.
Test how well the entity extraction works. Repeat graph generation a few times.
If 80% or more graphs are correctly generated, the test passes.
"""
file_path = os.path.join(
@ -47,7 +47,6 @@ async def main():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add("NLP is a subfield of computer science.")
original_file_path = await save_data_item_to_storage(file_path)
@ -66,31 +65,25 @@ async def main():
mime_type="text/plain",
name="quantum_text",
raw_data_location=file_path,
external_metadata=None
external_metadata=None,
)
document_chunks = []
async for chunk in extract_chunks_from_documents(
[text_document],
max_chunk_size=get_max_chunk_tokens(),
chunker=TextChunker
[text_document], max_chunk_size=get_max_chunk_tokens(), chunker=TextChunker
):
document_chunks.append(chunk)
document_chunks.append(chunk)
number_of_reps = 5
graph_results = await asyncio.gather(
*[
extract_graphs(document_chunks)
for _ in range(number_of_reps)
]
*[extract_graphs(document_chunks) for _ in range(number_of_reps)]
)
correct_graphs = [result for result in graph_results if result]
assert len(correct_graphs) >= 0.8 * number_of_reps
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())