fix: code cleanup [COG-781] (#667)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin
This commit is contained in:
Boris 2025-03-26 18:32:43 +01:00 committed by GitHub
parent 9f587a01a4
commit ebf1f81b35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
45 changed files with 128 additions and 853 deletions

View file

@ -3,12 +3,11 @@ from .api.v1.cognify import cognify
from .api.v1.config.config import config
from .api.v1.datasets.datasets import datasets
from .api.v1.prune import prune
from .api.v1.search import SearchType, get_search_history, search
from .api.v1.search import SearchType, search
from .api.v1.visualize import visualize_graph, start_visualization_server
from cognee.modules.visualization.cognee_network_visualization import (
cognee_network_visualization,
)
from .modules.data.operations.get_pipeline_run_metrics import get_pipeline_run_metrics
# Pipelines
from .modules import pipelines

View file

@ -1 +0,0 @@
from .authenticate_user import authenticate_user

View file

@ -1,27 +0,0 @@
from cognee.infrastructure.databases.relational.user_authentication.users import (
authenticate_user_method,
)
async def authenticate_user(email: str, password: str):
"""
This function is used to authenticate a user.
"""
output = await authenticate_user_method(email=email, password=password)
return output
if __name__ == "__main__":
import asyncio
# Define an example user
example_email = "example@example.com"
example_password = "securepassword123"
example_is_superuser = False
# Create an event loop and run the create_user function
loop = asyncio.get_event_loop()
result = loop.run_until_complete(authenticate_user(example_email, example_password))
# Print the result
print(result)

View file

@ -1 +1 @@
from .cognify_v2 import cognify
from .cognify import cognify

View file

@ -2,8 +2,7 @@ import asyncio
from cognee.shared.logging_utils import get_logger
from uuid import NAMESPACE_OID, uuid5
from cognee.api.v1.search.search_v2 import search
from cognee.api.v1.search import SearchType
from cognee.api.v1.search import SearchType, search
from cognee.base_config import get_base_config
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks
@ -14,11 +13,7 @@ from cognee.shared.utils import render_graph
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.ingestion import ingest_data
from cognee.tasks.repo_processor import (
get_data_list_for_user,
get_non_py_files,
get_repo_file_dependencies,
)
from cognee.tasks.repo_processor import get_non_py_files, get_repo_file_dependencies
from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_text
@ -31,7 +26,6 @@ if monitoring == MonitoringTool.LANGFUSE:
logger = get_logger("code_graph_pipeline")
update_status_lock = asyncio.Lock()
@observe
@ -49,18 +43,15 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
tasks = [
Task(get_repo_file_dependencies, detailed_extraction=detailed_extraction),
# Task(enrich_dependency_graph, task_config={"batch_size": 50}),
# Task(expand_dependency_graph, task_config={"batch_size": 50}),
# Task(get_source_code_chunks, task_config={"batch_size": 50}),
# Task(summarize_code, task_config={"batch_size": 50}),
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
Task(add_data_points, task_config={"batch_size": 500}),
]
if include_docs:
# This tasks take a long time to complete
non_code_tasks = [
Task(get_non_py_files, task_config={"batch_size": 50}),
Task(ingest_data, dataset_name="repo_docs", user=user),
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents),
Task(extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()),
Task(

View file

@ -17,7 +17,6 @@ from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.data_models import KnowledgeGraph
from cognee.shared.utils import send_telemetry
from cognee.tasks.documents import (
check_permissions_on_documents,
classify_documents,
@ -28,7 +27,7 @@ from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_text
from cognee.modules.chunking.TextChunker import TextChunker
logger = get_logger("cognify.v2")
logger = get_logger("cognify")
update_status_lock = asyncio.Lock()
@ -76,8 +75,6 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, tasks: list[Task]):
dataset_id = dataset.id
dataset_name = generate_dataset_name(dataset.name)
send_telemetry("cognee.cognify EXECUTION STARTED", user.id)
# async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests
task_status = await get_pipeline_status([dataset_id])
@ -88,26 +85,20 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, tasks: list[Task]):
logger.info("Dataset %s is already being processed.", dataset_name)
return
try:
if not isinstance(tasks, list):
raise ValueError("Tasks must be a list")
if not isinstance(tasks, list):
raise ValueError("Tasks must be a list")
for task in tasks:
if not isinstance(task, Task):
raise ValueError(f"Task {task} is not an instance of Task")
for task in tasks:
if not isinstance(task, Task):
raise ValueError(f"Task {task} is not an instance of Task")
pipeline_run = run_tasks(tasks, dataset.id, data_documents, "cognify_pipeline")
pipeline_run_status = None
pipeline_run = run_tasks(tasks, dataset.id, data_documents, "cognify_pipeline")
pipeline_run_status = None
async for run_status in pipeline_run:
pipeline_run_status = run_status
async for run_status in pipeline_run:
pipeline_run_status = run_status
send_telemetry("cognee.cognify EXECUTION COMPLETED", user.id)
return pipeline_run_status
except Exception as error:
send_telemetry("cognee.cognify EXECUTION ERRORED", user.id)
raise error
return pipeline_run_status
def generate_dataset_name(dataset_name: str) -> str:
@ -124,31 +115,30 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
if user is None:
user = await get_default_user()
try:
cognee_config = get_cognify_config()
ontology_adapter = OntologyResolver(ontology_file=ontology_file_path)
default_tasks = [
Task(classify_documents),
Task(check_permissions_on_documents, user=user, permissions=["write"]),
Task(
extract_chunks_from_documents,
max_chunk_size=chunk_size or get_max_chunk_tokens(),
chunker=chunker,
), # Extract text chunks based on the document type.
Task(
extract_graph_from_data,
graph_model=graph_model,
ontology_adapter=ontology_adapter,
task_config={"batch_size": 10},
), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,
summarization_model=cognee_config.summarization_model,
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
]
except Exception as error:
send_telemetry("cognee.cognify DEFAULT TASKS CREATION ERRORED", user.id)
raise error
cognee_config = get_cognify_config()
ontology_adapter = OntologyResolver(ontology_file=ontology_file_path)
default_tasks = [
Task(classify_documents),
Task(check_permissions_on_documents, user=user, permissions=["write"]),
Task(
extract_chunks_from_documents,
max_chunk_size=chunk_size or get_max_chunk_tokens(),
chunker=chunker,
), # Extract text chunks based on the document type.
Task(
extract_graph_from_data,
graph_model=graph_model,
ontology_adapter=ontology_adapter,
task_config={"batch_size": 10},
), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,
summarization_model=cognee_config.summarization_model,
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
]
return default_tasks

View file

@ -22,7 +22,7 @@ class CodePipelineRetrievePayloadDTO(InDTO):
def get_code_pipeline_router() -> APIRouter:
try:
import run_code_graph_pipeline
import cognee.api.v1.cognify.code_graph_pipeline
except ModuleNotFoundError:
logger.error("codegraph dependencies not found. Skipping codegraph API routes.")
return None

View file

@ -1,10 +1,10 @@
from fastapi import APIRouter
from typing import List, Optional
from pydantic import BaseModel
from cognee.modules.users.models import User
from fastapi.responses import JSONResponse
from cognee.modules.users.methods import get_authenticated_user
from fastapi import Depends
from fastapi import APIRouter
from fastapi.responses import JSONResponse
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.data_models import KnowledgeGraph
@ -19,7 +19,7 @@ def get_cognify_router() -> APIRouter:
@router.post("/", response_model=None)
async def cognify(payload: CognifyPayloadDTO, user: User = Depends(get_authenticated_user)):
"""This endpoint is responsible for the cognitive processing of the content."""
from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify
from cognee.api.v1.cognify import cognify as cognee_cognify
try:
await cognee_cognify(payload.datasets, user, payload.graph_model)

View file

@ -1,2 +1 @@
from .search_v2 import search, SearchType
from .get_search_history import get_search_history
from .search import search, SearchType

View file

@ -1,10 +0,0 @@
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
async def get_search_history(user: User = None) -> list:
if not user:
user = await get_default_user()
return await get_history(user.id)

View file

@ -1,6 +1,6 @@
from enum import Enum
from typing import Callable, Awaitable, List
from cognee.api.v1.cognify.cognify_v2 import get_default_tasks
from cognee.api.v1.cognify.cognify import get_default_tasks
from cognee.modules.pipelines.tasks.Task import Task
from cognee.eval_framework.corpus_builder.task_getters.get_cascade_graph_tasks import (
get_cascade_graph_tasks,

View file

@ -1,5 +1,5 @@
from typing import List, Awaitable, Optional
from cognee.api.v1.cognify.cognify_v2 import get_default_tasks
from typing import List
from cognee.api.v1.cognify.cognify import get_default_tasks
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.chunking.TextChunker import TextChunker

View file

@ -1,10 +1,13 @@
from typing import Union
from uuid import UUID
from sqlalchemy import select
from cognee.infrastructure.databases.relational import get_relational_engine
from ..models import Dataset
async def get_datasets_by_name(dataset_names: list[str], user_id: UUID) -> list[Dataset]:
async def get_datasets_by_name(
dataset_names: Union[str, list[str]], user_id: UUID
) -> list[Dataset]:
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:

View file

@ -0,0 +1 @@
from .get_pipeline_run_metrics import get_pipeline_run_metrics

View file

@ -3,7 +3,8 @@ from uuid import uuid5, NAMESPACE_OID
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine.models import DataPoint
from cognee.modules.data.extraction.extract_categories import extract_categories
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk

View file

@ -1,29 +0,0 @@
import os
import asyncio
import argparse
from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
from cognee.tasks.repo_processor.expand_dependency_graph import expand_dependency_graph
def main():
parser = argparse.ArgumentParser()
parser.add_argument("repo_path", help="Path to the repository")
args = parser.parse_args()
repo_path = args.repo_path
if not os.path.exists(repo_path):
print(f"Error: The provided repository path does not exist: {repo_path}")
return
graph = asyncio.run(get_repo_file_dependencies(repo_path))
graph = asyncio.run(enrich_dependency_graph(graph))
graph = expand_dependency_graph(graph)
for node in graph.nodes:
print(f"Node: {node}")
for _, target, data in graph.out_edges(node, data=True):
print(f" Edge to {target}, data: {data}")
if __name__ == "__main__":
main()

View file

@ -1,3 +1,5 @@
from .translate_text import translate_text
from .detect_language import detect_language
from .classify_documents import classify_documents
from .extract_chunks_from_documents import extract_chunks_from_documents
from .check_permissions_on_documents import check_permissions_on_documents

View file

@ -3,15 +3,12 @@ from typing import Any, List
import dlt
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.methods import create_dataset
from cognee.modules.data.methods import create_dataset, get_dataset_data, get_datasets_by_name
from cognee.modules.data.models.DatasetData import DatasetData
from cognee.modules.users.models import User
from cognee.modules.users.permissions.methods import give_permission_on_document
from cognee.shared.utils import send_telemetry
from .get_dlt_destination import get_dlt_destination
from .save_data_item_to_storage import (
save_data_item_to_storage,
)
from .save_data_item_to_storage import save_data_item_to_storage
from typing import Union, BinaryIO
import inspect
@ -48,7 +45,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
"owner_id": str(user.id),
}
async def data_storing(data: Any, dataset_name: str, user: User):
async def store_data_to_dataset(data: Any, dataset_name: str, user: User):
if not isinstance(data, list):
# Convert data to a list as we work with lists further down.
data = [data]
@ -124,18 +121,16 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
await give_permission_on_document(user, data_id, "write")
return file_paths
send_telemetry("cognee.add EXECUTION STARTED", user_id=user.id)
db_engine = get_relational_engine()
file_paths = await data_storing(data, dataset_name, user)
file_paths = await store_data_to_dataset(data, dataset_name, user)
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
# can't be used inside the pipeline
if db_engine.engine.dialect.name == "sqlite":
# To use sqlite with dlt dataset_name must be set to "main".
# Sqlite doesn't support schemas
run_info = pipeline.run(
pipeline.run(
data_resources(file_paths, user),
table_name="file_metadata",
dataset_name="main",
@ -143,13 +138,15 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
)
else:
# Data should be stored in the same schema to allow deduplication
run_info = pipeline.run(
pipeline.run(
data_resources(file_paths, user),
table_name="file_metadata",
dataset_name="public",
write_disposition="merge",
)
send_telemetry("cognee.add EXECUTION COMPLETED", user_id=user.id)
datasets = await get_datasets_by_name(dataset_name, user.id)
dataset = datasets[0]
data_documents = await get_dataset_data(dataset_id=dataset.id)
return run_info
return data_documents

View file

@ -1,4 +1,2 @@
from .enrich_dependency_graph import enrich_dependency_graph
from .expand_dependency_graph import expand_dependency_graph
from .get_non_code_files import get_data_list_for_user, get_non_py_files
from .get_non_code_files import get_non_py_files
from .get_repo_file_dependencies import get_repo_file_dependencies

View file

@ -1,136 +0,0 @@
import networkx as nx
from typing import AsyncGenerator, Dict, List
from tqdm.asyncio import tqdm
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile
from cognee.modules.graph.utils import get_graph_from_model, convert_node_to_data_point
from cognee.infrastructure.databases.graph import get_graph_engine
def topologically_sort_subgraph(
subgraph_node_to_indegree: Dict[str, int], graph: nx.DiGraph
) -> List[str]:
"""Performs a topological sort on a subgraph based on node indegrees."""
results = []
remaining_nodes = subgraph_node_to_indegree.copy()
while remaining_nodes:
next_node = min(remaining_nodes, key=remaining_nodes.get)
results.append(next_node)
for successor in graph.successors(next_node):
if successor in remaining_nodes:
remaining_nodes[successor] -= 1
remaining_nodes.pop(next_node)
return results
def topologically_sort(graph: nx.DiGraph) -> List[str]:
"""Performs a topological sort on the entire graph."""
subgraphs = (graph.subgraph(c).copy() for c in nx.weakly_connected_components(graph))
topological_order = []
for subgraph in subgraphs:
node_to_indegree = {node: len(list(subgraph.successors(node))) for node in subgraph.nodes}
topological_order.extend(topologically_sort_subgraph(node_to_indegree, subgraph))
return topological_order
async def node_enrich_and_connect(
graph: nx.MultiDiGraph,
topological_order: List[str],
node: CodeFile,
data_points_map: Dict[str, DataPoint],
) -> None:
"""Adds 'depends_on' edges to the graph based on topological order."""
topological_rank = topological_order.index(node.id)
node.topological_rank = topological_rank
node_descendants = nx.descendants(graph, node.id)
if graph.has_edge(node.id, node.id):
node_descendants.add(node.id)
new_connections = []
graph_engine = await get_graph_engine()
for desc_id in node_descendants:
if desc_id not in topological_order[: topological_rank + 1]:
continue
desc = None
if desc_id in data_points_map:
desc = data_points_map[desc_id]
else:
node_data = await graph_engine.extract_node(str(desc_id))
try:
desc = convert_node_to_data_point(node_data)
except Exception:
pass
if desc is not None:
new_connections.append(desc)
node.depends_directly_on = node.depends_directly_on or []
node.depends_directly_on.extend(new_connections)
async def enrich_dependency_graph(
data_points: list[DataPoint],
) -> AsyncGenerator[list[DataPoint], None]:
"""Enriches the graph with topological ranks and 'depends_on' edges."""
nodes = []
edges = []
added_nodes = {}
added_edges = {}
visited_properties = {}
for data_point in data_points:
graph_nodes, graph_edges = await get_graph_from_model(
data_point,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
nodes.extend(graph_nodes)
edges.extend(graph_edges)
graph = nx.MultiDiGraph()
simple_nodes = [(node.id, node.model_dump()) for node in nodes]
graph.add_nodes_from(simple_nodes)
graph.add_edges_from(edges)
topological_order = topologically_sort(graph)
node_rank_map = {node: idx for idx, node in enumerate(topological_order)}
# for node_id, node in tqdm(graph.nodes(data = True), desc = "Enriching dependency graph", unit = "node"):
# if node_id not in node_rank_map:
# continue
# data_points.append(node_enrich_and_connect(graph, topological_order, node))
data_points_map = {data_point.id: data_point for data_point in data_points}
# data_points_futures = []
for data_point in tqdm(data_points, desc="Enriching dependency graph", unit="data_point"):
if data_point.id not in node_rank_map:
continue
if isinstance(data_point, CodeFile):
# data_points_futures.append(node_enrich_and_connect(graph, topological_order, data_point, data_points_map))
await node_enrich_and_connect(graph, topological_order, data_point, data_points_map)
yield data_point
# await asyncio.gather(*data_points_futures)
# return data_points

View file

@ -1,74 +0,0 @@
from typing import AsyncGenerator
from uuid import NAMESPACE_OID, uuid5
# from tqdm import tqdm
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, CodePart
from cognee.tasks.repo_processor.extract_code_parts import extract_code_parts
from cognee.shared.logging_utils import get_logger
logger = get_logger()
def _add_code_parts_nodes_and_edges(code_file: CodeFile, part_type, code_parts) -> None:
"""Add code part nodes and edges for a specific part type."""
if not code_parts:
logger.debug(f"No code parts to add for node {code_file.id} and part_type {part_type}.")
return
part_nodes = []
for idx, code_part in enumerate(code_parts):
if not code_part.strip():
logger.warning(f"Empty code part in node {code_file.id} and part_type {part_type}.")
continue
part_node_id = uuid5(NAMESPACE_OID, f"{code_file.id}_{part_type}_{idx}")
part_nodes.append(
CodePart(
id=part_node_id,
type=part_type,
# part_of = code_file,
file_path=code_file.file_path[len(code_file.part_of.path) + 1 :],
source_code=code_part,
)
)
# graph.add_node(part_node_id, source_code=code_part, node_type=part_type)
# graph.add_edge(parent_node_id, part_node_id, relation="contains")
code_file.contains = code_file.contains or []
code_file.contains.extend(part_nodes)
def _process_single_node(code_file: CodeFile) -> None:
"""Process a single Python file node."""
node_id = code_file.id
source_code = code_file.source_code
if not source_code.strip():
logger.warning(f"Node {node_id} has no or empty 'source_code'. Skipping.")
return
try:
code_parts_dict = extract_code_parts(source_code)
except Exception as e:
logger.error(f"Error processing node {node_id}: {e}")
return
for part_type, code_parts in code_parts_dict.items():
_add_code_parts_nodes_and_edges(code_file, part_type, code_parts)
async def expand_dependency_graph(
data_points: list[DataPoint],
) -> AsyncGenerator[list[DataPoint], None]:
"""Process Python file nodes, adding code part nodes and edges."""
# for data_point in tqdm(data_points, desc = "Expand dependency graph", unit = "data_point"):
for data_point in data_points:
if isinstance(data_point, CodeFile):
_process_single_node(data_point)
yield data_point
# return data_points

View file

@ -1,61 +0,0 @@
from typing import Dict, List
from cognee.shared.logging_utils import get_logger, ERROR
logger = get_logger(level=ERROR)
def _extract_parts_from_module(module, parts_dict: Dict[str, List[str]]) -> Dict[str, List[str]]:
"""Extract code parts from a parsed module."""
current_top_level_code = []
child_to_code_type = {
"classdef": "classes",
"funcdef": "functions",
"import_name": "imports",
"import_from": "imports",
}
for child in module.children:
if child.type == "simple_stmt":
current_top_level_code.append(child.get_code())
continue
if current_top_level_code:
parts_dict["top_level_code"].append("\n".join(current_top_level_code))
current_top_level_code = []
if child.type in child_to_code_type:
code_type = child_to_code_type[child.type]
parts_dict[code_type].append(child.get_code())
if current_top_level_code:
parts_dict["top_level_code"].append("\n".join(current_top_level_code))
if parts_dict["imports"]:
parts_dict["imports"] = ["\n".join(parts_dict["imports"])]
return parts_dict
def extract_code_parts(source_code: str) -> Dict[str, List[str]]:
"""Extract high-level parts of the source code."""
parts_dict = {"classes": [], "functions": [], "imports": [], "top_level_code": []}
if not source_code.strip():
logger.warning("Empty source_code provided.")
return parts_dict
try:
import parso
module = parso.parse(source_code)
except Exception as e:
logger.error(f"Error parsing source code: {e}")
return parts_dict
if not module.children:
logger.warning("Parsed module has no children (empty or invalid source code).")
return parts_dict
return _extract_parts_from_module(module, parts_dict)

View file

@ -1,17 +1,5 @@
import os
import aiofiles
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.methods import get_datasets
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
from cognee.modules.data.methods.get_datasets_by_name import get_datasets_by_name
from cognee.modules.data.models import Data
from cognee.modules.ingestion.data_types import BinaryData
from cognee.modules.users.methods import get_default_user
from cognee.shared.CodeGraphEntities import Repository
async def get_non_py_files(repo_path):
"""Get files that are not .py files and their contents"""
@ -135,15 +123,3 @@ async def get_non_py_files(repo_path):
if not file.endswith(".py") and should_process(os.path.join(root, file))
]
return non_py_files_paths
async def get_data_list_for_user(_, dataset_name, user):
# Note: This method is meant to be used as a Task in a pipeline.
# By the nature of pipelines, the output of the previous Task will be passed as the first argument here,
# but it is not needed here, hence the "_" input.
datasets = await get_datasets_by_name(dataset_name, user.id)
data_documents: list[Data] = []
for dataset in datasets:
data_docs: list[Data] = await get_dataset_data(dataset_id=dataset.id)
data_documents.extend(data_docs)
return data_documents

View file

@ -1,171 +0,0 @@
from cognee.shared.logging_utils import get_logger
from typing import AsyncGenerator, Generator
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
from cognee.infrastructure.llm import get_max_chunk_tokens
logger = get_logger()
def _get_naive_subchunk_token_counts(
source_code: str, max_subchunk_tokens
) -> list[tuple[str, int]]:
"""Splits source code into subchunks of up to max_subchunk_tokens and counts tokens."""
tokenizer = get_vector_engine().embedding_engine.tokenizer
token_ids = tokenizer.extract_tokens(source_code)
subchunk_token_counts = []
for start_idx in range(0, len(token_ids), max_subchunk_tokens):
subchunk_token_ids = token_ids[start_idx : start_idx + max_subchunk_tokens]
token_count = len(subchunk_token_ids)
# Note: This can't work with Gemini embeddings as they keep their method of encoding text
# to tokens hidden and don't offer a decoder
# TODO: Add support for different tokenizers for this function
subchunk = "".join(
tokenizer.decode_single_token(token_id) for token_id in subchunk_token_ids
)
subchunk_token_counts.append((subchunk, token_count))
return subchunk_token_counts
def _get_subchunk_token_counts(
source_code: str,
max_subchunk_tokens,
depth: int = 0,
max_depth: int = 100,
) -> list[tuple[str, int]]:
"""Splits source code into subchunk and counts tokens for each subchunk."""
if depth > max_depth:
return _get_naive_subchunk_token_counts(source_code, max_subchunk_tokens)
try:
import parso
module = parso.parse(source_code)
except Exception as e:
logger.error(f"Error parsing source code: {e}")
return []
if not module.children:
logger.warning("Parsed module has no children (empty or invalid source code).")
return []
# Handle cases with only one real child and an EndMarker to prevent infinite recursion.
if len(module.children) <= 2:
module = module.children[0]
subchunk_token_counts = []
for child in module.children:
subchunk = child.get_code()
tokenizer = get_vector_engine().embedding_engine.tokenizer
token_count = tokenizer.count_tokens(subchunk)
if token_count == 0:
continue
if token_count <= max_subchunk_tokens:
subchunk_token_counts.append((subchunk, token_count))
continue
if child.type == "string":
subchunk_token_counts.extend(
_get_naive_subchunk_token_counts(subchunk, max_subchunk_tokens)
)
continue
subchunk_token_counts.extend(
_get_subchunk_token_counts(
subchunk, max_subchunk_tokens, depth=depth + 1, max_depth=max_depth
)
)
return subchunk_token_counts
def _get_chunk_source_code(
code_token_counts: list[tuple[str, int]], overlap: float
) -> tuple[list[tuple[str, int]], str]:
"""Generates a chunk of source code from tokenized subchunks with overlap handling."""
current_count = 0
cumulative_counts = []
current_source_code = ""
for i, (child_code, token_count) in enumerate(code_token_counts):
current_count += token_count
cumulative_counts.append(current_count)
if current_count > get_max_chunk_tokens():
break
current_source_code += f"\n{child_code}"
if current_count <= get_max_chunk_tokens():
return [], current_source_code.strip()
cutoff = 1
for i, cum_count in enumerate(cumulative_counts):
if cum_count > (1 - overlap) * get_max_chunk_tokens():
break
cutoff = i
return code_token_counts[cutoff:], current_source_code.strip()
def get_source_code_chunks_from_code_part(
code_file_part: CodePart,
overlap: float = 0.25,
granularity: float = 0.09,
) -> Generator[SourceCodeChunk, None, None]:
"""Yields source code chunks from a CodePart object, with configurable token limits and overlap."""
if not code_file_part.source_code:
logger.error(f"No source code in CodeFile {code_file_part.id}")
return
max_subchunk_tokens = max(1, int(granularity * get_max_chunk_tokens()))
subchunk_token_counts = _get_subchunk_token_counts(
code_file_part.source_code, max_subchunk_tokens
)
previous_chunk = None
while subchunk_token_counts:
subchunk_token_counts, chunk_source_code = _get_chunk_source_code(
subchunk_token_counts, overlap
)
if not chunk_source_code:
continue
current_chunk = SourceCodeChunk(
id=uuid5(NAMESPACE_OID, chunk_source_code),
code_chunk_of=code_file_part,
source_code=chunk_source_code,
previous_chunk=previous_chunk,
)
yield current_chunk
previous_chunk = current_chunk
async def get_source_code_chunks(
data_points: list[DataPoint],
) -> AsyncGenerator[list[DataPoint], None]:
"""Processes code graph datapoints, create SourceCodeChink datapoints."""
for data_point in data_points:
try:
yield data_point
if not isinstance(data_point, CodeFile):
continue
if not data_point.contains:
logger.warning(f"CodeFile {data_point.id} contains no code parts")
continue
for code_part in data_point.contains:
try:
yield code_part
for source_code_chunk in get_source_code_chunks_from_code_part(code_part):
yield source_code_chunk
except Exception as e:
logger.error(f"Error processing code part: {e}")
raise e
except Exception as e:
logger.error(f"Error processing data point: {e}")
raise e

View file

@ -1,183 +0,0 @@
import os
from tqdm import tqdm
from cognee.shared.logging_utils import get_logger
logger = get_logger()
_NODE_TYPE_MAP = {
"funcdef": "func_def",
"classdef": "class_def",
"async_funcdef": "async_func_def",
"async_stmt": "async_func_def",
"simple_stmt": "var_def",
}
def _create_object_dict(name_node, type_name=None):
return {
"name": name_node.value,
"line": name_node.start_pos[0],
"column": name_node.start_pos[1],
"type": type_name,
}
def _parse_node(node):
"""Parse a node to extract importable object details, including async functions and classes."""
node_type = _NODE_TYPE_MAP.get(node.type)
if node.type in {"funcdef", "classdef", "async_funcdef"}:
return [_create_object_dict(node.name, type_name=node_type)]
if node.type == "async_stmt" and len(node.children) > 1:
function_node = node.children[1]
if function_node.type == "funcdef":
return [
_create_object_dict(
function_node.name, type_name=_NODE_TYPE_MAP.get(function_node.type)
)
]
if node.type == "simple_stmt":
# TODO: Handle multi-level/nested unpacking variable definitions in the future
expr_child = node.children[0]
if expr_child.type != "expr_stmt":
return []
if expr_child.children[0].type == "testlist_star_expr":
name_targets = expr_child.children[0].children
else:
name_targets = expr_child.children
return [
_create_object_dict(target, type_name=_NODE_TYPE_MAP.get(target.type))
for target in name_targets
if target.type == "name"
]
return []
def extract_importable_objects_with_positions_from_source_code(source_code):
"""Extract top-level objects in a Python source code string with their positions (line/column)."""
try:
import parso
tree = parso.parse(source_code)
except Exception as e:
logger.error(f"Error parsing source code: {e}")
return []
importable_objects = []
try:
for node in tree.children:
importable_objects.extend(_parse_node(node))
except Exception as e:
logger.error(f"Error extracting nodes from parsed tree: {e}")
return []
return importable_objects
def extract_importable_objects_with_positions(file_path):
"""Extract top-level objects in a Python file with their positions (line/column)."""
try:
with open(file_path, "r") as file:
source_code = file.read()
except Exception as e:
logger.error(f"Error reading file {file_path}: {e}")
return []
return extract_importable_objects_with_positions_from_source_code(source_code)
def find_entity_usages(script, line, column):
"""
Return a list of files in the repo where the entity at module_path:line,column is used.
"""
usages = set()
try:
inferred = script.infer(line, column)
except Exception as e:
logger.error(f"Error inferring entity at {script.path}:{line},{column}: {e}")
return []
if not inferred or not inferred[0]:
logger.info(f"No entity inferred at {script.path}:{line},{column}")
return []
logger.debug(f"Inferred entity: {inferred[0].name}, type: {inferred[0].type}")
try:
references = script.get_references(
line=line, column=column, scope="project", include_builtins=False
)
except Exception as e:
logger.error(
f"Error retrieving references for entity at {script.path}:{line},{column}: {e}"
)
references = []
for ref in references:
if ref.module_path: # Collect unique module paths
usages.add(ref.module_path)
logger.info(f"Entity used in: {ref.module_path}")
return list(usages)
def parse_file_with_references(project, file_path):
"""Parse a file to extract object names and their references within a project."""
try:
importable_objects = extract_importable_objects_with_positions(file_path)
except Exception as e:
logger.error(f"Error extracting objects from {file_path}: {e}")
return []
if not os.path.isfile(file_path):
logger.warning(f"Module file does not exist: {file_path}")
return []
try:
import jedi
script = jedi.Script(path=file_path, project=project)
except Exception as e:
logger.error(f"Error initializing Jedi Script: {e}")
return []
parsed_results = [
{
"name": obj["name"],
"type": obj["type"],
"references": find_entity_usages(script, obj["line"], obj["column"]),
}
for obj in importable_objects
]
return parsed_results
def parse_repo(repo_path):
"""Parse a repository to extract object names, types, and references for all Python files."""
try:
import jedi
project = jedi.Project(path=repo_path)
except Exception as e:
logger.error(f"Error creating Jedi project for repository at {repo_path}: {e}")
return {}
EXCLUDE_DIRS = {"venv", ".git", "__pycache__", "build"}
python_files = [
os.path.join(directory, file)
for directory, _, filenames in os.walk(repo_path)
if not any(excluded in directory for excluded in EXCLUDE_DIRS)
for file in filenames
if file.endswith(".py") and os.path.getsize(os.path.join(directory, file)) > 0
]
results = {
file_path: parse_file_with_references(project, file_path)
for file_path in tqdm(python_files)
}
return results

View file

@ -4,8 +4,9 @@ import pathlib
import cognee
from cognee.modules.data.models import Data
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import get_default_user
from cognee.modules.search.types import SearchType
from cognee.modules.search.operations import get_history
logger = get_logger()
@ -151,7 +152,8 @@ async def main():
for result in search_results:
print(f"{result}\n")
history = await cognee.get_search_history()
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 8, "Search history is not correct."
await cognee.prune.prune_data()

View file

@ -1,7 +1,9 @@
import os
from cognee.shared.logging_utils import get_logger
import pathlib
import cognee
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.shared.utils import render_graph
from cognee.low_level import DataPoint
@ -93,7 +95,8 @@ async def main():
for chunk in chunks:
print(chunk)
history = await cognee.get_search_history()
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 8, "Search history is not correct."

View file

@ -1,7 +1,9 @@
import os
from cognee.shared.logging_utils import get_logger
import pathlib
import cognee
import pathlib
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
# from cognee.shared.utils import render_graph
@ -72,7 +74,8 @@ async def main():
for result in search_results:
print(f"{result}\n")
history = await cognee.get_search_history()
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 6, "Search history is not correct."

View file

@ -1,12 +1,11 @@
import os
from cognee.shared.logging_utils import get_logger
import pathlib
import cognee
import shutil
import cognee
import pathlib
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.infrastructure.engine import DataPoint
from uuid import uuid4
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
logger = get_logger()
@ -81,7 +80,8 @@ async def main():
for result in search_results:
print(f"{result}\n")
history = await cognee.get_search_history()
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()

View file

@ -1,7 +1,9 @@
import os
from cognee.shared.logging_utils import get_logger
import pathlib
import cognee
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
logger = get_logger()
@ -69,7 +71,8 @@ async def main():
for result in search_results:
print(f"{result}\n")
history = await cognee.get_search_history()
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 6, "Search history is not correct."

View file

@ -1,7 +1,9 @@
import os
from cognee.shared.logging_utils import get_logger
import pathlib
import cognee
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
logger = get_logger()
@ -80,7 +82,8 @@ async def main():
for result in search_results:
print(f"{result}\n")
history = await cognee.get_search_history()
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()

View file

@ -1,9 +1,10 @@
import os
from cognee.shared.logging_utils import get_logger
import pathlib
import cognee
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
logger = get_logger()
@ -73,13 +74,11 @@ async def main():
for result in search_results:
print(f"{result}\n")
history = await cognee.get_search_history()
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 6, "Search history is not correct."
results = await brute_force_triplet_search("What is a quantum computer?")
assert len(results) > 0
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -1,11 +1,10 @@
import os
from cognee.shared.logging_utils import get_logger
import pathlib
import cognee
from cognee.modules.search.operations import get_history
from cognee.shared.logging_utils import get_logger
from cognee.modules.data.models import Data
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.users.methods import get_default_user
logger = get_logger()
@ -159,12 +158,10 @@ async def main():
for result in search_results:
print(f"{result}\n")
history = await cognee.get_search_history()
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 8, "Search history is not correct."
results = await brute_force_triplet_search("What is a quantum computer?")
assert len(results) > 0
await test_local_file_deletion(text, explanation_file_path)
await cognee.prune.prune_data()

View file

@ -1,9 +1,10 @@
import os
from cognee.shared.logging_utils import get_logger
import pathlib
import cognee
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
logger = get_logger()
@ -73,12 +74,10 @@ async def main():
for result in search_results:
print(f"{result}\n")
history = await cognee.get_search_history()
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 6, "Search history is not correct."
results = await brute_force_triplet_search("What is a quantum computer?")
assert len(results) > 0
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -1,9 +1,10 @@
import os
from cognee.shared.logging_utils import get_logger
import pathlib
import cognee
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
logger = get_logger()
@ -73,12 +74,10 @@ async def main():
for result in search_results:
print(f"{result}\n")
history = await cognee.get_search_history()
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 6, "Search history is not correct."
results = await brute_force_triplet_search("What is a quantum computer?")
assert len(results) > 0
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -1,6 +1,7 @@
import cognee
import asyncio
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.modules.metrics.operations import get_pipeline_run_metrics
from cognee.api.v1.search import SearchType
@ -184,7 +185,7 @@ async def main(enable_steps):
# Step 4: Calculate descriptive metrics
if enable_steps.get("graph_metrics"):
await cognee.get_pipeline_run_metrics(pipeline_run, include_optional=True)
await get_pipeline_run_metrics(pipeline_run, include_optional=True)
print("Descriptive graph metrics saved to database.")
# Step 5: Query insights

View file

@ -1,5 +1,6 @@
import cognee
import asyncio
from cognee.modules.metrics.operations import get_pipeline_run_metrics
from cognee.shared.logging_utils import get_logger
import os
@ -64,7 +65,7 @@ async def main():
print("Knowledge with ontology created.")
# Step 4: Calculate descriptive metrics
await cognee.get_pipeline_run_metrics(pipeline_run, include_optional=True)
await get_pipeline_run_metrics(pipeline_run, include_optional=True)
print("Descriptive graph metrics saved to database.")
# Step 5: Query insights