fix: Resolve issues with data label PR, add tests and upgrade migration
This commit is contained in:
parent
56b03c89f3
commit
b77961b0f1
5 changed files with 117 additions and 7 deletions
25
.github/workflows/e2e_tests.yml
vendored
25
.github/workflows/e2e_tests.yml
vendored
|
|
@ -315,6 +315,31 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_multi_tenancy.py
|
||||
|
||||
test-data-label:
|
||||
name: Test adding of label for data in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run custom data label test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_custom_data_label.py
|
||||
|
||||
test-graph-edges:
|
||||
name: Test graph edge ingestion
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
|
|||
|
|
@ -13,15 +13,26 @@ import sqlalchemy as sa
|
|||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a1b2c3d4e5f6"
|
||||
down_revision: Union[str, None] = "211ab850ef3d"
|
||||
down_revision: Union[str, None] = "46a6ce2bd2b2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"data",
|
||||
sa.Column("label", sa.String(), nullable=True)
|
||||
)
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
label_column = _get_column(insp, "data", "label")
|
||||
if not label_column:
|
||||
op.add_column("data", sa.Column("label", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("data", "label")
|
||||
op.drop_column("data", "label")
|
||||
|
|
|
|||
|
|
@ -10,13 +10,14 @@ from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
|||
)
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
||||
from cognee.tasks.ingestion.data_item import DataItem
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def add(
|
||||
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
|
||||
data: Union[BinaryIO, list[BinaryIO], str, list[str], DataItem, list[DataItem]],
|
||||
dataset_name: str = "main_dataset",
|
||||
user: User = None,
|
||||
node_set: Optional[List[str]] = None,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from cognee.shared.logging_utils import get_logger
|
|||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from cognee.tasks.web_scraper.utils import fetch_page_content
|
||||
from cognee.tasks.ingestion.data_item import DataItem
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -95,5 +96,9 @@ async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str
|
|||
# data is text, save it to data storage and return the file path
|
||||
return await save_data_to_file(data_item)
|
||||
|
||||
if isinstance(data_item, DataItem):
|
||||
# If instance is DataItem use the underlying data
|
||||
return await save_data_item_to_storage(data_item.data)
|
||||
|
||||
# data is not a supported type
|
||||
raise IngestionError(message=f"Data type not supported: {type(data_item)}")
|
||||
|
|
|
|||
68
cognee/tests/test_custom_data_label.py
Normal file
68
cognee/tests/test_custom_data_label.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
||||
|
||||
async def main():
|
||||
# Create a clean slate for cognee -- reset data and system state
|
||||
print("Resetting cognee data...")
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
print("Data reset complete.\n")
|
||||
|
||||
# cognee knowledge graph will be created based on this text
|
||||
text = """
|
||||
Natural language processing (NLP) is an interdisciplinary
|
||||
subfield of computer science and information retrieval.
|
||||
"""
|
||||
from cognee.tasks.ingestion.data_item import DataItem
|
||||
|
||||
test_item = DataItem(text, "test_item")
|
||||
# Add the text, and make it available for cognify
|
||||
await cognee.add(test_item)
|
||||
|
||||
# Use LLMs and cognee to create knowledge graph
|
||||
ret_val = await cognee.cognify()
|
||||
|
||||
query_text = "Tell me about NLP"
|
||||
print(f"Searching cognee for insights with query: '{query_text}'")
|
||||
# Query cognee for insights on the added text
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text=query_text
|
||||
)
|
||||
|
||||
print("Search results:")
|
||||
# Display results
|
||||
for result_text in search_results:
|
||||
print(result_text)
|
||||
|
||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||
|
||||
for pipeline in ret_val.values():
|
||||
dataset_id = pipeline.dataset_id
|
||||
|
||||
dataset_data = await get_dataset_data(dataset_id=dataset_id)
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
data = [
|
||||
dict(
|
||||
**jsonable_encoder(data),
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
for data in dataset_data
|
||||
]
|
||||
|
||||
# Check if label is properly added and stored
|
||||
assert data[0]["label"] == "test_item"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=ERROR)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
Loading…
Add table
Reference in a new issue