Merge branch 'dev' into feature/web_scraping_connector_task
This commit is contained in:
commit
4e5c681e62
30 changed files with 719 additions and 12821 deletions
10
.github/workflows/db_examples_tests.yml
vendored
10
.github/workflows/db_examples_tests.yml
vendored
|
|
@ -54,6 +54,10 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ inputs.python-version }}
|
||||||
|
|
||||||
|
- name: Setup Neo4j with GDS
|
||||||
|
uses: ./.github/actions/setup_neo4j
|
||||||
|
id: neo4j
|
||||||
|
|
||||||
- name: Run Neo4j Example
|
- name: Run Neo4j Example
|
||||||
env:
|
env:
|
||||||
ENV: dev
|
ENV: dev
|
||||||
|
|
@ -66,9 +70,9 @@ jobs:
|
||||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
GRAPH_DATABASE_PROVIDER: "neo4j"
|
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||||
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||||
GRAPH_DATABASE_USERNAME: "neo4j"
|
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||||
run: |
|
run: |
|
||||||
uv run python examples/database_examples/neo4j_example.py
|
uv run python examples/database_examples/neo4j_example.py
|
||||||
|
|
||||||
|
|
|
||||||
73
.github/workflows/distributed_test.yml
vendored
Normal file
73
.github/workflows/distributed_test.yml
vendored
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
name: Distributed Cognee test with modal
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
python-version:
|
||||||
|
required: false
|
||||||
|
type: string
|
||||||
|
default: '3.11.x'
|
||||||
|
secrets:
|
||||||
|
LLM_MODEL:
|
||||||
|
required: true
|
||||||
|
LLM_ENDPOINT:
|
||||||
|
required: true
|
||||||
|
LLM_API_KEY:
|
||||||
|
required: true
|
||||||
|
LLM_API_VERSION:
|
||||||
|
required: true
|
||||||
|
EMBEDDING_MODEL:
|
||||||
|
required: true
|
||||||
|
EMBEDDING_ENDPOINT:
|
||||||
|
required: true
|
||||||
|
EMBEDDING_API_KEY:
|
||||||
|
required: true
|
||||||
|
EMBEDDING_API_VERSION:
|
||||||
|
required: true
|
||||||
|
OPENAI_API_KEY:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run-server-start-test:
|
||||||
|
name: Distributed Cognee test (Modal)
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
extra-dependencies: "distributed postgres"
|
||||||
|
|
||||||
|
- name: Run Distributed Cognee (Modal)
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
|
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||||
|
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||||
|
MODAL_SECRET_NAME: ${{ secrets.MODAL_SECRET_NAME }}
|
||||||
|
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||||
|
GRAPH_DATABASE_URL: ${{ secrets.AZURE_NEO4j_URL }}
|
||||||
|
GRAPH_DATABASE_USERNAME: ${{ secrets.AZURE_NEO4J_USERNAME }}
|
||||||
|
GRAPH_DATABASE_PASSWORD: ${{ secrets.AZURE_NEO4J_PW }}
|
||||||
|
DB_PROVIDER: "postgres"
|
||||||
|
DB_NAME: ${{ secrets.AZURE_POSTGRES_DB_NAME }}
|
||||||
|
DB_HOST: ${{ secrets.AZURE_POSTGRES_HOST }}
|
||||||
|
DB_PORT: ${{ secrets.AZURE_POSTGRES_PORT }}
|
||||||
|
DB_USERNAME: ${{ secrets.AZURE_POSTGRES_USERNAME }}
|
||||||
|
DB_PASSWORD: ${{ secrets.AZURE_POSTGRES_PW }}
|
||||||
|
VECTOR_DB_PROVIDER: "pgvector"
|
||||||
|
COGNEE_DISTRIBUTED: "true"
|
||||||
|
run: uv run modal run ./distributed/entrypoint.py
|
||||||
10
.github/workflows/graph_db_tests.yml
vendored
10
.github/workflows/graph_db_tests.yml
vendored
|
|
@ -71,6 +71,10 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ inputs.python-version }}
|
||||||
|
|
||||||
|
- name: Setup Neo4j with GDS
|
||||||
|
uses: ./.github/actions/setup_neo4j
|
||||||
|
id: neo4j
|
||||||
|
|
||||||
- name: Run default Neo4j
|
- name: Run default Neo4j
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
|
|
@ -83,9 +87,9 @@ jobs:
|
||||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
GRAPH_DATABASE_PROVIDER: "neo4j"
|
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||||
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||||
GRAPH_DATABASE_USERNAME: "neo4j"
|
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||||
run: uv run python ./cognee/tests/test_neo4j.py
|
run: uv run python ./cognee/tests/test_neo4j.py
|
||||||
|
|
||||||
- name: Run Weighted Edges Tests with Neo4j
|
- name: Run Weighted Edges Tests with Neo4j
|
||||||
|
|
|
||||||
|
|
@ -186,6 +186,10 @@ jobs:
|
||||||
python-version: '3.11.x'
|
python-version: '3.11.x'
|
||||||
extra-dependencies: "postgres"
|
extra-dependencies: "postgres"
|
||||||
|
|
||||||
|
- name: Setup Neo4j with GDS
|
||||||
|
uses: ./.github/actions/setup_neo4j
|
||||||
|
id: neo4j
|
||||||
|
|
||||||
- name: Install specific db dependency
|
- name: Install specific db dependency
|
||||||
run: echo "Dependencies already installed in setup"
|
run: echo "Dependencies already installed in setup"
|
||||||
|
|
||||||
|
|
@ -206,9 +210,9 @@ jobs:
|
||||||
env:
|
env:
|
||||||
ENV: 'dev'
|
ENV: 'dev'
|
||||||
GRAPH_DATABASE_PROVIDER: "neo4j"
|
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||||
GRAPH_DATABASE_URL: ${{ secrets.NEO4J_API_URL }}
|
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.NEO4J_API_KEY }}
|
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||||
GRAPH_DATABASE_USERNAME: "neo4j"
|
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||||
|
|
||||||
LLM_PROVIDER: openai
|
LLM_PROVIDER: openai
|
||||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
|
|
||||||
47
.github/workflows/search_db_tests.yml
vendored
47
.github/workflows/search_db_tests.yml
vendored
|
|
@ -51,20 +51,6 @@ jobs:
|
||||||
name: Search test for Neo4j/LanceDB/Sqlite
|
name: Search test for Neo4j/LanceDB/Sqlite
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
|
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
|
||||||
services:
|
|
||||||
neo4j:
|
|
||||||
image: neo4j:5.11
|
|
||||||
env:
|
|
||||||
NEO4J_AUTH: neo4j/pleaseletmein
|
|
||||||
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
|
|
||||||
ports:
|
|
||||||
- 7474:7474
|
|
||||||
- 7687:7687
|
|
||||||
options: >-
|
|
||||||
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
|
|
||||||
--health-interval=10s
|
|
||||||
--health-timeout=5s
|
|
||||||
--health-retries=5
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
|
|
@ -77,6 +63,10 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ inputs.python-version }}
|
||||||
|
|
||||||
|
- name: Setup Neo4j with GDS
|
||||||
|
uses: ./.github/actions/setup_neo4j
|
||||||
|
id: neo4j
|
||||||
|
|
||||||
- name: Dependencies already installed
|
- name: Dependencies already installed
|
||||||
run: echo "Dependencies already installed in setup"
|
run: echo "Dependencies already installed in setup"
|
||||||
|
|
||||||
|
|
@ -94,9 +84,9 @@ jobs:
|
||||||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||||
VECTOR_DB_PROVIDER: 'lancedb'
|
VECTOR_DB_PROVIDER: 'lancedb'
|
||||||
DB_PROVIDER: 'sqlite'
|
DB_PROVIDER: 'sqlite'
|
||||||
GRAPH_DATABASE_URL: bolt://localhost:7687
|
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||||
GRAPH_DATABASE_USERNAME: neo4j
|
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||||
GRAPH_DATABASE_PASSWORD: pleaseletmein
|
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||||
run: uv run python ./cognee/tests/test_search_db.py
|
run: uv run python ./cognee/tests/test_search_db.py
|
||||||
|
|
||||||
run-kuzu-pgvector-postgres-search-tests:
|
run-kuzu-pgvector-postgres-search-tests:
|
||||||
|
|
@ -158,19 +148,6 @@ jobs:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
|
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }}
|
||||||
services:
|
services:
|
||||||
neo4j:
|
|
||||||
image: neo4j:5.11
|
|
||||||
env:
|
|
||||||
NEO4J_AUTH: neo4j/pleaseletmein
|
|
||||||
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
|
|
||||||
ports:
|
|
||||||
- 7474:7474
|
|
||||||
- 7687:7687
|
|
||||||
options: >-
|
|
||||||
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
|
|
||||||
--health-interval=10s
|
|
||||||
--health-timeout=5s
|
|
||||||
--health-retries=5
|
|
||||||
postgres:
|
postgres:
|
||||||
image: pgvector/pgvector:pg17
|
image: pgvector/pgvector:pg17
|
||||||
env:
|
env:
|
||||||
|
|
@ -196,6 +173,10 @@ jobs:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ inputs.python-version }}
|
||||||
extra-dependencies: "postgres"
|
extra-dependencies: "postgres"
|
||||||
|
|
||||||
|
- name: Setup Neo4j with GDS
|
||||||
|
uses: ./.github/actions/setup_neo4j
|
||||||
|
id: neo4j
|
||||||
|
|
||||||
- name: Dependencies already installed
|
- name: Dependencies already installed
|
||||||
run: echo "Dependencies already installed in setup"
|
run: echo "Dependencies already installed in setup"
|
||||||
|
|
||||||
|
|
@ -213,9 +194,9 @@ jobs:
|
||||||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||||
VECTOR_DB_PROVIDER: 'pgvector'
|
VECTOR_DB_PROVIDER: 'pgvector'
|
||||||
DB_PROVIDER: 'postgres'
|
DB_PROVIDER: 'postgres'
|
||||||
GRAPH_DATABASE_URL: bolt://localhost:7687
|
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||||
GRAPH_DATABASE_USERNAME: neo4j
|
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||||
GRAPH_DATABASE_PASSWORD: pleaseletmein
|
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||||
DB_NAME: cognee_db
|
DB_NAME: cognee_db
|
||||||
DB_HOST: 127.0.0.1
|
DB_HOST: 127.0.0.1
|
||||||
DB_PORT: 5432
|
DB_PORT: 5432
|
||||||
|
|
|
||||||
24
.github/workflows/temporal_graph_tests.yml
vendored
24
.github/workflows/temporal_graph_tests.yml
vendored
|
|
@ -51,20 +51,6 @@ jobs:
|
||||||
name: Temporal Graph test Neo4j (lancedb + sqlite)
|
name: Temporal Graph test Neo4j (lancedb + sqlite)
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
|
if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }}
|
||||||
services:
|
|
||||||
neo4j:
|
|
||||||
image: neo4j:5.11
|
|
||||||
env:
|
|
||||||
NEO4J_AUTH: neo4j/pleaseletmein
|
|
||||||
NEO4J_PLUGINS: '["apoc","graph-data-science"]'
|
|
||||||
ports:
|
|
||||||
- 7474:7474
|
|
||||||
- 7687:7687
|
|
||||||
options: >-
|
|
||||||
--health-cmd="cypher-shell -u neo4j -p pleaseletmein 'RETURN 1'"
|
|
||||||
--health-interval=10s
|
|
||||||
--health-timeout=5s
|
|
||||||
--health-retries=5
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
|
|
@ -77,6 +63,10 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ inputs.python-version }}
|
||||||
|
|
||||||
|
- name: Setup Neo4j with GDS
|
||||||
|
uses: ./.github/actions/setup_neo4j
|
||||||
|
id: neo4j
|
||||||
|
|
||||||
- name: Dependencies already installed
|
- name: Dependencies already installed
|
||||||
run: echo "Dependencies already installed in setup"
|
run: echo "Dependencies already installed in setup"
|
||||||
|
|
||||||
|
|
@ -94,9 +84,9 @@ jobs:
|
||||||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||||
VECTOR_DB_PROVIDER: 'lancedb'
|
VECTOR_DB_PROVIDER: 'lancedb'
|
||||||
DB_PROVIDER: 'sqlite'
|
DB_PROVIDER: 'sqlite'
|
||||||
GRAPH_DATABASE_URL: bolt://localhost:7687
|
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||||
GRAPH_DATABASE_USERNAME: neo4j
|
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||||
GRAPH_DATABASE_PASSWORD: pleaseletmein
|
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||||
run: uv run python ./cognee/tests/test_temporal_graph.py
|
run: uv run python ./cognee/tests/test_temporal_graph.py
|
||||||
|
|
||||||
run_temporal_graph_kuzu_postgres_pgvector:
|
run_temporal_graph_kuzu_postgres_pgvector:
|
||||||
|
|
|
||||||
8
.github/workflows/test_suites.yml
vendored
8
.github/workflows/test_suites.yml
vendored
|
|
@ -27,6 +27,12 @@ jobs:
|
||||||
uses: ./.github/workflows/e2e_tests.yml
|
uses: ./.github/workflows/e2e_tests.yml
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
distributed-tests:
|
||||||
|
name: Distributed Cognee Test
|
||||||
|
needs: [ basic-tests, e2e-tests, graph-db-tests ]
|
||||||
|
uses: ./.github/workflows/distributed_test.yml
|
||||||
|
secrets: inherit
|
||||||
|
|
||||||
cli-tests:
|
cli-tests:
|
||||||
name: CLI Tests
|
name: CLI Tests
|
||||||
uses: ./.github/workflows/cli_tests.yml
|
uses: ./.github/workflows/cli_tests.yml
|
||||||
|
|
@ -104,7 +110,7 @@ jobs:
|
||||||
|
|
||||||
db-examples-tests:
|
db-examples-tests:
|
||||||
name: DB Examples Tests
|
name: DB Examples Tests
|
||||||
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests]
|
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests, distributed-tests]
|
||||||
uses: ./.github/workflows/db_examples_tests.yml
|
uses: ./.github/workflows/db_examples_tests.yml
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
|
|
||||||
7
.github/workflows/weighted_edges_tests.yml
vendored
7
.github/workflows/weighted_edges_tests.yml
vendored
|
|
@ -86,12 +86,19 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Setup Neo4j with GDS
|
||||||
|
uses: ./.github/actions/setup_neo4j
|
||||||
|
id: neo4j
|
||||||
|
|
||||||
- name: Dependencies already installed
|
- name: Dependencies already installed
|
||||||
run: echo "Dependencies already installed in setup"
|
run: echo "Dependencies already installed in setup"
|
||||||
|
|
||||||
- name: Run Weighted Edges Tests
|
- name: Run Weighted Edges Tests
|
||||||
env:
|
env:
|
||||||
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
|
GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }}
|
||||||
|
GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }}
|
||||||
|
GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }}
|
||||||
|
GRAPH_DATABASE_PASSWORD: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-password || '' }}
|
||||||
run: |
|
run: |
|
||||||
uv run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short
|
uv run pytest cognee/tests/unit/interfaces/graph/test_weighted_edges.py -v --tb=short
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -217,10 +217,24 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
||||||
|
|
||||||
const [graphShape, setGraphShape] = useState<string>();
|
const [graphShape, setGraphShape] = useState<string>();
|
||||||
|
|
||||||
|
const zoomToFit: ForceGraphMethods["zoomToFit"] = (
|
||||||
|
durationMs?: number,
|
||||||
|
padding?: number,
|
||||||
|
nodeFilter?: (node: NodeObject) => boolean
|
||||||
|
) => {
|
||||||
|
if (!graphRef.current) {
|
||||||
|
console.warn("GraphVisualization: graphRef not ready yet");
|
||||||
|
return undefined as any;
|
||||||
|
}
|
||||||
|
|
||||||
|
return graphRef.current.zoomToFit?.(durationMs, padding, nodeFilter);
|
||||||
|
};
|
||||||
|
|
||||||
useImperativeHandle(ref, () => ({
|
useImperativeHandle(ref, () => ({
|
||||||
zoomToFit: graphRef.current!.zoomToFit,
|
zoomToFit,
|
||||||
setGraphShape: setGraphShape,
|
setGraphShape,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
|
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
|
||||||
|
|
|
||||||
|
|
@ -117,5 +117,4 @@ async def add_rule_associations(data: str, rules_nodeset_name: str):
|
||||||
|
|
||||||
if len(edges_to_save) > 0:
|
if len(edges_to_save) > 0:
|
||||||
await graph_engine.add_edges(edges_to_save)
|
await graph_engine.add_edges(edges_to_save)
|
||||||
|
await index_graph_edges(edges_to_save)
|
||||||
await index_graph_edges()
|
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
auth=auth,
|
auth=auth,
|
||||||
max_connection_lifetime=120,
|
max_connection_lifetime=120,
|
||||||
notifications_min_severity="OFF",
|
notifications_min_severity="OFF",
|
||||||
|
keep_alive=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
@ -205,7 +206,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
{
|
{
|
||||||
"node_id": str(node.id),
|
"node_id": str(node.id),
|
||||||
"label": type(node).__name__,
|
"label": type(node).__name__,
|
||||||
"properties": self.serialize_properties(node.model_dump()),
|
"properties": self.serialize_properties(dict(node)),
|
||||||
}
|
}
|
||||||
for node in nodes
|
for node in nodes
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
|
||||||
logger = get_logger("deadlock_retry")
|
logger = get_logger("deadlock_retry")
|
||||||
|
|
||||||
|
|
||||||
def deadlock_retry(max_retries=5):
|
def deadlock_retry(max_retries=10):
|
||||||
"""
|
"""
|
||||||
Decorator that automatically retries an asynchronous function when rate limit errors occur.
|
Decorator that automatically retries an asynchronous function when rate limit errors occur.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,10 @@ async def get_dataset_data(dataset_id: UUID) -> list[Data]:
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Data).join(Data.datasets).filter((Dataset.id == dataset_id))
|
select(Data)
|
||||||
|
.join(Data.datasets)
|
||||||
|
.filter((Dataset.id == dataset_id))
|
||||||
|
.order_by(Data.data_size.desc())
|
||||||
)
|
)
|
||||||
|
|
||||||
data = list(result.scalars().all())
|
data = list(result.scalars().all())
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from typing import Optional
|
||||||
|
|
||||||
class TableRow(DataPoint):
|
class TableRow(DataPoint):
|
||||||
name: str
|
name: str
|
||||||
is_a: Optional[TableType] = None
|
|
||||||
description: str
|
description: str
|
||||||
properties: str
|
properties: str
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,35 +4,28 @@ import asyncio
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
import cognee.modules.ingestion as ingestion
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
|
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.data.models import Data
|
|
||||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.pipelines.utils import generate_pipeline_id
|
from cognee.modules.pipelines.utils import generate_pipeline_id
|
||||||
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
|
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
|
||||||
from cognee.tasks.ingestion import save_data_item_to_storage, resolve_data_directories
|
from cognee.tasks.ingestion import resolve_data_directories
|
||||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||||
PipelineRunCompleted,
|
PipelineRunCompleted,
|
||||||
PipelineRunErrored,
|
PipelineRunErrored,
|
||||||
PipelineRunStarted,
|
PipelineRunStarted,
|
||||||
PipelineRunYield,
|
|
||||||
PipelineRunAlreadyCompleted,
|
|
||||||
)
|
)
|
||||||
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
|
|
||||||
|
|
||||||
from cognee.modules.pipelines.operations import (
|
from cognee.modules.pipelines.operations import (
|
||||||
log_pipeline_run_start,
|
log_pipeline_run_start,
|
||||||
log_pipeline_run_complete,
|
log_pipeline_run_complete,
|
||||||
log_pipeline_run_error,
|
log_pipeline_run_error,
|
||||||
)
|
)
|
||||||
from .run_tasks_with_telemetry import run_tasks_with_telemetry
|
from .run_tasks_with_telemetry import run_tasks_with_telemetry
|
||||||
|
from .run_tasks_data_item import run_tasks_data_item
|
||||||
from ..tasks.task import Task
|
from ..tasks.task import Task
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -68,176 +61,6 @@ async def run_tasks(
|
||||||
context: dict = None,
|
context: dict = None,
|
||||||
incremental_loading: bool = False,
|
incremental_loading: bool = False,
|
||||||
):
|
):
|
||||||
async def _run_tasks_data_item_incremental(
|
|
||||||
data_item,
|
|
||||||
dataset,
|
|
||||||
tasks,
|
|
||||||
pipeline_name,
|
|
||||||
pipeline_id,
|
|
||||||
pipeline_run_id,
|
|
||||||
context,
|
|
||||||
user,
|
|
||||||
):
|
|
||||||
db_engine = get_relational_engine()
|
|
||||||
# If incremental_loading of data is set to True don't process documents already processed by pipeline
|
|
||||||
# If data is being added to Cognee for the first time calculate the id of the data
|
|
||||||
if not isinstance(data_item, Data):
|
|
||||||
file_path = await save_data_item_to_storage(data_item)
|
|
||||||
# Ingest data and add metadata
|
|
||||||
async with open_data_file(file_path) as file:
|
|
||||||
classified_data = ingestion.classify(file)
|
|
||||||
# data_id is the hash of file contents + owner id to avoid duplicate data
|
|
||||||
data_id = ingestion.identify(classified_data, user)
|
|
||||||
else:
|
|
||||||
# If data was already processed by Cognee get data id
|
|
||||||
data_id = data_item.id
|
|
||||||
|
|
||||||
# Check pipeline status, if Data already processed for pipeline before skip current processing
|
|
||||||
async with db_engine.get_async_session() as session:
|
|
||||||
data_point = (
|
|
||||||
await session.execute(select(Data).filter(Data.id == data_id))
|
|
||||||
).scalar_one_or_none()
|
|
||||||
if data_point:
|
|
||||||
if (
|
|
||||||
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
|
|
||||||
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
|
||||||
):
|
|
||||||
yield {
|
|
||||||
"run_info": PipelineRunAlreadyCompleted(
|
|
||||||
pipeline_run_id=pipeline_run_id,
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
dataset_name=dataset.name,
|
|
||||||
),
|
|
||||||
"data_id": data_id,
|
|
||||||
}
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Process data based on data_item and list of tasks
|
|
||||||
async for result in run_tasks_with_telemetry(
|
|
||||||
tasks=tasks,
|
|
||||||
data=[data_item],
|
|
||||||
user=user,
|
|
||||||
pipeline_name=pipeline_id,
|
|
||||||
context=context,
|
|
||||||
):
|
|
||||||
yield PipelineRunYield(
|
|
||||||
pipeline_run_id=pipeline_run_id,
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
dataset_name=dataset.name,
|
|
||||||
payload=result,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update pipeline status for Data element
|
|
||||||
async with db_engine.get_async_session() as session:
|
|
||||||
data_point = (
|
|
||||||
await session.execute(select(Data).filter(Data.id == data_id))
|
|
||||||
).scalar_one_or_none()
|
|
||||||
data_point.pipeline_status[pipeline_name] = {
|
|
||||||
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
|
||||||
}
|
|
||||||
await session.merge(data_point)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
yield {
|
|
||||||
"run_info": PipelineRunCompleted(
|
|
||||||
pipeline_run_id=pipeline_run_id,
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
dataset_name=dataset.name,
|
|
||||||
),
|
|
||||||
"data_id": data_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as error:
|
|
||||||
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
|
|
||||||
logger.error(
|
|
||||||
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
|
|
||||||
)
|
|
||||||
yield {
|
|
||||||
"run_info": PipelineRunErrored(
|
|
||||||
pipeline_run_id=pipeline_run_id,
|
|
||||||
payload=repr(error),
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
dataset_name=dataset.name,
|
|
||||||
),
|
|
||||||
"data_id": data_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
if os.getenv("RAISE_INCREMENTAL_LOADING_ERRORS", "true").lower() == "true":
|
|
||||||
raise error
|
|
||||||
|
|
||||||
async def _run_tasks_data_item_regular(
|
|
||||||
data_item,
|
|
||||||
dataset,
|
|
||||||
tasks,
|
|
||||||
pipeline_id,
|
|
||||||
pipeline_run_id,
|
|
||||||
context,
|
|
||||||
user,
|
|
||||||
):
|
|
||||||
# Process data based on data_item and list of tasks
|
|
||||||
async for result in run_tasks_with_telemetry(
|
|
||||||
tasks=tasks,
|
|
||||||
data=[data_item],
|
|
||||||
user=user,
|
|
||||||
pipeline_name=pipeline_id,
|
|
||||||
context=context,
|
|
||||||
):
|
|
||||||
yield PipelineRunYield(
|
|
||||||
pipeline_run_id=pipeline_run_id,
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
dataset_name=dataset.name,
|
|
||||||
payload=result,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield {
|
|
||||||
"run_info": PipelineRunCompleted(
|
|
||||||
pipeline_run_id=pipeline_run_id,
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
dataset_name=dataset.name,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _run_tasks_data_item(
|
|
||||||
data_item,
|
|
||||||
dataset,
|
|
||||||
tasks,
|
|
||||||
pipeline_name,
|
|
||||||
pipeline_id,
|
|
||||||
pipeline_run_id,
|
|
||||||
context,
|
|
||||||
user,
|
|
||||||
incremental_loading,
|
|
||||||
):
|
|
||||||
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
|
|
||||||
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
|
|
||||||
result = None
|
|
||||||
if incremental_loading:
|
|
||||||
async for result in _run_tasks_data_item_incremental(
|
|
||||||
data_item=data_item,
|
|
||||||
dataset=dataset,
|
|
||||||
tasks=tasks,
|
|
||||||
pipeline_name=pipeline_name,
|
|
||||||
pipeline_id=pipeline_id,
|
|
||||||
pipeline_run_id=pipeline_run_id,
|
|
||||||
context=context,
|
|
||||||
user=user,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
async for result in _run_tasks_data_item_regular(
|
|
||||||
data_item=data_item,
|
|
||||||
dataset=dataset,
|
|
||||||
tasks=tasks,
|
|
||||||
pipeline_id=pipeline_id,
|
|
||||||
pipeline_run_id=pipeline_run_id,
|
|
||||||
context=context,
|
|
||||||
user=user,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
|
|
@ -269,7 +92,7 @@ async def run_tasks(
|
||||||
# Create async tasks per data item that will run the pipeline for the data item
|
# Create async tasks per data item that will run the pipeline for the data item
|
||||||
data_item_tasks = [
|
data_item_tasks = [
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_run_tasks_data_item(
|
run_tasks_data_item(
|
||||||
data_item,
|
data_item,
|
||||||
dataset,
|
dataset,
|
||||||
tasks,
|
tasks,
|
||||||
|
|
|
||||||
261
cognee/modules/pipelines/operations/run_tasks_data_item.py
Normal file
261
cognee/modules/pipelines/operations/run_tasks_data_item.py
Normal file
|
|
@ -0,0 +1,261 @@
|
||||||
|
"""
|
||||||
|
Data item processing functions for pipeline operations.
|
||||||
|
|
||||||
|
This module contains reusable functions for processing individual data items
|
||||||
|
within pipeline operations, supporting both incremental and regular processing modes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, AsyncGenerator, Optional
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
import cognee.modules.ingestion as ingestion
|
||||||
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.data.models import Data, Dataset
|
||||||
|
from cognee.tasks.ingestion import save_data_item_to_storage
|
||||||
|
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||||
|
PipelineRunCompleted,
|
||||||
|
PipelineRunErrored,
|
||||||
|
PipelineRunYield,
|
||||||
|
PipelineRunAlreadyCompleted,
|
||||||
|
)
|
||||||
|
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
|
||||||
|
from cognee.modules.pipelines.operations.run_tasks_with_telemetry import run_tasks_with_telemetry
|
||||||
|
from ..tasks.task import Task
|
||||||
|
|
||||||
|
logger = get_logger("run_tasks_data_item")
|
||||||
|
|
||||||
|
|
||||||
|
async def run_tasks_data_item_incremental(
|
||||||
|
data_item: Any,
|
||||||
|
dataset: Dataset,
|
||||||
|
tasks: list[Task],
|
||||||
|
pipeline_name: str,
|
||||||
|
pipeline_id: str,
|
||||||
|
pipeline_run_id: str,
|
||||||
|
context: Optional[Dict[str, Any]],
|
||||||
|
user: User,
|
||||||
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
"""
|
||||||
|
Process a single data item with incremental loading support.
|
||||||
|
|
||||||
|
This function handles incremental processing by checking if the data item
|
||||||
|
has already been processed for the given pipeline and dataset. If it has,
|
||||||
|
it skips processing and returns a completion status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_item: The data item to process
|
||||||
|
dataset: The dataset containing the data item
|
||||||
|
tasks: List of tasks to execute on the data item
|
||||||
|
pipeline_name: Name of the pipeline
|
||||||
|
pipeline_id: Unique identifier for the pipeline
|
||||||
|
pipeline_run_id: Unique identifier for this pipeline run
|
||||||
|
context: Optional context dictionary
|
||||||
|
user: User performing the operation
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Dict containing run_info and data_id for each processing step
|
||||||
|
"""
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
# If incremental_loading of data is set to True don't process documents already processed by pipeline
|
||||||
|
# If data is being added to Cognee for the first time calculate the id of the data
|
||||||
|
if not isinstance(data_item, Data):
|
||||||
|
file_path = await save_data_item_to_storage(data_item)
|
||||||
|
# Ingest data and add metadata
|
||||||
|
async with open_data_file(file_path) as file:
|
||||||
|
classified_data = ingestion.classify(file)
|
||||||
|
# data_id is the hash of file contents + owner id to avoid duplicate data
|
||||||
|
data_id = ingestion.identify(classified_data, user)
|
||||||
|
else:
|
||||||
|
# If data was already processed by Cognee get data id
|
||||||
|
data_id = data_item.id
|
||||||
|
|
||||||
|
# Check pipeline status, if Data already processed for pipeline before skip current processing
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
data_point = (
|
||||||
|
await session.execute(select(Data).filter(Data.id == data_id))
|
||||||
|
).scalar_one_or_none()
|
||||||
|
if data_point:
|
||||||
|
if (
|
||||||
|
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
|
||||||
|
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
||||||
|
):
|
||||||
|
yield {
|
||||||
|
"run_info": PipelineRunAlreadyCompleted(
|
||||||
|
pipeline_run_id=pipeline_run_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
dataset_name=dataset.name,
|
||||||
|
),
|
||||||
|
"data_id": data_id,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process data based on data_item and list of tasks
|
||||||
|
async for result in run_tasks_with_telemetry(
|
||||||
|
tasks=tasks,
|
||||||
|
data=[data_item],
|
||||||
|
user=user,
|
||||||
|
pipeline_name=pipeline_id,
|
||||||
|
context=context,
|
||||||
|
):
|
||||||
|
yield PipelineRunYield(
|
||||||
|
pipeline_run_id=pipeline_run_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
dataset_name=dataset.name,
|
||||||
|
payload=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update pipeline status for Data element
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
data_point = (
|
||||||
|
await session.execute(select(Data).filter(Data.id == data_id))
|
||||||
|
).scalar_one_or_none()
|
||||||
|
data_point.pipeline_status[pipeline_name] = {
|
||||||
|
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
||||||
|
}
|
||||||
|
await session.merge(data_point)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"run_info": PipelineRunCompleted(
|
||||||
|
pipeline_run_id=pipeline_run_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
dataset_name=dataset.name,
|
||||||
|
),
|
||||||
|
"data_id": data_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as error:
|
||||||
|
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
|
||||||
|
logger.error(
|
||||||
|
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
|
||||||
|
)
|
||||||
|
yield {
|
||||||
|
"run_info": PipelineRunErrored(
|
||||||
|
pipeline_run_id=pipeline_run_id,
|
||||||
|
payload=repr(error),
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
dataset_name=dataset.name,
|
||||||
|
),
|
||||||
|
"data_id": data_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.getenv("RAISE_INCREMENTAL_LOADING_ERRORS", "true").lower() == "true":
|
||||||
|
raise error
|
||||||
|
|
||||||
|
|
||||||
|
async def run_tasks_data_item_regular(
|
||||||
|
data_item: Any,
|
||||||
|
dataset: Dataset,
|
||||||
|
tasks: list[Task],
|
||||||
|
pipeline_id: str,
|
||||||
|
pipeline_run_id: str,
|
||||||
|
context: Optional[Dict[str, Any]],
|
||||||
|
user: User,
|
||||||
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
"""
|
||||||
|
Process a single data item in regular (non-incremental) mode.
|
||||||
|
|
||||||
|
This function processes a data item without checking for previous processing
|
||||||
|
status, executing all tasks on the data item.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_item: The data item to process
|
||||||
|
dataset: The dataset containing the data item
|
||||||
|
tasks: List of tasks to execute on the data item
|
||||||
|
pipeline_id: Unique identifier for the pipeline
|
||||||
|
pipeline_run_id: Unique identifier for this pipeline run
|
||||||
|
context: Optional context dictionary
|
||||||
|
user: User performing the operation
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Dict containing run_info for each processing step
|
||||||
|
"""
|
||||||
|
# Process data based on data_item and list of tasks
|
||||||
|
async for result in run_tasks_with_telemetry(
|
||||||
|
tasks=tasks,
|
||||||
|
data=[data_item],
|
||||||
|
user=user,
|
||||||
|
pipeline_name=pipeline_id,
|
||||||
|
context=context,
|
||||||
|
):
|
||||||
|
yield PipelineRunYield(
|
||||||
|
pipeline_run_id=pipeline_run_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
dataset_name=dataset.name,
|
||||||
|
payload=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"run_info": PipelineRunCompleted(
|
||||||
|
pipeline_run_id=pipeline_run_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
dataset_name=dataset.name,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def run_tasks_data_item(
|
||||||
|
data_item: Any,
|
||||||
|
dataset: Dataset,
|
||||||
|
tasks: list[Task],
|
||||||
|
pipeline_name: str,
|
||||||
|
pipeline_id: str,
|
||||||
|
pipeline_run_id: str,
|
||||||
|
context: Optional[Dict[str, Any]],
|
||||||
|
user: User,
|
||||||
|
incremental_loading: bool,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Process a single data item, choosing between incremental and regular processing.
|
||||||
|
|
||||||
|
This is the main entry point for data item processing that delegates to either
|
||||||
|
incremental or regular processing based on the incremental_loading flag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_item: The data item to process
|
||||||
|
dataset: The dataset containing the data item
|
||||||
|
tasks: List of tasks to execute on the data item
|
||||||
|
pipeline_name: Name of the pipeline
|
||||||
|
pipeline_id: Unique identifier for the pipeline
|
||||||
|
pipeline_run_id: Unique identifier for this pipeline run
|
||||||
|
context: Optional context dictionary
|
||||||
|
user: User performing the operation
|
||||||
|
incremental_loading: Whether to use incremental processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing the final processing result, or None if processing was skipped
|
||||||
|
"""
|
||||||
|
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
|
||||||
|
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
|
||||||
|
result = None
|
||||||
|
if incremental_loading:
|
||||||
|
async for result in run_tasks_data_item_incremental(
|
||||||
|
data_item=data_item,
|
||||||
|
dataset=dataset,
|
||||||
|
tasks=tasks,
|
||||||
|
pipeline_name=pipeline_name,
|
||||||
|
pipeline_id=pipeline_id,
|
||||||
|
pipeline_run_id=pipeline_run_id,
|
||||||
|
context=context,
|
||||||
|
user=user,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
async for result in run_tasks_data_item_regular(
|
||||||
|
data_item=data_item,
|
||||||
|
dataset=dataset,
|
||||||
|
tasks=tasks,
|
||||||
|
pipeline_id=pipeline_id,
|
||||||
|
pipeline_run_id=pipeline_run_id,
|
||||||
|
context=context,
|
||||||
|
user=user,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
@ -3,49 +3,96 @@ try:
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
modal = None
|
modal = None
|
||||||
|
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from cognee.modules.pipelines.tasks.task import Task
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.modules.pipelines.models import (
|
from cognee.modules.pipelines.models import (
|
||||||
PipelineRunStarted,
|
PipelineRunStarted,
|
||||||
PipelineRunYield,
|
|
||||||
PipelineRunCompleted,
|
PipelineRunCompleted,
|
||||||
|
PipelineRunErrored,
|
||||||
)
|
)
|
||||||
from cognee.modules.pipelines.operations import log_pipeline_run_start, log_pipeline_run_complete
|
from cognee.modules.pipelines.operations import (
|
||||||
from cognee.modules.pipelines.utils.generate_pipeline_id import generate_pipeline_id
|
log_pipeline_run_start,
|
||||||
|
log_pipeline_run_complete,
|
||||||
|
log_pipeline_run_error,
|
||||||
|
)
|
||||||
|
from cognee.modules.pipelines.utils import generate_pipeline_id
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.users.models import User
|
||||||
from .run_tasks_with_telemetry import run_tasks_with_telemetry
|
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
|
||||||
|
from cognee.tasks.ingestion import resolve_data_directories
|
||||||
|
from .run_tasks_data_item import run_tasks_data_item
|
||||||
|
|
||||||
logger = get_logger("run_tasks_distributed()")
|
logger = get_logger("run_tasks_distributed()")
|
||||||
|
|
||||||
|
|
||||||
if modal:
|
if modal:
|
||||||
|
import os
|
||||||
from distributed.app import app
|
from distributed.app import app
|
||||||
from distributed.modal_image import image
|
from distributed.modal_image import image
|
||||||
|
|
||||||
|
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
|
||||||
|
|
||||||
@app.function(
|
@app.function(
|
||||||
retries=3,
|
retries=3,
|
||||||
image=image,
|
image=image,
|
||||||
timeout=86400,
|
timeout=86400,
|
||||||
max_containers=50,
|
max_containers=50,
|
||||||
secrets=[modal.Secret.from_name("distributed_cognee")],
|
secrets=[modal.Secret.from_name(secret_name)],
|
||||||
)
|
)
|
||||||
async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):
|
async def run_tasks_on_modal(
|
||||||
pipeline_run = run_tasks_with_telemetry(tasks, data_item, user, pipeline_name, context)
|
data_item,
|
||||||
|
dataset_id: UUID,
|
||||||
|
tasks: List[Task],
|
||||||
|
pipeline_name: str,
|
||||||
|
pipeline_id: str,
|
||||||
|
pipeline_run_id: str,
|
||||||
|
context: Optional[dict],
|
||||||
|
user: User,
|
||||||
|
incremental_loading: bool,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Wrapper that runs the run_tasks_data_item function.
|
||||||
|
This is the function/code that runs on modal executor and produces the graph/vector db objects
|
||||||
|
"""
|
||||||
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
run_info = None
|
async with get_relational_engine().get_async_session() as session:
|
||||||
|
from cognee.modules.data.models import Dataset
|
||||||
|
|
||||||
async for pipeline_run_info in pipeline_run:
|
dataset = await session.get(Dataset, dataset_id)
|
||||||
run_info = pipeline_run_info
|
|
||||||
|
|
||||||
return run_info
|
result = await run_tasks_data_item(
|
||||||
|
data_item=data_item,
|
||||||
|
dataset=dataset,
|
||||||
|
tasks=tasks,
|
||||||
|
pipeline_name=pipeline_name,
|
||||||
|
pipeline_id=pipeline_id,
|
||||||
|
pipeline_run_id=pipeline_run_id,
|
||||||
|
context=context,
|
||||||
|
user=user,
|
||||||
|
incremental_loading=incremental_loading,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, context):
|
async def run_tasks_distributed(
|
||||||
|
tasks: List[Task],
|
||||||
|
dataset_id: UUID,
|
||||||
|
data: List[Any] = None,
|
||||||
|
user: User = None,
|
||||||
|
pipeline_name: str = "unknown_pipeline",
|
||||||
|
context: dict = None,
|
||||||
|
incremental_loading: bool = False,
|
||||||
|
):
|
||||||
if not user:
|
if not user:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
|
# Get dataset object
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
from cognee.modules.data.models import Dataset
|
from cognee.modules.data.models import Dataset
|
||||||
|
|
@ -53,9 +100,7 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
|
||||||
dataset = await session.get(Dataset, dataset_id)
|
dataset = await session.get(Dataset, dataset_id)
|
||||||
|
|
||||||
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
||||||
|
|
||||||
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
||||||
|
|
||||||
pipeline_run_id = pipeline_run.pipeline_run_id
|
pipeline_run_id = pipeline_run.pipeline_run_id
|
||||||
|
|
||||||
yield PipelineRunStarted(
|
yield PipelineRunStarted(
|
||||||
|
|
@ -65,30 +110,67 @@ async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, co
|
||||||
payload=data,
|
payload=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
data_count = len(data) if isinstance(data, list) else 1
|
try:
|
||||||
|
if not isinstance(data, list):
|
||||||
|
data = [data]
|
||||||
|
|
||||||
arguments = [
|
data = await resolve_data_directories(data)
|
||||||
[tasks] * data_count,
|
|
||||||
[[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],
|
|
||||||
[user] * data_count,
|
|
||||||
[pipeline_name] * data_count,
|
|
||||||
[context] * data_count,
|
|
||||||
]
|
|
||||||
|
|
||||||
async for result in run_tasks_on_modal.map.aio(*arguments):
|
number_of_data_items = len(data) if isinstance(data, list) else 1
|
||||||
logger.info(f"Received result: {result}")
|
|
||||||
|
|
||||||
yield PipelineRunYield(
|
data_item_tasks = [
|
||||||
|
data,
|
||||||
|
[dataset.id] * number_of_data_items,
|
||||||
|
[tasks] * number_of_data_items,
|
||||||
|
[pipeline_name] * number_of_data_items,
|
||||||
|
[pipeline_id] * number_of_data_items,
|
||||||
|
[pipeline_run_id] * number_of_data_items,
|
||||||
|
[context] * number_of_data_items,
|
||||||
|
[user] * number_of_data_items,
|
||||||
|
[incremental_loading] * number_of_data_items,
|
||||||
|
]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
async for result in run_tasks_on_modal.map.aio(*data_item_tasks):
|
||||||
|
if not result:
|
||||||
|
continue
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Remove skipped results
|
||||||
|
results = [r for r in results if r]
|
||||||
|
|
||||||
|
# If any data item failed, raise PipelineRunFailedError
|
||||||
|
errored = [
|
||||||
|
r
|
||||||
|
for r in results
|
||||||
|
if r and r.get("run_info") and isinstance(r["run_info"], PipelineRunErrored)
|
||||||
|
]
|
||||||
|
if errored:
|
||||||
|
raise PipelineRunFailedError("Pipeline run failed. Data item could not be processed.")
|
||||||
|
|
||||||
|
await log_pipeline_run_complete(
|
||||||
|
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
|
||||||
|
)
|
||||||
|
|
||||||
|
yield PipelineRunCompleted(
|
||||||
pipeline_run_id=pipeline_run_id,
|
pipeline_run_id=pipeline_run_id,
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
dataset_name=dataset.name,
|
dataset_name=dataset.name,
|
||||||
payload=result,
|
data_ingestion_info=results,
|
||||||
)
|
)
|
||||||
|
|
||||||
await log_pipeline_run_complete(pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data)
|
except Exception as error:
|
||||||
|
await log_pipeline_run_error(
|
||||||
|
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
|
||||||
|
)
|
||||||
|
|
||||||
yield PipelineRunCompleted(
|
yield PipelineRunErrored(
|
||||||
pipeline_run_id=pipeline_run_id,
|
pipeline_run_id=pipeline_run_id,
|
||||||
dataset_id=dataset.id,
|
payload=repr(error),
|
||||||
dataset_name=dataset.name,
|
dataset_id=dataset.id,
|
||||||
)
|
dataset_name=dataset.name,
|
||||||
|
data_ingestion_info=locals().get("results"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(error, PipelineRunFailedError):
|
||||||
|
raise
|
||||||
|
|
|
||||||
|
|
@ -194,7 +194,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
belongs_to_set=interactions_node_set,
|
belongs_to_set=interactions_node_set,
|
||||||
)
|
)
|
||||||
|
|
||||||
await add_data_points(data_points=[cognee_user_interaction], update_edge_collection=False)
|
await add_data_points(data_points=[cognee_user_interaction])
|
||||||
|
|
||||||
relationships = []
|
relationships = []
|
||||||
relationship_name = "used_graph_element_to_answer"
|
relationship_name = "used_graph_element_to_answer"
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.retrieval.base_feedback import BaseFeedback
|
from cognee.modules.retrieval.base_feedback import BaseFeedback
|
||||||
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
|
from cognee.modules.retrieval.utils.models import CogneeUserFeedback
|
||||||
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation
|
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points, index_graph_edges
|
||||||
|
|
||||||
logger = get_logger("CompletionRetriever")
|
logger = get_logger("CompletionRetriever")
|
||||||
|
|
||||||
|
|
@ -47,7 +47,7 @@ class UserQAFeedback(BaseFeedback):
|
||||||
belongs_to_set=feedbacks_node_set,
|
belongs_to_set=feedbacks_node_set,
|
||||||
)
|
)
|
||||||
|
|
||||||
await add_data_points(data_points=[cognee_user_feedback], update_edge_collection=False)
|
await add_data_points(data_points=[cognee_user_feedback])
|
||||||
|
|
||||||
relationships = []
|
relationships = []
|
||||||
relationship_name = "gives_feedback_to"
|
relationship_name = "gives_feedback_to"
|
||||||
|
|
@ -76,6 +76,7 @@ class UserQAFeedback(BaseFeedback):
|
||||||
if len(relationships) > 0:
|
if len(relationships) > 0:
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
await graph_engine.add_edges(relationships)
|
await graph_engine.add_edges(relationships)
|
||||||
|
await index_graph_edges(relationships)
|
||||||
await graph_engine.apply_feedback_weight(
|
await graph_engine.apply_feedback_weight(
|
||||||
node_ids=to_node_ids, weight=feedback_sentiment.score
|
node_ids=to_node_ids, weight=feedback_sentiment.score
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -124,5 +124,4 @@ async def add_rule_associations(
|
||||||
|
|
||||||
if len(edges_to_save) > 0:
|
if len(edges_to_save) > 0:
|
||||||
await graph_engine.add_edges(edges_to_save)
|
await graph_engine.add_edges(edges_to_save)
|
||||||
|
await index_graph_edges(edges_to_save)
|
||||||
await index_graph_edges()
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
||||||
|
from cognee.tasks.storage import index_graph_edges
|
||||||
from cognee.tasks.storage.add_data_points import add_data_points
|
from cognee.tasks.storage.add_data_points import add_data_points
|
||||||
from cognee.modules.ontology.ontology_config import Config
|
from cognee.modules.ontology.ontology_config import Config
|
||||||
from cognee.modules.ontology.get_default_ontology_resolver import (
|
from cognee.modules.ontology.get_default_ontology_resolver import (
|
||||||
|
|
@ -88,6 +89,7 @@ async def integrate_chunk_graphs(
|
||||||
|
|
||||||
if len(graph_edges) > 0:
|
if len(graph_edges) > 0:
|
||||||
await graph_engine.add_edges(graph_edges)
|
await graph_engine.add_edges(graph_edges)
|
||||||
|
await index_graph_edges(graph_edges)
|
||||||
|
|
||||||
return data_chunks
|
return data_chunks
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,7 @@ from cognee.tasks.storage.exceptions import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def add_data_points(
|
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||||
data_points: List[DataPoint], update_edge_collection: bool = True
|
|
||||||
) -> List[DataPoint]:
|
|
||||||
"""
|
"""
|
||||||
Add a batch of data points to the graph database by extracting nodes and edges,
|
Add a batch of data points to the graph database by extracting nodes and edges,
|
||||||
deduplicating them, and indexing them for retrieval.
|
deduplicating them, and indexing them for retrieval.
|
||||||
|
|
@ -25,9 +23,6 @@ async def add_data_points(
|
||||||
Args:
|
Args:
|
||||||
data_points (List[DataPoint]):
|
data_points (List[DataPoint]):
|
||||||
A list of data points to process and insert into the graph.
|
A list of data points to process and insert into the graph.
|
||||||
update_edge_collection (bool, optional):
|
|
||||||
Whether to update the edge index after adding edges.
|
|
||||||
Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[DataPoint]:
|
List[DataPoint]:
|
||||||
|
|
@ -73,12 +68,10 @@ async def add_data_points(
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
|
await graph_engine.add_nodes(nodes)
|
||||||
await index_data_points(nodes)
|
await index_data_points(nodes)
|
||||||
|
|
||||||
await graph_engine.add_nodes(nodes)
|
|
||||||
await graph_engine.add_edges(edges)
|
await graph_engine.add_edges(edges)
|
||||||
|
await index_graph_edges(edges)
|
||||||
if update_edge_collection:
|
|
||||||
await index_graph_edges()
|
|
||||||
|
|
||||||
return data_points
|
return data_points
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,18 @@
|
||||||
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||||
from cognee.shared.logging_utils import get_logger, ERROR
|
from cognee.shared.logging_utils import get_logger
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from typing import Optional, Dict, Any, List, Tuple, Union
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.modules.graph.models.EdgeType import EdgeType
|
from cognee.modules.graph.models.EdgeType import EdgeType
|
||||||
|
from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData
|
||||||
|
|
||||||
logger = get_logger(level=ERROR)
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
async def index_graph_edges():
|
async def index_graph_edges(
|
||||||
|
edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Indexes graph edges by creating and managing vector indexes for relationship types.
|
Indexes graph edges by creating and managing vector indexes for relationship types.
|
||||||
|
|
||||||
|
|
@ -35,13 +38,17 @@ async def index_graph_edges():
|
||||||
index_points = {}
|
index_points = {}
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
graph_engine = await get_graph_engine()
|
|
||||||
|
if edges_data is None:
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
_, edges_data = await graph_engine.get_graph_data()
|
||||||
|
logger.warning(
|
||||||
|
"Your graph edge embedding is deprecated, please pass edges to the index_graph_edges directly."
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to initialize engines: %s", e)
|
logger.error("Failed to initialize engines: %s", e)
|
||||||
raise RuntimeError("Initialization error") from e
|
raise RuntimeError("Initialization error") from e
|
||||||
|
|
||||||
_, edges_data = await graph_engine.get_graph_data()
|
|
||||||
|
|
||||||
edge_types = Counter(
|
edge_types = Counter(
|
||||||
item.get("relationship_name")
|
item.get("relationship_name")
|
||||||
for edge in edges_data
|
for edge in edges_data
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,3 @@ RUN poetry install --extras neo4j --extras postgres --extras aws --extras distri
|
||||||
|
|
||||||
COPY cognee/ /app/cognee
|
COPY cognee/ /app/cognee
|
||||||
COPY distributed/ /app/distributed
|
COPY distributed/ /app/distributed
|
||||||
RUN chmod +x /app/distributed/entrypoint.sh
|
|
||||||
|
|
||||||
ENTRYPOINT ["/app/distributed/entrypoint.sh"]
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from distributed.app import app
|
||||||
from distributed.queues import add_nodes_and_edges_queue, add_data_points_queue
|
from distributed.queues import add_nodes_and_edges_queue, add_data_points_queue
|
||||||
from distributed.workers.graph_saving_worker import graph_saving_worker
|
from distributed.workers.graph_saving_worker import graph_saving_worker
|
||||||
from distributed.workers.data_point_saving_worker import data_point_saving_worker
|
from distributed.workers.data_point_saving_worker import data_point_saving_worker
|
||||||
|
from distributed.signal import QueueSignal
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -23,13 +24,14 @@ async def main():
|
||||||
await add_nodes_and_edges_queue.clear.aio()
|
await add_nodes_and_edges_queue.clear.aio()
|
||||||
await add_data_points_queue.clear.aio()
|
await add_data_points_queue.clear.aio()
|
||||||
|
|
||||||
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
|
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn (MAX 1)
|
||||||
number_of_data_point_saving_workers = 5 # Total number of graph_saving_worker to spawn
|
number_of_data_point_saving_workers = (
|
||||||
|
10 # Total number of graph_saving_worker to spawn (MAX 10)
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
|
||||||
consumer_futures = []
|
consumer_futures = []
|
||||||
|
|
||||||
# await prune.prune_data() # We don't want to delete files on s3
|
await prune.prune_data() # This prunes the data from the file storage
|
||||||
# Delete DBs and saved files from metastore
|
# Delete DBs and saved files from metastore
|
||||||
await prune.prune_system(metadata=True)
|
await prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
|
@ -45,16 +47,28 @@ async def main():
|
||||||
worker_future = data_point_saving_worker.spawn()
|
worker_future = data_point_saving_worker.spawn()
|
||||||
consumer_futures.append(worker_future)
|
consumer_futures.append(worker_future)
|
||||||
|
|
||||||
|
""" Example: Setting and adding S3 path as input
|
||||||
s3_bucket_path = os.getenv("S3_BUCKET_PATH")
|
s3_bucket_path = os.getenv("S3_BUCKET_PATH")
|
||||||
s3_data_path = "s3://" + s3_bucket_path
|
s3_data_path = "s3://" + s3_bucket_path
|
||||||
|
|
||||||
await cognee.add(s3_data_path, dataset_name="s3-files")
|
await cognee.add(s3_data_path, dataset_name="s3-files")
|
||||||
|
"""
|
||||||
|
await cognee.add(
|
||||||
|
[
|
||||||
|
"Audi is a German car manufacturer",
|
||||||
|
"The Netherlands is next to Germany",
|
||||||
|
"Berlin is the capital of Germany",
|
||||||
|
"The Rhine is a major European river",
|
||||||
|
"BMW produces luxury vehicles",
|
||||||
|
],
|
||||||
|
dataset_name="s3-files",
|
||||||
|
)
|
||||||
|
|
||||||
await cognee.cognify(datasets=["s3-files"])
|
await cognee.cognify(datasets=["s3-files"])
|
||||||
|
|
||||||
# Push empty tuple into the queue to signal the end of data.
|
# Put Processing end signal into the queues to stop the consumers
|
||||||
await add_nodes_and_edges_queue.put.aio(())
|
await add_nodes_and_edges_queue.put.aio(QueueSignal.STOP)
|
||||||
await add_data_points_queue.put.aio(())
|
await add_data_points_queue.put.aio(QueueSignal.STOP)
|
||||||
|
|
||||||
for consumer_future in consumer_futures:
|
for consumer_future in consumer_futures:
|
||||||
try:
|
try:
|
||||||
|
|
@ -64,8 +78,6 @@ async def main():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
|
||||||
print(results)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
12238
distributed/poetry.lock
generated
12238
distributed/poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,185 +0,0 @@
|
||||||
[project]
|
|
||||||
name = "cognee"
|
|
||||||
version = "0.2.2.dev0"
|
|
||||||
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
|
||||||
authors = [
|
|
||||||
{ name = "Vasilije Markovic" },
|
|
||||||
{ name = "Boris Arzentar" },
|
|
||||||
]
|
|
||||||
requires-python = ">=3.10,<=3.13"
|
|
||||||
readme = "README.md"
|
|
||||||
license = "Apache-2.0"
|
|
||||||
classifiers = [
|
|
||||||
"Development Status :: 4 - Beta",
|
|
||||||
"Intended Audience :: Developers",
|
|
||||||
"License :: OSI Approved :: Apache Software License",
|
|
||||||
"Topic :: Software Development :: Libraries",
|
|
||||||
"Operating System :: MacOS :: MacOS X",
|
|
||||||
"Operating System :: POSIX :: Linux",
|
|
||||||
"Operating System :: Microsoft :: Windows",
|
|
||||||
]
|
|
||||||
dependencies = [
|
|
||||||
"openai>=1.80.1,<2",
|
|
||||||
"python-dotenv>=1.0.1,<2.0.0",
|
|
||||||
"pydantic>=2.11.7,<3.0.0",
|
|
||||||
"pydantic-settings>=2.10.1,<3",
|
|
||||||
"typing_extensions>=4.12.2,<5.0.0",
|
|
||||||
"nltk>=3.9.1,<4.0.0",
|
|
||||||
"numpy>=1.26.4, <=4.0.0",
|
|
||||||
"pandas>=2.2.2,<3.0.0",
|
|
||||||
# Note: New s3fs and boto3 versions don't work well together
|
|
||||||
# Always use comaptible fixed versions of these two dependencies
|
|
||||||
"s3fs[boto3]==2025.3.2",
|
|
||||||
"sqlalchemy>=2.0.39,<3.0.0",
|
|
||||||
"aiosqlite>=0.20.0,<1.0.0",
|
|
||||||
"tiktoken>=0.8.0,<1.0.0",
|
|
||||||
"litellm>=1.57.4, <1.71.0",
|
|
||||||
"instructor>=1.9.1,<2.0.0",
|
|
||||||
"langfuse>=2.32.0,<3",
|
|
||||||
"filetype>=1.2.0,<2.0.0",
|
|
||||||
"aiohttp>=3.11.14,<4.0.0",
|
|
||||||
"aiofiles>=23.2.1,<24.0.0",
|
|
||||||
"rdflib>=7.1.4,<7.2.0",
|
|
||||||
"pypdf>=4.1.0,<7.0.0",
|
|
||||||
"jinja2>=3.1.3,<4",
|
|
||||||
"matplotlib>=3.8.3,<4",
|
|
||||||
"networkx>=3.4.2,<4",
|
|
||||||
"lancedb>=0.24.0,<1.0.0",
|
|
||||||
"alembic>=1.13.3,<2",
|
|
||||||
"pre-commit>=4.0.1,<5",
|
|
||||||
"scikit-learn>=1.6.1,<2",
|
|
||||||
"limits>=4.4.1,<5",
|
|
||||||
"fastapi>=0.115.7,<1.0.0",
|
|
||||||
"python-multipart>=0.0.20,<1.0.0",
|
|
||||||
"fastapi-users[sqlalchemy]>=14.0.1,<15.0.0",
|
|
||||||
"dlt[sqlalchemy]>=1.9.0,<2",
|
|
||||||
"sentry-sdk[fastapi]>=2.9.0,<3",
|
|
||||||
"structlog>=25.2.0,<26",
|
|
||||||
"pympler>=1.1,<2.0.0",
|
|
||||||
"onnxruntime>=1.0.0,<2.0.0",
|
|
||||||
"pylance>=0.22.0,<1.0.0",
|
|
||||||
"kuzu (==0.11.0)"
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
api = [
|
|
||||||
"uvicorn>=0.34.0,<1.0.0",
|
|
||||||
"gunicorn>=20.1.0,<24",
|
|
||||||
"websockets>=15.0.1,<16.0.0"
|
|
||||||
]
|
|
||||||
distributed = [
|
|
||||||
"modal>=1.0.5,<2.0.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
neo4j = ["neo4j>=5.28.0,<6"]
|
|
||||||
postgres = [
|
|
||||||
"psycopg2>=2.9.10,<3",
|
|
||||||
"pgvector>=0.3.5,<0.4",
|
|
||||||
"asyncpg>=0.30.0,<1.0.0",
|
|
||||||
]
|
|
||||||
postgres-binary = [
|
|
||||||
"psycopg2-binary>=2.9.10,<3.0.0",
|
|
||||||
"pgvector>=0.3.5,<0.4",
|
|
||||||
"asyncpg>=0.30.0,<1.0.0",
|
|
||||||
]
|
|
||||||
notebook = ["notebook>=7.1.0,<8"]
|
|
||||||
langchain = [
|
|
||||||
"langsmith>=0.2.3,<1.0.0",
|
|
||||||
"langchain_text_splitters>=0.3.2,<1.0.0",
|
|
||||||
]
|
|
||||||
llama-index = ["llama-index-core>=0.12.11,<0.13"]
|
|
||||||
gemini = ["google-generativeai>=0.8.4,<0.9"]
|
|
||||||
huggingface = ["transformers>=4.46.3,<5"]
|
|
||||||
ollama = ["transformers>=4.46.3,<5"]
|
|
||||||
mistral = ["mistral-common>=1.5.2,<2"]
|
|
||||||
anthropic = ["anthropic>=0.26.1,<0.27"]
|
|
||||||
deepeval = ["deepeval>=2.0.1,<3"]
|
|
||||||
posthog = ["posthog>=3.5.0,<4"]
|
|
||||||
groq = ["groq>=0.8.0,<1.0.0"]
|
|
||||||
chromadb = [
|
|
||||||
"chromadb>=0.3.0,<0.7",
|
|
||||||
"pypika==0.48.8",
|
|
||||||
]
|
|
||||||
docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx]>=0.18.1,<19"]
|
|
||||||
codegraph = [
|
|
||||||
"fastembed<=0.6.0 ; python_version < '3.13'",
|
|
||||||
"transformers>=4.46.3,<5",
|
|
||||||
"tree-sitter>=0.24.0,<0.25",
|
|
||||||
"tree-sitter-python>=0.23.6,<0.24",
|
|
||||||
]
|
|
||||||
evals = [
|
|
||||||
"plotly>=6.0.0,<7",
|
|
||||||
"gdown>=5.2.0,<6",
|
|
||||||
]
|
|
||||||
gui = [
|
|
||||||
"pyside6>=6.8.3,<7",
|
|
||||||
"qasync>=0.27.1,<0.28",
|
|
||||||
]
|
|
||||||
graphiti = ["graphiti-core>=0.7.0,<0.8"]
|
|
||||||
# Note: New s3fs and boto3 versions don't work well together
|
|
||||||
# Always use comaptible fixed versions of these two dependencies
|
|
||||||
aws = ["s3fs[boto3]==2025.3.2"]
|
|
||||||
dev = [
|
|
||||||
"pytest>=7.4.0,<8",
|
|
||||||
"pytest-cov>=6.1.1,<7.0.0",
|
|
||||||
"pytest-asyncio>=0.21.1,<0.22",
|
|
||||||
"coverage>=7.3.2,<8",
|
|
||||||
"mypy>=1.7.1,<2",
|
|
||||||
"notebook>=7.1.0,<8",
|
|
||||||
"deptry>=0.20.0,<0.21",
|
|
||||||
"pylint>=3.0.3,<4",
|
|
||||||
"ruff>=0.9.2,<1.0.0",
|
|
||||||
"tweepy>=4.14.0,<5.0.0",
|
|
||||||
"gitpython>=3.1.43,<4",
|
|
||||||
"mkdocs-material>=9.5.42,<10",
|
|
||||||
"mkdocs-minify-plugin>=0.8.0,<0.9",
|
|
||||||
"mkdocstrings[python]>=0.26.2,<0.27",
|
|
||||||
]
|
|
||||||
debug = ["debugpy>=1.8.9,<2.0.0"]
|
|
||||||
|
|
||||||
[project.urls]
|
|
||||||
Homepage = "https://www.cognee.ai"
|
|
||||||
Repository = "https://github.com/topoteretes/cognee"
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["hatchling"]
|
|
||||||
build-backend = "hatchling.build"
|
|
||||||
|
|
||||||
[tool.hatch.build]
|
|
||||||
exclude = [
|
|
||||||
"/bin",
|
|
||||||
"/dist",
|
|
||||||
"/.data",
|
|
||||||
"/.github",
|
|
||||||
"/alembic",
|
|
||||||
"/deployment",
|
|
||||||
"/cognee-mcp",
|
|
||||||
"/cognee-frontend",
|
|
||||||
"/examples",
|
|
||||||
"/helm",
|
|
||||||
"/licenses",
|
|
||||||
"/logs",
|
|
||||||
"/notebooks",
|
|
||||||
"/profiling",
|
|
||||||
"/tests",
|
|
||||||
"/tools",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
|
||||||
packages = ["cognee", "distributed"]
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 100
|
|
||||||
exclude = [
|
|
||||||
"migrations/", # Ignore migrations directory
|
|
||||||
"notebooks/", # Ignore notebook files
|
|
||||||
"build/", # Ignore build directory
|
|
||||||
"cognee/pipelines.py",
|
|
||||||
"cognee/modules/users/models/Group.py",
|
|
||||||
"cognee/modules/users/models/ACL.py",
|
|
||||||
"cognee/modules/pipelines/models/Task.py",
|
|
||||||
"cognee/modules/data/models/Dataset.py"
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
ignore = ["F401"]
|
|
||||||
5
distributed/signal.py
Normal file
5
distributed/signal.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class QueueSignal(str, Enum):
|
||||||
|
STOP = "STOP"
|
||||||
|
|
@ -1,16 +1,17 @@
|
||||||
|
import os
|
||||||
import modal
|
import modal
|
||||||
import asyncio
|
import asyncio
|
||||||
from sqlalchemy.exc import OperationalError, DBAPIError
|
from sqlalchemy.exc import OperationalError, DBAPIError
|
||||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
from distributed.app import app
|
from distributed.app import app
|
||||||
|
from distributed.signal import QueueSignal
|
||||||
from distributed.modal_image import image
|
from distributed.modal_image import image
|
||||||
from distributed.queues import add_data_points_queue
|
from distributed.queues import add_data_points_queue
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("data_point_saving_worker")
|
logger = get_logger("data_point_saving_worker")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -39,55 +40,84 @@ def is_deadlock_error(error):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
@app.function(
|
||||||
retries=3,
|
retries=3,
|
||||||
image=image,
|
image=image,
|
||||||
timeout=86400,
|
timeout=86400,
|
||||||
max_containers=5,
|
max_containers=10,
|
||||||
secrets=[modal.Secret.from_name("distributed_cognee")],
|
secrets=[modal.Secret.from_name(secret_name)],
|
||||||
)
|
)
|
||||||
async def data_point_saving_worker():
|
async def data_point_saving_worker():
|
||||||
print("Started processing of data points; starting vector engine queue.")
|
print("Started processing of data points; starting vector engine queue.")
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
# Defines how many data packets do we glue together from the modal queue before embedding call and ingestion
|
||||||
|
BATCH_SIZE = 25
|
||||||
|
stop_seen = False
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
if stop_seen:
|
||||||
|
print("Finished processing all data points; stopping vector engine queue consumer.")
|
||||||
|
return True
|
||||||
|
|
||||||
if await add_data_points_queue.len.aio() != 0:
|
if await add_data_points_queue.len.aio() != 0:
|
||||||
try:
|
try:
|
||||||
add_data_points_request = await add_data_points_queue.get.aio(block=False)
|
print("Remaining elements in queue:")
|
||||||
|
print(await add_data_points_queue.len.aio())
|
||||||
|
|
||||||
|
# collect batched requests
|
||||||
|
batched_points = {}
|
||||||
|
for _ in range(min(BATCH_SIZE, await add_data_points_queue.len.aio())):
|
||||||
|
add_data_points_request = await add_data_points_queue.get.aio(block=False)
|
||||||
|
|
||||||
|
if not add_data_points_request:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if add_data_points_request == QueueSignal.STOP:
|
||||||
|
await add_data_points_queue.put.aio(QueueSignal.STOP)
|
||||||
|
stop_seen = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(add_data_points_request) == 2:
|
||||||
|
collection_name, data_points = add_data_points_request
|
||||||
|
if collection_name not in batched_points:
|
||||||
|
batched_points[collection_name] = []
|
||||||
|
batched_points[collection_name].extend(data_points)
|
||||||
|
else:
|
||||||
|
print("NoneType or invalid request detected.")
|
||||||
|
|
||||||
|
if batched_points:
|
||||||
|
for collection_name, data_points in batched_points.items():
|
||||||
|
print(
|
||||||
|
f"Adding {len(data_points)} data points to '{collection_name}' collection."
|
||||||
|
)
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_exception_type(VectorDatabaseDeadlockError),
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
|
)
|
||||||
|
async def add_data_points():
|
||||||
|
try:
|
||||||
|
await vector_engine.create_data_points(
|
||||||
|
collection_name, data_points, distributed=False
|
||||||
|
)
|
||||||
|
except DBAPIError as error:
|
||||||
|
if is_deadlock_error(error):
|
||||||
|
raise VectorDatabaseDeadlockError()
|
||||||
|
except OperationalError as error:
|
||||||
|
if is_deadlock_error(error):
|
||||||
|
raise VectorDatabaseDeadlockError()
|
||||||
|
|
||||||
|
await add_data_points()
|
||||||
|
print(f"Finished adding data points to '{collection_name}'.")
|
||||||
|
|
||||||
except modal.exception.DeserializationError as error:
|
except modal.exception.DeserializationError as error:
|
||||||
logger.error(f"Deserialization error: {str(error)}")
|
logger.error(f"Deserialization error: {str(error)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(add_data_points_request) == 0:
|
|
||||||
print("Finished processing all data points; stopping vector engine queue.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
if len(add_data_points_request) == 2:
|
|
||||||
(collection_name, data_points) = add_data_points_request
|
|
||||||
|
|
||||||
print(f"Adding {len(data_points)} data points to '{collection_name}' collection.")
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
retry=retry_if_exception_type(VectorDatabaseDeadlockError),
|
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
|
||||||
)
|
|
||||||
async def add_data_points():
|
|
||||||
try:
|
|
||||||
await vector_engine.create_data_points(
|
|
||||||
collection_name, data_points, distributed=False
|
|
||||||
)
|
|
||||||
except DBAPIError as error:
|
|
||||||
if is_deadlock_error(error):
|
|
||||||
raise VectorDatabaseDeadlockError()
|
|
||||||
except OperationalError as error:
|
|
||||||
if is_deadlock_error(error):
|
|
||||||
raise VectorDatabaseDeadlockError()
|
|
||||||
|
|
||||||
await add_data_points()
|
|
||||||
|
|
||||||
print("Finished adding data points.")
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("No jobs, go to sleep.")
|
print("No jobs, go to sleep.")
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
|
import os
|
||||||
import modal
|
import modal
|
||||||
import asyncio
|
import asyncio
|
||||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
from distributed.app import app
|
from distributed.app import app
|
||||||
|
from distributed.signal import QueueSignal
|
||||||
from distributed.modal_image import image
|
from distributed.modal_image import image
|
||||||
from distributed.queues import add_nodes_and_edges_queue
|
from distributed.queues import add_nodes_and_edges_queue
|
||||||
|
|
||||||
|
|
@ -10,7 +12,6 @@ from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("graph_saving_worker")
|
logger = get_logger("graph_saving_worker")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -37,68 +38,91 @@ def is_deadlock_error(error):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
@app.function(
|
||||||
retries=3,
|
retries=3,
|
||||||
image=image,
|
image=image,
|
||||||
timeout=86400,
|
timeout=86400,
|
||||||
max_containers=5,
|
max_containers=1,
|
||||||
secrets=[modal.Secret.from_name("distributed_cognee")],
|
secrets=[modal.Secret.from_name(secret_name)],
|
||||||
)
|
)
|
||||||
async def graph_saving_worker():
|
async def graph_saving_worker():
|
||||||
print("Started processing of nodes and edges; starting graph engine queue.")
|
print("Started processing of nodes and edges; starting graph engine queue.")
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
# Defines how many data packets do we glue together from the queue before ingesting them into the graph database
|
||||||
|
BATCH_SIZE = 25
|
||||||
|
stop_seen = False
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
if stop_seen:
|
||||||
|
print("Finished processing all data points; stopping graph engine queue consumer.")
|
||||||
|
return True
|
||||||
|
|
||||||
if await add_nodes_and_edges_queue.len.aio() != 0:
|
if await add_nodes_and_edges_queue.len.aio() != 0:
|
||||||
try:
|
try:
|
||||||
nodes_and_edges = await add_nodes_and_edges_queue.get.aio(block=False)
|
print("Remaining elements in queue:")
|
||||||
|
print(await add_nodes_and_edges_queue.len.aio())
|
||||||
|
|
||||||
|
all_nodes, all_edges = [], []
|
||||||
|
for _ in range(min(BATCH_SIZE, await add_nodes_and_edges_queue.len.aio())):
|
||||||
|
nodes_and_edges = await add_nodes_and_edges_queue.get.aio(block=False)
|
||||||
|
|
||||||
|
if not nodes_and_edges:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if nodes_and_edges == QueueSignal.STOP:
|
||||||
|
await add_nodes_and_edges_queue.put.aio(QueueSignal.STOP)
|
||||||
|
stop_seen = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(nodes_and_edges) == 2:
|
||||||
|
nodes, edges = nodes_and_edges
|
||||||
|
all_nodes.extend(nodes)
|
||||||
|
all_edges.extend(edges)
|
||||||
|
else:
|
||||||
|
print("None Type detected.")
|
||||||
|
|
||||||
|
if all_nodes or all_edges:
|
||||||
|
print(f"Adding {len(all_nodes)} nodes and {len(all_edges)} edges.")
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
|
)
|
||||||
|
async def save_graph_nodes(new_nodes):
|
||||||
|
try:
|
||||||
|
await graph_engine.add_nodes(new_nodes, distributed=False)
|
||||||
|
except Exception as error:
|
||||||
|
if is_deadlock_error(error):
|
||||||
|
raise GraphDatabaseDeadlockError()
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
|
)
|
||||||
|
async def save_graph_edges(new_edges):
|
||||||
|
try:
|
||||||
|
await graph_engine.add_edges(new_edges, distributed=False)
|
||||||
|
except Exception as error:
|
||||||
|
if is_deadlock_error(error):
|
||||||
|
raise GraphDatabaseDeadlockError()
|
||||||
|
|
||||||
|
if all_nodes:
|
||||||
|
await save_graph_nodes(all_nodes)
|
||||||
|
|
||||||
|
if all_edges:
|
||||||
|
await save_graph_edges(all_edges)
|
||||||
|
|
||||||
|
print("Finished adding nodes and edges.")
|
||||||
|
|
||||||
except modal.exception.DeserializationError as error:
|
except modal.exception.DeserializationError as error:
|
||||||
logger.error(f"Deserialization error: {str(error)}")
|
logger.error(f"Deserialization error: {str(error)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(nodes_and_edges) == 0:
|
|
||||||
print("Finished processing all nodes and edges; stopping graph engine queue.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
if len(nodes_and_edges) == 2:
|
|
||||||
print(
|
|
||||||
f"Adding {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
|
|
||||||
)
|
|
||||||
nodes = nodes_and_edges[0]
|
|
||||||
edges = nodes_and_edges[1]
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
|
||||||
)
|
|
||||||
async def save_graph_nodes(new_nodes):
|
|
||||||
try:
|
|
||||||
await graph_engine.add_nodes(new_nodes, distributed=False)
|
|
||||||
except Exception as error:
|
|
||||||
if is_deadlock_error(error):
|
|
||||||
raise GraphDatabaseDeadlockError()
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
retry=retry_if_exception_type(GraphDatabaseDeadlockError),
|
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
|
||||||
)
|
|
||||||
async def save_graph_edges(new_edges):
|
|
||||||
try:
|
|
||||||
await graph_engine.add_edges(new_edges, distributed=False)
|
|
||||||
except Exception as error:
|
|
||||||
if is_deadlock_error(error):
|
|
||||||
raise GraphDatabaseDeadlockError()
|
|
||||||
|
|
||||||
if nodes:
|
|
||||||
await save_graph_nodes(nodes)
|
|
||||||
|
|
||||||
if edges:
|
|
||||||
await save_graph_edges(edges)
|
|
||||||
|
|
||||||
print("Finished adding nodes and edges.")
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("No jobs, go to sleep.")
|
print("No jobs, go to sleep.")
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue