Merge branch 'dev' into llama-index-notebook

This commit is contained in:
Igor Ilic 2025-01-10 17:26:05 +01:00 committed by GitHub
commit a5c91e8f0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 459 additions and 202 deletions

View file

@ -10,7 +10,7 @@ repos:
- id: check-added-large-files - id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version. # Ruff version.
rev: v0.8.3 rev: v0.9.0
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff - id: ruff

View file

@ -3,7 +3,6 @@ import logging
from pathlib import Path from pathlib import Path
from cognee.base_config import get_base_config from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.modules.cognify.config import get_cognify_config from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task from cognee.modules.pipelines.tasks.Task import Task
@ -54,8 +53,6 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
await create_db_and_tables() await create_db_and_tables()
embedding_engine = get_embedding_engine()
cognee_config = get_cognify_config() cognee_config = get_cognify_config()
user = await get_default_user() user = await get_default_user()
@ -63,11 +60,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(get_repo_file_dependencies), Task(get_repo_file_dependencies),
Task(enrich_dependency_graph), Task(enrich_dependency_graph),
Task(expand_dependency_graph, task_config={"batch_size": 50}), Task(expand_dependency_graph, task_config={"batch_size": 50}),
Task( Task(get_source_code_chunks, task_config={"batch_size": 50}),
get_source_code_chunks,
embedding_model=embedding_engine.model,
task_config={"batch_size": 50},
),
Task(summarize_code, task_config={"batch_size": 50}), Task(summarize_code, task_config={"batch_size": 50}),
Task(add_data_points, task_config={"batch_size": 50}), Task(add_data_points, task_config={"batch_size": 50}),
] ]
@ -78,7 +71,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user), Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
Task(get_data_list_for_user, dataset_name="repo_docs", user=user), Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents), Task(classify_documents),
Task(extract_chunks_from_documents), Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens),
Task( Task(
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50} extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
), ),

View file

@ -493,7 +493,7 @@ class Neo4jAdapter(GraphDBInterface):
query_edges = f""" query_edges = f"""
MATCH (n)-[r]->(m) MATCH (n)-[r]->(m)
WHERE {where_clause} AND {where_clause.replace('n.', 'm.')} WHERE {where_clause} AND {where_clause.replace("n.", "m.")}
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
""" """
result_edges = await self.query(query_edges) result_edges = await self.query(query_edges)

View file

@ -1,3 +1,6 @@
I need you to solve this issue by looking at the provided edges retrieved from a knowledge graph and You are a senior software engineer. I need you to solve this issue by looking at the provided context and
generate a single patch file that I can apply directly to this repository using git apply. generate a single patch file that I can apply directly to this repository using git apply.
Please respond with a single patch file in the following format. Additionally, please make sure that you provide code only with correct syntax and
you apply the patch on the relevant files (together with their path that you can try to find out from the github issue). Don't change the names of existing
functions or classes, as they may be referenced from other code.
Please respond only with a single patch file in the following format without adding any additional context or string.

View file

@ -1,7 +1,9 @@
from uuid import uuid5, NAMESPACE_OID from typing import Optional
from uuid import NAMESPACE_OID, uuid5
from cognee.tasks.chunks import chunk_by_paragraph
from .models.DocumentChunk import DocumentChunk from .models.DocumentChunk import DocumentChunk
from cognee.tasks.chunks import chunk_by_paragraph
class TextChunker: class TextChunker:
@ -10,23 +12,36 @@ class TextChunker:
chunk_index = 0 chunk_index = 0
chunk_size = 0 chunk_size = 0
token_count = 0
def __init__(self, document, get_text: callable, chunk_size: int = 1024): def __init__(
self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024
):
self.document = document self.document = document
self.max_chunk_size = chunk_size self.max_chunk_size = chunk_size
self.get_text = get_text self.get_text = get_text
self.max_tokens = max_tokens if max_tokens else float("inf")
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens
return word_count_fits and token_count_fits
def read(self): def read(self):
paragraph_chunks = [] paragraph_chunks = []
for content_text in self.get_text(): for content_text in self.get_text():
for chunk_data in chunk_by_paragraph( for chunk_data in chunk_by_paragraph(
content_text, content_text,
self.max_tokens,
self.max_chunk_size, self.max_chunk_size,
batch_paragraphs=True, batch_paragraphs=True,
): ):
if self.chunk_size + chunk_data["word_count"] <= self.max_chunk_size: if self.check_word_count_and_token_count(
self.chunk_size, self.token_count, chunk_data
):
paragraph_chunks.append(chunk_data) paragraph_chunks.append(chunk_data)
self.chunk_size += chunk_data["word_count"] self.chunk_size += chunk_data["word_count"]
self.token_count += chunk_data["token_count"]
else: else:
if len(paragraph_chunks) == 0: if len(paragraph_chunks) == 0:
yield DocumentChunk( yield DocumentChunk(
@ -66,6 +81,7 @@ class TextChunker:
print(e) print(e)
paragraph_chunks = [chunk_data] paragraph_chunks = [chunk_data]
self.chunk_size = chunk_data["word_count"] self.chunk_size = chunk_data["word_count"]
self.token_count = chunk_data["token_count"]
self.chunk_index += 1 self.chunk_index += 1

View file

@ -12,6 +12,7 @@ class DocumentChunk(DataPoint):
chunk_index: int chunk_index: int
cut_type: str cut_type: str
is_part_of: Document is_part_of: Document
pydantic_type: str = "DocumentChunk"
contains: List[Entity] = None contains: List[Entity] = None
_metadata: dict = {"index_fields": ["text"], "type": "DocumentChunk"} _metadata: dict = {"index_fields": ["text"], "type": "DocumentChunk"}

View file

@ -1,12 +1,14 @@
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.shared.data_models import DefaultContentPrediction, SummarizedContent from cognee.shared.data_models import DefaultContentPrediction, SummarizedContent
from typing import Optional
import os
class CognifyConfig(BaseSettings): class CognifyConfig(BaseSettings):
classification_model: object = DefaultContentPrediction classification_model: object = DefaultContentPrediction
summarization_model: object = SummarizedContent summarization_model: object = SummarizedContent
max_tokens: Optional[int] = os.getenv("MAX_TOKENS")
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict: def to_dict(self) -> dict:

View file

@ -1,6 +1,9 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .Document import Document
from .ChunkerMapping import ChunkerConfig from .ChunkerMapping import ChunkerConfig
from .Document import Document
class AudioDocument(Document): class AudioDocument(Document):
@ -10,12 +13,14 @@ class AudioDocument(Document):
result = get_llm_client().create_transcript(self.raw_data_location) result = get_llm_client().create_transcript(self.raw_data_location)
return result.text return result.text
def read(self, chunk_size: int, chunker: str): def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
# Transcribe the audio file # Transcribe the audio file
text = self.create_transcript() text = self.create_transcript()
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text]) chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -1,3 +1,4 @@
from typing import Optional
from uuid import UUID from uuid import UUID
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -10,5 +11,5 @@ class Document(DataPoint):
mime_type: str mime_type: str
_metadata: dict = {"index_fields": ["name"], "type": "Document"} _metadata: dict = {"index_fields": ["name"], "type": "Document"}
def read(self, chunk_size: int, chunker=str) -> str: def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str:
pass pass

View file

@ -1,6 +1,9 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .Document import Document
from .ChunkerMapping import ChunkerConfig from .ChunkerMapping import ChunkerConfig
from .Document import Document
class ImageDocument(Document): class ImageDocument(Document):
@ -10,11 +13,13 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location) result = get_llm_client().transcribe_image(self.raw_data_location)
return result.choices[0].message.content return result.choices[0].message.content
def read(self, chunk_size: int, chunker: str): def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
# Transcribe the image file # Transcribe the image file
text = self.transcribe_image() text = self.transcribe_image()
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text]) chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -1,12 +1,15 @@
from typing import Optional
from pypdf import PdfReader from pypdf import PdfReader
from .Document import Document
from .ChunkerMapping import ChunkerConfig from .ChunkerMapping import ChunkerConfig
from .Document import Document
class PdfDocument(Document): class PdfDocument(Document):
type: str = "pdf" type: str = "pdf"
def read(self, chunk_size: int, chunker: str): def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
file = PdfReader(self.raw_data_location) file = PdfReader(self.raw_data_location)
def get_text(): def get_text():
@ -15,7 +18,9 @@ class PdfDocument(Document):
yield page_text yield page_text
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text) chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -1,11 +1,13 @@
from .Document import Document from typing import Optional
from .ChunkerMapping import ChunkerConfig from .ChunkerMapping import ChunkerConfig
from .Document import Document
class TextDocument(Document): class TextDocument(Document):
type: str = "text" type: str = "text"
def read(self, chunk_size: int, chunker: str): def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
def get_text(): def get_text():
with open(self.raw_data_location, mode="r", encoding="utf-8") as file: with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
while True: while True:
@ -18,6 +20,8 @@ class TextDocument(Document):
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text) chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -1,14 +1,16 @@
from io import StringIO from io import StringIO
from typing import Optional
from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
from cognee.modules.data.exceptions import UnstructuredLibraryImportError from cognee.modules.data.exceptions import UnstructuredLibraryImportError
from .Document import Document
class UnstructuredDocument(Document): class UnstructuredDocument(Document):
type: str = "unstructured" type: str = "unstructured"
def read(self, chunk_size: int): def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str:
def get_text(): def get_text():
try: try:
from unstructured.partition.auto import partition from unstructured.partition.auto import partition
@ -27,6 +29,6 @@ class UnstructuredDocument(Document):
yield text yield text
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text) chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens)
yield from chunker.read() yield from chunker.read()

View file

@ -7,5 +7,6 @@ class Entity(DataPoint):
name: str name: str
is_a: EntityType is_a: EntityType
description: str description: str
pydantic_type: str = "Entity"
_metadata: dict = {"index_fields": ["name"], "type": "Entity"} _metadata: dict = {"index_fields": ["name"], "type": "Entity"}

View file

@ -5,5 +5,6 @@ class EntityType(DataPoint):
__tablename__ = "entity_type" __tablename__ = "entity_type"
name: str name: str
description: str description: str
pydantic_type: str = "EntityType"
_metadata: dict = {"index_fields": ["name"], "type": "EntityType"} _metadata: dict = {"index_fields": ["name"], "type": "EntityType"}

View file

@ -43,7 +43,7 @@ def format_triplets(edges):
edge_info = {key: value for key, value in edge_attributes.items() if value is not None} edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
# Create the formatted triplet # Create the formatted triplet
triplet = f"Node1: {node1_info}\n" f"Edge: {edge_info}\n" f"Node2: {node2_info}\n\n\n" triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n"
triplets.append(triplet) triplets.append(triplet)
return "".join(triplets) return "".join(triplets)

View file

@ -8,20 +8,27 @@ from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.api.v1.search import SearchType
from cognee.api.v1.search.search_v2 import search
from cognee.infrastructure.llm.get_llm_client import get_llm_client
async def code_description_to_code_part_search(query: str, user: User = None, top_k=2) -> list: async def code_description_to_code_part_search(
query: str, include_docs=False, user: User = None, top_k=5
) -> list:
if user is None: if user is None:
user = await get_default_user() user = await get_default_user()
if user is None: if user is None:
raise PermissionError("No user found in the system. Please create a user.") raise PermissionError("No user found in the system. Please create a user.")
retrieved_codeparts = await code_description_to_code_part(query, user, top_k) retrieved_codeparts = await code_description_to_code_part(query, user, top_k, include_docs)
return retrieved_codeparts return retrieved_codeparts
async def code_description_to_code_part(query: str, user: User, top_k: int) -> List[str]: async def code_description_to_code_part(
query: str, user: User, top_k: int, include_docs: bool = False
) -> List[str]:
""" """
Maps a code description query to relevant code parts using a CodeGraph pipeline. Maps a code description query to relevant code parts using a CodeGraph pipeline.
@ -29,6 +36,7 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
query (str): The search query describing the code parts. query (str): The search query describing the code parts.
user (User): The user performing the search. user (User): The user performing the search.
top_k (int): Number of codegraph descriptions to match ( num of corresponding codeparts will be higher) top_k (int): Number of codegraph descriptions to match ( num of corresponding codeparts will be higher)
include_docs(bool): Boolean showing whether we have the docs in the graph or not
Returns: Returns:
Set[str]: A set of unique code parts matching the query. Set[str]: A set of unique code parts matching the query.
@ -55,21 +63,48 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
) )
try: try:
results = await vector_engine.search("code_summary_text", query_text=query, limit=top_k) if include_docs:
if not results: search_results = await search(SearchType.INSIGHTS, query_text=query)
concatenated_descriptions = " ".join(
obj["description"]
for tpl in search_results
for obj in tpl
if isinstance(obj, dict) and "description" in obj
)
llm_client = get_llm_client()
context_from_documents = await llm_client.acreate_structured_output(
text_input=f"The retrieved context from documents is {concatenated_descriptions}.",
system_prompt="You are a Senior Software Engineer, summarize the context from documents"
f" in a way that it is gonna be provided next to codeparts as context"
f" while trying to solve this github issue connected to the project: {query}]",
response_model=str,
)
code_summaries = await vector_engine.search(
"code_summary_text", query_text=query, limit=top_k
)
if not code_summaries:
logging.warning("No results found for query: '%s' by user: %s", query, user.id) logging.warning("No results found for query: '%s' by user: %s", query, user.id)
return [] return []
memory_fragment = CogneeGraph() memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db( await memory_fragment.project_graph_from_db(
graph_engine, graph_engine,
node_properties_to_project=["id", "type", "text", "source_code"], node_properties_to_project=[
"id",
"type",
"text",
"source_code",
"pydantic_type",
],
edge_properties_to_project=["relationship_name"], edge_properties_to_project=["relationship_name"],
) )
code_pieces_to_return = set() code_pieces_to_return = set()
for node in results: for node in code_summaries:
node_id = str(node.id) node_id = str(node.id)
node_to_search_from = memory_fragment.get_node(node_id) node_to_search_from = memory_fragment.get_node(node_id)
@ -78,9 +113,16 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
continue continue
for code_file in node_to_search_from.get_skeleton_neighbours(): for code_file in node_to_search_from.get_skeleton_neighbours():
for code_file_edge in code_file.get_skeleton_edges(): if code_file.get_attribute("pydantic_type") == "SourceCodeChunk":
if code_file_edge.get_attribute("relationship_name") == "contains": for code_file_edge in code_file.get_skeleton_edges():
code_pieces_to_return.add(code_file_edge.get_destination_node()) if code_file_edge.get_attribute("relationship_name") == "code_chunk_of":
code_pieces_to_return.add(code_file_edge.get_destination_node())
elif code_file.get_attribute("pydantic_type") == "CodePart":
code_pieces_to_return.add(code_file)
elif code_file.get_attribute("pydantic_type") == "CodeFile":
for code_file_edge in code_file.get_skeleton_edges():
if code_file_edge.get_attribute("relationship_name") == "contains":
code_pieces_to_return.add(code_file_edge.get_destination_node())
logging.info( logging.info(
"Search completed for user: %s, query: '%s'. Found %d code pieces.", "Search completed for user: %s, query: '%s'. Found %d code pieces.",
@ -89,7 +131,14 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
len(code_pieces_to_return), len(code_pieces_to_return),
) )
return list(code_pieces_to_return) context = ""
for code_piece in code_pieces_to_return:
context = context + code_piece.get_attribute("source_code")
if include_docs:
context = context_from_documents + context
return context
except Exception as exec_error: except Exception as exec_error:
logging.error( logging.error(

View file

@ -5,12 +5,14 @@ from cognee.infrastructure.engine import DataPoint
class Repository(DataPoint): class Repository(DataPoint):
__tablename__ = "Repository" __tablename__ = "Repository"
path: str path: str
pydantic_type: str = "Repository"
_metadata: dict = {"index_fields": [], "type": "Repository"} _metadata: dict = {"index_fields": [], "type": "Repository"}
class CodeFile(DataPoint): class CodeFile(DataPoint):
__tablename__ = "codefile" __tablename__ = "codefile"
extracted_id: str # actually file path extracted_id: str # actually file path
pydantic_type: str = "CodeFile"
source_code: Optional[str] = None source_code: Optional[str] = None
part_of: Optional[Repository] = None part_of: Optional[Repository] = None
depends_on: Optional[List["CodeFile"]] = None depends_on: Optional[List["CodeFile"]] = None
@ -22,6 +24,7 @@ class CodeFile(DataPoint):
class CodePart(DataPoint): class CodePart(DataPoint):
__tablename__ = "codepart" __tablename__ = "codepart"
# part_of: Optional[CodeFile] = None # part_of: Optional[CodeFile] = None
pydantic_type: str = "CodePart"
source_code: Optional[str] = None source_code: Optional[str] = None
_metadata: dict = {"index_fields": [], "type": "CodePart"} _metadata: dict = {"index_fields": [], "type": "CodePart"}
@ -30,6 +33,7 @@ class SourceCodeChunk(DataPoint):
__tablename__ = "sourcecodechunk" __tablename__ = "sourcecodechunk"
code_chunk_of: Optional[CodePart] = None code_chunk_of: Optional[CodePart] = None
source_code: Optional[str] = None source_code: Optional[str] = None
pydantic_type: str = "SourceCodeChunk"
previous_chunk: Optional["SourceCodeChunk"] = None previous_chunk: Optional["SourceCodeChunk"] = None
_metadata: dict = {"index_fields": ["source_code"], "type": "SourceCodeChunk"} _metadata: dict = {"index_fields": ["source_code"], "type": "SourceCodeChunk"}

View file

@ -231,6 +231,7 @@ class SummarizedContent(BaseModel):
summary: str summary: str
description: str description: str
pydantic_type: str = "SummarizedContent"
class SummarizedFunction(BaseModel): class SummarizedFunction(BaseModel):
@ -239,6 +240,7 @@ class SummarizedFunction(BaseModel):
inputs: Optional[List[str]] = None inputs: Optional[List[str]] = None
outputs: Optional[List[str]] = None outputs: Optional[List[str]] = None
decorators: Optional[List[str]] = None decorators: Optional[List[str]] = None
pydantic_type: str = "SummarizedFunction"
class SummarizedClass(BaseModel): class SummarizedClass(BaseModel):
@ -246,6 +248,7 @@ class SummarizedClass(BaseModel):
description: str description: str
methods: Optional[List[SummarizedFunction]] = None methods: Optional[List[SummarizedFunction]] = None
decorators: Optional[List[str]] = None decorators: Optional[List[str]] = None
pydantic_type: str = "SummarizedClass"
class SummarizedCode(BaseModel): class SummarizedCode(BaseModel):
@ -256,6 +259,7 @@ class SummarizedCode(BaseModel):
classes: List[SummarizedClass] = [] classes: List[SummarizedClass] = []
functions: List[SummarizedFunction] = [] functions: List[SummarizedFunction] = []
workflow_description: Optional[str] = None workflow_description: Optional[str] = None
pydantic_type: str = "SummarizedCode"
class GraphDBType(Enum): class GraphDBType(Enum):

View file

@ -1,10 +1,18 @@
from uuid import uuid5, NAMESPACE_OID from typing import Any, Dict, Iterator, Optional, Union
from typing import Dict, Any, Iterator from uuid import NAMESPACE_OID, uuid5
import tiktoken
from cognee.infrastructure.databases.vector import get_vector_engine
from .chunk_by_sentence import chunk_by_sentence from .chunk_by_sentence import chunk_by_sentence
def chunk_by_paragraph( def chunk_by_paragraph(
data: str, paragraph_length: int = 1024, batch_paragraphs: bool = True data: str,
max_tokens: Optional[Union[int, float]] = None,
paragraph_length: int = 1024,
batch_paragraphs: bool = True,
) -> Iterator[Dict[str, Any]]: ) -> Iterator[Dict[str, Any]]:
""" """
Chunks text by paragraph while preserving exact text reconstruction capability. Chunks text by paragraph while preserving exact text reconstruction capability.
@ -15,16 +23,31 @@ def chunk_by_paragraph(
chunk_index = 0 chunk_index = 0
paragraph_ids = [] paragraph_ids = []
last_cut_type = None last_cut_type = None
current_token_count = 0
if not max_tokens:
max_tokens = float("inf")
vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model
embedding_model = embedding_model.split("/")[-1]
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence( for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(
data, maximum_length=paragraph_length data, maximum_length=paragraph_length
): ):
# Check if this sentence would exceed length limit # Check if this sentence would exceed length limit
if current_word_count > 0 and current_word_count + word_count > paragraph_length:
tokenizer = tiktoken.encoding_for_model(embedding_model)
token_count = len(tokenizer.encode(sentence))
if current_word_count > 0 and (
current_word_count + word_count > paragraph_length
or current_token_count + token_count > max_tokens
):
# Yield current chunk # Yield current chunk
chunk_dict = { chunk_dict = {
"text": current_chunk, "text": current_chunk,
"word_count": current_word_count, "word_count": current_word_count,
"token_count": current_token_count,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk), "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"paragraph_ids": paragraph_ids, "paragraph_ids": paragraph_ids,
"chunk_index": chunk_index, "chunk_index": chunk_index,
@ -37,11 +60,13 @@ def chunk_by_paragraph(
paragraph_ids = [] paragraph_ids = []
current_chunk = "" current_chunk = ""
current_word_count = 0 current_word_count = 0
current_token_count = 0
chunk_index += 1 chunk_index += 1
paragraph_ids.append(paragraph_id) paragraph_ids.append(paragraph_id)
current_chunk += sentence current_chunk += sentence
current_word_count += word_count current_word_count += word_count
current_token_count += token_count
# Handle end of paragraph # Handle end of paragraph
if end_type in ("paragraph_end", "sentence_cut") and not batch_paragraphs: if end_type in ("paragraph_end", "sentence_cut") and not batch_paragraphs:
@ -49,6 +74,7 @@ def chunk_by_paragraph(
chunk_dict = { chunk_dict = {
"text": current_chunk, "text": current_chunk,
"word_count": current_word_count, "word_count": current_word_count,
"token_count": current_token_count,
"paragraph_ids": paragraph_ids, "paragraph_ids": paragraph_ids,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk), "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"chunk_index": chunk_index, "chunk_index": chunk_index,
@ -58,6 +84,7 @@ def chunk_by_paragraph(
paragraph_ids = [] paragraph_ids = []
current_chunk = "" current_chunk = ""
current_word_count = 0 current_word_count = 0
current_token_count = 0
chunk_index += 1 chunk_index += 1
last_cut_type = end_type last_cut_type = end_type
@ -67,6 +94,7 @@ def chunk_by_paragraph(
chunk_dict = { chunk_dict = {
"text": current_chunk, "text": current_chunk,
"word_count": current_word_count, "word_count": current_word_count,
"token_count": current_token_count,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk), "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"paragraph_ids": paragraph_ids, "paragraph_ids": paragraph_ids,
"chunk_index": chunk_index, "chunk_index": chunk_index,

View file

@ -1,9 +1,16 @@
from typing import Optional
from cognee.modules.data.processing.document_types.Document import Document from cognee.modules.data.processing.document_types.Document import Document
async def extract_chunks_from_documents( async def extract_chunks_from_documents(
documents: list[Document], chunk_size: int = 1024, chunker="text_chunker" documents: list[Document],
chunk_size: int = 1024,
chunker="text_chunker",
max_tokens: Optional[int] = None,
): ):
for document in documents: for document in documents:
for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker): for document_chunk in document.read(
chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens
):
yield document_chunk yield document_chunk

View file

@ -1,6 +1,5 @@
from typing import Dict, List from typing import Dict, List
import parso import parso
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -9,7 +9,6 @@ import aiofiles
import jedi import jedi
import parso import parso
from parso.tree import BaseNode from parso.tree import BaseNode
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -29,8 +29,105 @@ async def get_non_py_files(repo_path):
"*.egg-info", "*.egg-info",
} }
ALLOWED_EXTENSIONS = {
".txt",
".md",
".csv",
".json",
".xml",
".yaml",
".yml",
".html",
".css",
".js",
".ts",
".jsx",
".tsx",
".sql",
".log",
".ini",
".toml",
".properties",
".sh",
".bash",
".dockerfile",
".gitignore",
".gitattributes",
".makefile",
".pyproject",
".requirements",
".env",
".pdf",
".doc",
".docx",
".dot",
".dotx",
".rtf",
".wps",
".wpd",
".odt",
".ott",
".ottx",
".txt",
".wp",
".sdw",
".sdx",
".docm",
".dotm",
# Additional extensions for other programming languages
".java",
".c",
".cpp",
".h",
".cs",
".go",
".php",
".rb",
".swift",
".pl",
".lua",
".rs",
".scala",
".kt",
".sh",
".sql",
".v",
".asm",
".pas",
".d",
".ml",
".clj",
".cljs",
".erl",
".ex",
".exs",
".f",
".fs",
".r",
".pyi",
".pdb",
".ipynb",
".rmd",
".cabal",
".hs",
".nim",
".vhdl",
".verilog",
".svelte",
".html",
".css",
".scss",
".less",
".json5",
".yaml",
".yml",
}
def should_process(path): def should_process(path):
return not any(pattern in path for pattern in IGNORED_PATTERNS) _, ext = os.path.splitext(path)
return ext in ALLOWED_EXTENSIONS and not any(
pattern in path for pattern in IGNORED_PATTERNS
)
non_py_files_paths = [ non_py_files_paths = [
os.path.join(root, file) os.path.join(root, file)

View file

@ -5,6 +5,7 @@ from uuid import NAMESPACE_OID, uuid5
import parso import parso
import tiktoken import tiktoken
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
@ -126,6 +127,9 @@ def get_source_code_chunks_from_code_part(
logger.error(f"No source code in CodeFile {code_file_part.id}") logger.error(f"No source code in CodeFile {code_file_part.id}")
return return
vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model
model_name = embedding_model.split("/")[-1]
tokenizer = tiktoken.encoding_for_model(model_name) tokenizer = tiktoken.encoding_for_model(model_name)
max_subchunk_tokens = max(1, int(granularity * max_tokens)) max_subchunk_tokens = max(1, int(granularity * max_tokens))
subchunk_token_counts = _get_subchunk_token_counts( subchunk_token_counts = _get_subchunk_token_counts(
@ -150,7 +154,7 @@ def get_source_code_chunks_from_code_part(
async def get_source_code_chunks( async def get_source_code_chunks(
data_points: list[DataPoint], embedding_model="text-embedding-3-large" data_points: list[DataPoint],
) -> AsyncGenerator[list[DataPoint], None]: ) -> AsyncGenerator[list[DataPoint], None]:
"""Processes code graph datapoints, create SourceCodeChink datapoints.""" """Processes code graph datapoints, create SourceCodeChink datapoints."""
# TODO: Add support for other embedding models, with max_token mapping # TODO: Add support for other embedding models, with max_token mapping
@ -165,9 +169,7 @@ async def get_source_code_chunks(
for code_part in data_point.contains: for code_part in data_point.contains:
try: try:
yield code_part yield code_part
for source_code_chunk in get_source_code_chunks_from_code_part( for source_code_chunk in get_source_code_chunks_from_code_part(code_part):
code_part, model_name=embedding_model
):
yield source_code_chunk yield source_code_chunk
except Exception as e: except Exception as e:
logger.error(f"Error processing code part: {e}") logger.error(f"Error processing code part: {e}")

View file

@ -17,5 +17,6 @@ class CodeSummary(DataPoint):
__tablename__ = "code_summary" __tablename__ = "code_summary"
text: str text: str
summarizes: Union[CodeFile, CodePart, SourceCodeChunk] summarizes: Union[CodeFile, CodePart, SourceCodeChunk]
pydantic_type: str = "CodeSummary"
_metadata: dict = {"index_fields": ["text"], "type": "CodeSummary"} _metadata: dict = {"index_fields": ["text"], "type": "CodeSummary"}

View file

@ -36,12 +36,12 @@ def test_AudioDocument():
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker") GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
): ):
assert ( assert ground_truth["word_count"] == paragraph_data.word_count, (
ground_truth["word_count"] == paragraph_data.word_count f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' )
assert ground_truth["len_text"] == len( assert ground_truth["len_text"] == len(paragraph_data.text), (
paragraph_data.text f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' )
assert ( assert ground_truth["cut_type"] == paragraph_data.cut_type, (
ground_truth["cut_type"] == paragraph_data.cut_type f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' )

View file

@ -25,12 +25,12 @@ def test_ImageDocument():
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker") GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
): ):
assert ( assert ground_truth["word_count"] == paragraph_data.word_count, (
ground_truth["word_count"] == paragraph_data.word_count f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' )
assert ground_truth["len_text"] == len( assert ground_truth["len_text"] == len(paragraph_data.text), (
paragraph_data.text f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' )
assert ( assert ground_truth["cut_type"] == paragraph_data.cut_type, (
ground_truth["cut_type"] == paragraph_data.cut_type f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' )

View file

@ -27,12 +27,12 @@ def test_PdfDocument():
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker") GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker")
): ):
assert ( assert ground_truth["word_count"] == paragraph_data.word_count, (
ground_truth["word_count"] == paragraph_data.word_count f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' )
assert ground_truth["len_text"] == len( assert ground_truth["len_text"] == len(paragraph_data.text), (
paragraph_data.text f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' )
assert ( assert ground_truth["cut_type"] == paragraph_data.cut_type, (
ground_truth["cut_type"] == paragraph_data.cut_type f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' )

View file

@ -39,12 +39,12 @@ def test_TextDocument(input_file, chunk_size):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker") GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker")
): ):
assert ( assert ground_truth["word_count"] == paragraph_data.word_count, (
ground_truth["word_count"] == paragraph_data.word_count f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' )
assert ground_truth["len_text"] == len( assert ground_truth["len_text"] == len(paragraph_data.text), (
paragraph_data.text f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' )
assert ( assert ground_truth["cut_type"] == paragraph_data.cut_type, (
ground_truth["cut_type"] == paragraph_data.cut_type f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' )

View file

@ -68,35 +68,35 @@ def test_UnstructuredDocument():
) )
# Test PPTX # Test PPTX
for paragraph_data in pptx_document.read(chunk_size=1024): for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }" assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }" assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
assert ( assert "sentence_cut" == paragraph_data.cut_type, (
"sentence_cut" == paragraph_data.cut_type f" sentence_cut != {paragraph_data.cut_type = }"
), f" sentence_cut != {paragraph_data.cut_type = }" )
# Test DOCX # Test DOCX
for paragraph_data in docx_document.read(chunk_size=1024): for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }" assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }" assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
assert ( assert "sentence_end" == paragraph_data.cut_type, (
"sentence_end" == paragraph_data.cut_type f" sentence_end != {paragraph_data.cut_type = }"
), f" sentence_end != {paragraph_data.cut_type = }" )
# TEST CSV # TEST CSV
for paragraph_data in csv_document.read(chunk_size=1024): for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"):
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }" assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
assert ( assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text f"Read text doesn't match expected text: {paragraph_data.text}"
), f"Read text doesn't match expected text: {paragraph_data.text}" )
assert ( assert "sentence_cut" == paragraph_data.cut_type, (
"sentence_cut" == paragraph_data.cut_type f" sentence_cut != {paragraph_data.cut_type = }"
), f" sentence_cut != {paragraph_data.cut_type = }" )
# Test XLSX # Test XLSX
for paragraph_data in xlsx_document.read(chunk_size=1024): for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }" assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }" assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
assert ( assert "sentence_cut" == paragraph_data.cut_type, (
"sentence_cut" == paragraph_data.cut_type f" sentence_cut != {paragraph_data.cut_type = }"
), f" sentence_cut != {paragraph_data.cut_type = }" )

View file

@ -30,9 +30,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data") result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found." assert len(result) == 1, "More than one data entity was found."
assert ( assert result[0]["name"] == "Natural_language_processing_copy", (
result[0]["name"] == "Natural_language_processing_copy" "Result name does not match expected value."
), "Result name does not match expected value." )
result = await relational_engine.get_all_data_from_table("datasets") result = await relational_engine.get_all_data_from_table("datasets")
assert len(result) == 2, "Unexpected number of datasets found." assert len(result) == 2, "Unexpected number of datasets found."
@ -61,9 +61,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data") result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found." assert len(result) == 1, "More than one data entity was found."
assert ( assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], (
hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"] "Content hash is not a part of file name."
), "Content hash is not a part of file name." )
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)

View file

@ -85,9 +85,9 @@ async def main():
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists( assert not os.path.exists(get_relational_engine().db_path), (
get_relational_engine().db_path "SQLite relational database is not empty"
), "SQLite relational database is not empty" )
from cognee.infrastructure.databases.graph import get_graph_config from cognee.infrastructure.databases.graph import get_graph_config

View file

@ -82,9 +82,9 @@ async def main():
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists( assert not os.path.exists(get_relational_engine().db_path), (
get_relational_engine().db_path "SQLite relational database is not empty"
), "SQLite relational database is not empty" )
from cognee.infrastructure.databases.graph import get_graph_config from cognee.infrastructure.databases.graph import get_graph_config

View file

@ -24,28 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
data_hash = hashlib.md5(encoded_text).hexdigest() data_hash = hashlib.md5(encoded_text).hexdigest()
# Get data entry from database based on hash contents # Get data entry from database based on hash contents
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one() data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
assert os.path.isfile( assert os.path.isfile(data.raw_data_location), (
data.raw_data_location f"Data location doesn't exist: {data.raw_data_location}"
), f"Data location doesn't exist: {data.raw_data_location}" )
# Test deletion of data along with local files created by cognee # Test deletion of data along with local files created by cognee
await engine.delete_data_entity(data.id) await engine.delete_data_entity(data.id)
assert not os.path.exists( assert not os.path.exists(data.raw_data_location), (
data.raw_data_location f"Data location still exists after deletion: {data.raw_data_location}"
), f"Data location still exists after deletion: {data.raw_data_location}" )
async with engine.get_async_session() as session: async with engine.get_async_session() as session:
# Get data entry from database based on file path # Get data entry from database based on file path
data = ( data = (
await session.scalars(select(Data).where(Data.raw_data_location == file_location)) await session.scalars(select(Data).where(Data.raw_data_location == file_location))
).one() ).one()
assert os.path.isfile( assert os.path.isfile(data.raw_data_location), (
data.raw_data_location f"Data location doesn't exist: {data.raw_data_location}"
), f"Data location doesn't exist: {data.raw_data_location}" )
# Test local files not created by cognee won't get deleted # Test local files not created by cognee won't get deleted
await engine.delete_data_entity(data.id) await engine.delete_data_entity(data.id)
assert os.path.exists( assert os.path.exists(data.raw_data_location), (
data.raw_data_location f"Data location doesn't exists: {data.raw_data_location}"
), f"Data location doesn't exists: {data.raw_data_location}" )
async def test_getting_of_documents(dataset_name_1): async def test_getting_of_documents(dataset_name_1):
@ -54,16 +54,16 @@ async def test_getting_of_documents(dataset_name_1):
user = await get_default_user() user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1]) document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
assert ( assert len(document_ids) == 1, (
len(document_ids) == 1 f"Number of expected documents doesn't match {len(document_ids)} != 1"
), f"Number of expected documents doesn't match {len(document_ids)} != 1" )
# Test getting of documents for search when no dataset is provided # Test getting of documents for search when no dataset is provided
user = await get_default_user() user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id) document_ids = await get_document_ids_for_user(user.id)
assert ( assert len(document_ids) == 2, (
len(document_ids) == 2 f"Number of expected documents doesn't match {len(document_ids)} != 2"
), f"Number of expected documents doesn't match {len(document_ids)} != 2" )
async def main(): async def main():

View file

@ -17,9 +17,9 @@ batch_paragraphs_vals = [True, False]
def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs): def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs):
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs) chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)
reconstructed_text = "".join([chunk["text"] for chunk in chunks]) reconstructed_text = "".join([chunk["text"] for chunk in chunks])
assert ( assert reconstructed_text == input_text, (
reconstructed_text == input_text f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -27,14 +27,18 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
) )
def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
chunks = list(chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)) chunks = list(
chunk_by_paragraph(
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
)
)
chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks]) chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks])
larger_chunks = chunk_lengths[chunk_lengths > paragraph_length] larger_chunks = chunk_lengths[chunk_lengths > paragraph_length]
assert np.all( assert np.all(chunk_lengths <= paragraph_length), (
chunk_lengths <= paragraph_length f"{paragraph_length = }: {larger_chunks} are too large"
), f"{paragraph_length = }: {larger_chunks} are too large" )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -42,8 +46,10 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
) )
def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs): def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs):
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs) chunks = chunk_by_paragraph(
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
)
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
assert np.all( assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
chunk_indices == np.arange(len(chunk_indices)) f"{chunk_indices = } are not monotonically increasing"
), f"{chunk_indices = } are not monotonically increasing" )

View file

@ -49,16 +49,18 @@ Third paragraph is cut and is missing the dot at the end""",
def run_chunking_test(test_text, expected_chunks): def run_chunking_test(test_text, expected_chunks):
chunks = [] chunks = []
for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs=False): for chunk_data in chunk_by_paragraph(
data=test_text, paragraph_length=12, batch_paragraphs=False
):
chunks.append(chunk_data) chunks.append(chunk_data)
assert len(chunks) == 3 assert len(chunks) == 3
for expected_chunks_item, chunk in zip(expected_chunks, chunks): for expected_chunks_item, chunk in zip(expected_chunks, chunks):
for key in ["text", "word_count", "cut_type"]: for key in ["text", "word_count", "cut_type"]:
assert ( assert chunk[key] == expected_chunks_item[key], (
chunk[key] == expected_chunks_item[key] f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
), f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }" )
def test_chunking_whole_text(): def test_chunking_whole_text():

View file

@ -16,9 +16,9 @@ maximum_length_vals = [None, 8, 64]
def test_chunk_by_sentence_isomorphism(input_text, maximum_length): def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
chunks = chunk_by_sentence(input_text, maximum_length) chunks = chunk_by_sentence(input_text, maximum_length)
reconstructed_text = "".join([chunk[1] for chunk in chunks]) reconstructed_text = "".join([chunk[1] for chunk in chunks])
assert ( assert reconstructed_text == input_text, (
reconstructed_text == input_text f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -36,6 +36,6 @@ def test_paragraph_chunk_length(input_text, maximum_length):
chunk_lengths = np.array([len(list(chunk_by_word(chunk[1]))) for chunk in chunks]) chunk_lengths = np.array([len(list(chunk_by_word(chunk[1]))) for chunk in chunks])
larger_chunks = chunk_lengths[chunk_lengths > maximum_length] larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
assert np.all( assert np.all(chunk_lengths <= maximum_length), (
chunk_lengths <= maximum_length f"{maximum_length = }: {larger_chunks} are too large"
), f"{maximum_length = }: {larger_chunks} are too large" )

View file

@ -17,9 +17,9 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS
def test_chunk_by_word_isomorphism(input_text): def test_chunk_by_word_isomorphism(input_text):
chunks = chunk_by_word(input_text) chunks = chunk_by_word(input_text)
reconstructed_text = "".join([chunk[0] for chunk in chunks]) reconstructed_text = "".join([chunk[0] for chunk in chunks])
assert ( assert reconstructed_text == input_text, (
reconstructed_text == input_text f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" )
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -11,8 +11,9 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.description_to_codepart_search import (
from cognee.shared.utils import render_graph code_description_to_code_part_search,
)
from evals.eval_utils import download_github_repo, retrieved_edges_to_string from evals.eval_utils import download_github_repo, retrieved_edges_to_string
@ -32,25 +33,18 @@ def check_install_package(package_name):
return False return False
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS): async def generate_patch_with_cognee(instance):
repo_path = download_github_repo(instance, "../RAW_GIT_REPOS") repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")
pipeline = await run_code_graph_pipeline(repo_path) include_docs = True
async for result in pipeline:
print(result)
print("Here we have the repo under the repo_path")
await render_graph(None, include_labels=True, include_nodes=True)
problem_statement = instance["problem_statement"] problem_statement = instance["problem_statement"]
instructions = read_query_prompt("patch_gen_kg_instructions.txt") instructions = read_query_prompt("patch_gen_kg_instructions.txt")
retrieved_edges = await brute_force_triplet_search( async for result in run_code_graph_pipeline(repo_path, include_docs=include_docs):
problem_statement, top_k=3, collections=["data_point_source_code", "data_point_text"] print(result)
)
retrieved_edges_str = retrieved_edges_to_string(retrieved_edges) retrieved_codeparts = await code_description_to_code_part_search(
problem_statement, include_docs=include_docs
)
prompt = "\n".join( prompt = "\n".join(
[ [
@ -58,8 +52,8 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp
"<patch>", "<patch>",
PATCH_EXAMPLE, PATCH_EXAMPLE,
"</patch>", "</patch>",
"These are the retrieved edges:", "This is the additional context to solve the problem (description from documentation together with codeparts):",
retrieved_edges_str, retrieved_codeparts,
] ]
) )
@ -85,8 +79,6 @@ async def generate_patch_without_cognee(instance, llm_client):
async def get_preds(dataset, with_cognee=True): async def get_preds(dataset, with_cognee=True):
llm_client = get_llm_client()
if with_cognee: if with_cognee:
model_name = "with_cognee" model_name = "with_cognee"
pred_func = generate_patch_with_cognee pred_func = generate_patch_with_cognee
@ -94,17 +86,18 @@ async def get_preds(dataset, with_cognee=True):
model_name = "without_cognee" model_name = "without_cognee"
pred_func = generate_patch_without_cognee pred_func = generate_patch_without_cognee
futures = [(instance["instance_id"], pred_func(instance, llm_client)) for instance in dataset] preds = []
model_patches = await asyncio.gather(*[x[1] for x in futures])
preds = [ for instance in dataset:
{ instance_id = instance["instance_id"]
"instance_id": instance_id, model_patch = await pred_func(instance) # Sequentially await the async function
"model_patch": model_patch, preds.append(
"model_name_or_path": model_name, {
} "instance_id": instance_id,
for (instance_id, _), model_patch in zip(futures, model_patches) "model_patch": model_patch,
] "model_name_or_path": model_name,
}
)
return preds return preds
@ -134,6 +127,7 @@ async def main():
with open(predictions_path, "w") as file: with open(predictions_path, "w") as file:
json.dump(preds, file) json.dump(preds, file)
""" This part is for the evaluation
subprocess.run( subprocess.run(
[ [
"python", "python",
@ -151,6 +145,7 @@ async def main():
"test_run", "test_run",
] ]
) )
"""
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,7 +1,9 @@
import argparse import argparse
import asyncio import asyncio
import logging
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.shared.utils import setup_logging
async def main(repo_path, include_docs): async def main(repo_path, include_docs):
@ -9,7 +11,7 @@ async def main(repo_path, include_docs):
print(result) print(result)
if __name__ == "__main__": def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository") parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
parser.add_argument( parser.add_argument(
@ -18,5 +20,28 @@ if __name__ == "__main__":
default=True, default=True,
help="Whether or not to process non-code files", help="Whether or not to process non-code files",
) )
args = parser.parse_args() parser.add_argument(
asyncio.run(main(args.repo_path, args.include_docs)) "--time",
type=lambda x: x.lower() in ("true", "1"),
default=True,
help="Whether or not to time the pipeline run",
)
return parser.parse_args()
if __name__ == "__main__":
setup_logging(logging.ERROR)
args = parse_args()
if args.time:
import time
start_time = time.time()
asyncio.run(main(args.repo_path, args.include_docs))
end_time = time.time()
print("\n" + "=" * 50)
print(f"Pipeline Execution Time: {end_time - start_time:.2f} seconds")
print("=" * 50 + "\n")
else:
asyncio.run(main(args.repo_path, args.include_docs))