chore: Format
This commit is contained in:
parent
ac5fe4761b
commit
6cb54c94f1
1 changed files with 17 additions and 24 deletions
|
|
@ -17,27 +17,27 @@ from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
|
|
||||||
async def extract_graphs(document_chunks):
|
async def extract_graphs(document_chunks):
|
||||||
"""
|
"""
|
||||||
Extract graph, and check if entities are present
|
Extract graph, and check if entities are present
|
||||||
"""
|
"""
|
||||||
|
|
||||||
extraction_results = await asyncio.gather(
|
extraction_results = await asyncio.gather(
|
||||||
*[
|
*[extract_content_graph(chunk.text, KnowledgeGraph) for chunk in document_chunks]
|
||||||
extract_content_graph(chunk.text, KnowledgeGraph)
|
|
||||||
for chunk in document_chunks
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return all(
|
return all(
|
||||||
any(term in node.name.lower()
|
any(
|
||||||
for extraction_result in extraction_results
|
term in node.name.lower()
|
||||||
for node in extraction_result.nodes)
|
for extraction_result in extraction_results
|
||||||
for term in ("qubit", "algorithm", "superposition")
|
for node in extraction_result.nodes
|
||||||
)
|
)
|
||||||
|
for term in ("qubit", "algorithm", "superposition")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""
|
"""
|
||||||
Test how well the entity extraction works. Repeat graph generation a few times.
|
Test how well the entity extraction works. Repeat graph generation a few times.
|
||||||
If 80% or more graphs are correctly generated, the test passes.
|
If 80% or more graphs are correctly generated, the test passes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
file_path = os.path.join(
|
file_path = os.path.join(
|
||||||
|
|
@ -47,7 +47,6 @@ async def main():
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
|
||||||
await cognee.add("NLP is a subfield of computer science.")
|
await cognee.add("NLP is a subfield of computer science.")
|
||||||
|
|
||||||
original_file_path = await save_data_item_to_storage(file_path)
|
original_file_path = await save_data_item_to_storage(file_path)
|
||||||
|
|
@ -66,31 +65,25 @@ async def main():
|
||||||
mime_type="text/plain",
|
mime_type="text/plain",
|
||||||
name="quantum_text",
|
name="quantum_text",
|
||||||
raw_data_location=file_path,
|
raw_data_location=file_path,
|
||||||
external_metadata=None
|
external_metadata=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
document_chunks = []
|
document_chunks = []
|
||||||
async for chunk in extract_chunks_from_documents(
|
async for chunk in extract_chunks_from_documents(
|
||||||
[text_document],
|
[text_document], max_chunk_size=get_max_chunk_tokens(), chunker=TextChunker
|
||||||
max_chunk_size=get_max_chunk_tokens(),
|
|
||||||
chunker=TextChunker
|
|
||||||
):
|
):
|
||||||
document_chunks.append(chunk)
|
document_chunks.append(chunk)
|
||||||
|
|
||||||
|
|
||||||
number_of_reps = 5
|
number_of_reps = 5
|
||||||
|
|
||||||
graph_results = await asyncio.gather(
|
graph_results = await asyncio.gather(
|
||||||
*[
|
*[extract_graphs(document_chunks) for _ in range(number_of_reps)]
|
||||||
extract_graphs(document_chunks)
|
|
||||||
for _ in range(number_of_reps)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
correct_graphs = [result for result in graph_results if result]
|
correct_graphs = [result for result in graph_results if result]
|
||||||
|
|
||||||
assert len(correct_graphs) >= 0.8 * number_of_reps
|
assert len(correct_graphs) >= 0.8 * number_of_reps
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue