Merge pull request #414 from topoteretes/COG-949

Code graph pipeline improvements and fixes
This commit is contained in:
Vasilije 2025-01-10 14:32:06 +01:00 committed by GitHub
commit f7e808eddd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 255 additions and 54 deletions

View file

@ -3,7 +3,6 @@ import logging
from pathlib import Path
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
@ -54,8 +53,6 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
await cognee.prune.prune_system(metadata=True)
await create_db_and_tables()
embedding_engine = get_embedding_engine()
cognee_config = get_cognify_config()
user = await get_default_user()
@ -63,11 +60,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(get_repo_file_dependencies),
Task(enrich_dependency_graph),
Task(expand_dependency_graph, task_config={"batch_size": 50}),
Task(
get_source_code_chunks,
embedding_model=embedding_engine.model,
task_config={"batch_size": 50},
),
Task(get_source_code_chunks, task_config={"batch_size": 50}),
Task(summarize_code, task_config={"batch_size": 50}),
Task(add_data_points, task_config={"batch_size": 50}),
]
@ -78,7 +71,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents),
Task(extract_chunks_from_documents),
Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens),
Task(
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
),

View file

@ -1,7 +1,9 @@
from uuid import uuid5, NAMESPACE_OID
from typing import Optional
from uuid import NAMESPACE_OID, uuid5
from cognee.tasks.chunks import chunk_by_paragraph
from .models.DocumentChunk import DocumentChunk
from cognee.tasks.chunks import chunk_by_paragraph
class TextChunker:
@ -10,23 +12,36 @@ class TextChunker:
chunk_index = 0
chunk_size = 0
token_count = 0
def __init__(self, document, get_text: callable, chunk_size: int = 1024):
def __init__(
self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024
):
self.document = document
self.max_chunk_size = chunk_size
self.get_text = get_text
self.max_tokens = max_tokens if max_tokens else float("inf")
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens
return word_count_fits and token_count_fits
def read(self):
paragraph_chunks = []
for content_text in self.get_text():
for chunk_data in chunk_by_paragraph(
content_text,
self.max_tokens,
self.max_chunk_size,
batch_paragraphs=True,
):
if self.chunk_size + chunk_data["word_count"] <= self.max_chunk_size:
if self.check_word_count_and_token_count(
self.chunk_size, self.token_count, chunk_data
):
paragraph_chunks.append(chunk_data)
self.chunk_size += chunk_data["word_count"]
self.token_count += chunk_data["token_count"]
else:
if len(paragraph_chunks) == 0:
yield DocumentChunk(
@ -66,6 +81,7 @@ class TextChunker:
print(e)
paragraph_chunks = [chunk_data]
self.chunk_size = chunk_data["word_count"]
self.token_count = chunk_data["token_count"]
self.chunk_index += 1

View file

@ -1,12 +1,14 @@
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.shared.data_models import DefaultContentPrediction, SummarizedContent
from typing import Optional
import os
class CognifyConfig(BaseSettings):
classification_model: object = DefaultContentPrediction
summarization_model: object = SummarizedContent
max_tokens: Optional[int] = os.getenv("MAX_TOKENS")
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:

View file

@ -1,6 +1,9 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .Document import Document
from .ChunkerMapping import ChunkerConfig
from .Document import Document
class AudioDocument(Document):
@ -10,12 +13,14 @@ class AudioDocument(Document):
result = get_llm_client().create_transcript(self.raw_data_location)
return result.text
def read(self, chunk_size: int, chunker: str):
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
# Transcribe the audio file
text = self.create_transcript()
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
)
yield from chunker.read()

View file

@ -1,3 +1,4 @@
from typing import Optional
from uuid import UUID
from cognee.infrastructure.engine import DataPoint
@ -10,5 +11,5 @@ class Document(DataPoint):
mime_type: str
_metadata: dict = {"index_fields": ["name"], "type": "Document"}
def read(self, chunk_size: int, chunker=str) -> str:
def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str:
pass

View file

@ -1,6 +1,9 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .Document import Document
from .ChunkerMapping import ChunkerConfig
from .Document import Document
class ImageDocument(Document):
@ -10,11 +13,13 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location)
return result.choices[0].message.content
def read(self, chunk_size: int, chunker: str):
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
# Transcribe the image file
text = self.transcribe_image()
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
)
yield from chunker.read()

View file

@ -1,12 +1,15 @@
from typing import Optional
from pypdf import PdfReader
from .Document import Document
from .ChunkerMapping import ChunkerConfig
from .Document import Document
class PdfDocument(Document):
type: str = "pdf"
def read(self, chunk_size: int, chunker: str):
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
file = PdfReader(self.raw_data_location)
def get_text():
@ -15,7 +18,9 @@ class PdfDocument(Document):
yield page_text
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
)
yield from chunker.read()

View file

@ -1,11 +1,13 @@
from .Document import Document
from typing import Optional
from .ChunkerMapping import ChunkerConfig
from .Document import Document
class TextDocument(Document):
type: str = "text"
def read(self, chunk_size: int, chunker: str):
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
def get_text():
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
while True:
@ -18,6 +20,8 @@ class TextDocument(Document):
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
)
yield from chunker.read()

View file

@ -1,14 +1,16 @@
from io import StringIO
from typing import Optional
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
from cognee.modules.data.exceptions import UnstructuredLibraryImportError
from .Document import Document
class UnstructuredDocument(Document):
type: str = "unstructured"
def read(self, chunk_size: int):
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str:
def get_text():
try:
from unstructured.partition.auto import partition
@ -27,6 +29,6 @@ class UnstructuredDocument(Document):
yield text
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text)
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens)
yield from chunker.read()

View file

@ -1,10 +1,18 @@
from uuid import uuid5, NAMESPACE_OID
from typing import Dict, Any, Iterator
from typing import Any, Dict, Iterator, Optional, Union
from uuid import NAMESPACE_OID, uuid5
import tiktoken
from cognee.infrastructure.databases.vector import get_vector_engine
from .chunk_by_sentence import chunk_by_sentence
def chunk_by_paragraph(
data: str, paragraph_length: int = 1024, batch_paragraphs: bool = True
data: str,
max_tokens: Optional[Union[int, float]] = None,
paragraph_length: int = 1024,
batch_paragraphs: bool = True,
) -> Iterator[Dict[str, Any]]:
"""
Chunks text by paragraph while preserving exact text reconstruction capability.
@ -15,16 +23,31 @@ def chunk_by_paragraph(
chunk_index = 0
paragraph_ids = []
last_cut_type = None
current_token_count = 0
if not max_tokens:
max_tokens = float("inf")
vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model
embedding_model = embedding_model.split("/")[-1]
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(
data, maximum_length=paragraph_length
):
# Check if this sentence would exceed length limit
if current_word_count > 0 and current_word_count + word_count > paragraph_length:
tokenizer = tiktoken.encoding_for_model(embedding_model)
token_count = len(tokenizer.encode(sentence))
if current_word_count > 0 and (
current_word_count + word_count > paragraph_length
or current_token_count + token_count > max_tokens
):
# Yield current chunk
chunk_dict = {
"text": current_chunk,
"word_count": current_word_count,
"token_count": current_token_count,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"paragraph_ids": paragraph_ids,
"chunk_index": chunk_index,
@ -37,11 +60,13 @@ def chunk_by_paragraph(
paragraph_ids = []
current_chunk = ""
current_word_count = 0
current_token_count = 0
chunk_index += 1
paragraph_ids.append(paragraph_id)
current_chunk += sentence
current_word_count += word_count
current_token_count += token_count
# Handle end of paragraph
if end_type in ("paragraph_end", "sentence_cut") and not batch_paragraphs:
@ -49,6 +74,7 @@ def chunk_by_paragraph(
chunk_dict = {
"text": current_chunk,
"word_count": current_word_count,
"token_count": current_token_count,
"paragraph_ids": paragraph_ids,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"chunk_index": chunk_index,
@ -58,6 +84,7 @@ def chunk_by_paragraph(
paragraph_ids = []
current_chunk = ""
current_word_count = 0
current_token_count = 0
chunk_index += 1
last_cut_type = end_type
@ -67,6 +94,7 @@ def chunk_by_paragraph(
chunk_dict = {
"text": current_chunk,
"word_count": current_word_count,
"token_count": current_token_count,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"paragraph_ids": paragraph_ids,
"chunk_index": chunk_index,

View file

@ -1,9 +1,16 @@
from typing import Optional
from cognee.modules.data.processing.document_types.Document import Document
async def extract_chunks_from_documents(
documents: list[Document], chunk_size: int = 1024, chunker="text_chunker"
documents: list[Document],
chunk_size: int = 1024,
chunker="text_chunker",
max_tokens: Optional[int] = None,
):
for document in documents:
for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker):
for document_chunk in document.read(
chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens
):
yield document_chunk

View file

@ -29,8 +29,105 @@ async def get_non_py_files(repo_path):
"*.egg-info",
}
ALLOWED_EXTENSIONS = {
".txt",
".md",
".csv",
".json",
".xml",
".yaml",
".yml",
".html",
".css",
".js",
".ts",
".jsx",
".tsx",
".sql",
".log",
".ini",
".toml",
".properties",
".sh",
".bash",
".dockerfile",
".gitignore",
".gitattributes",
".makefile",
".pyproject",
".requirements",
".env",
".pdf",
".doc",
".docx",
".dot",
".dotx",
".rtf",
".wps",
".wpd",
".odt",
".ott",
".ottx",
".txt",
".wp",
".sdw",
".sdx",
".docm",
".dotm",
# Additional extensions for other programming languages
".java",
".c",
".cpp",
".h",
".cs",
".go",
".php",
".rb",
".swift",
".pl",
".lua",
".rs",
".scala",
".kt",
".sh",
".sql",
".v",
".asm",
".pas",
".d",
".ml",
".clj",
".cljs",
".erl",
".ex",
".exs",
".f",
".fs",
".r",
".pyi",
".pdb",
".ipynb",
".rmd",
".cabal",
".hs",
".nim",
".vhdl",
".verilog",
".svelte",
".html",
".css",
".scss",
".less",
".json5",
".yaml",
".yml",
}
def should_process(path):
return not any(pattern in path for pattern in IGNORED_PATTERNS)
_, ext = os.path.splitext(path)
return ext in ALLOWED_EXTENSIONS and not any(
pattern in path for pattern in IGNORED_PATTERNS
)
non_py_files_paths = [
os.path.join(root, file)

View file

@ -5,6 +5,7 @@ from uuid import NAMESPACE_OID, uuid5
import parso
import tiktoken
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
@ -126,6 +127,9 @@ def get_source_code_chunks_from_code_part(
logger.error(f"No source code in CodeFile {code_file_part.id}")
return
vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model
model_name = embedding_model.split("/")[-1]
tokenizer = tiktoken.encoding_for_model(model_name)
max_subchunk_tokens = max(1, int(granularity * max_tokens))
subchunk_token_counts = _get_subchunk_token_counts(
@ -150,7 +154,7 @@ def get_source_code_chunks_from_code_part(
async def get_source_code_chunks(
data_points: list[DataPoint], embedding_model="text-embedding-3-large"
data_points: list[DataPoint],
) -> AsyncGenerator[list[DataPoint], None]:
"""Processes code graph datapoints, create SourceCodeChink datapoints."""
# TODO: Add support for other embedding models, with max_token mapping
@ -165,9 +169,7 @@ async def get_source_code_chunks(
for code_part in data_point.contains:
try:
yield code_part
for source_code_chunk in get_source_code_chunks_from_code_part(
code_part, model_name=embedding_model
):
for source_code_chunk in get_source_code_chunks_from_code_part(code_part):
yield source_code_chunk
except Exception as e:
logger.error(f"Error processing code part: {e}")

View file

@ -68,7 +68,7 @@ def test_UnstructuredDocument():
)
# Test PPTX
for paragraph_data in pptx_document.read(chunk_size=1024):
for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
assert (
@ -76,7 +76,7 @@ def test_UnstructuredDocument():
), f" sentence_cut != {paragraph_data.cut_type = }"
# Test DOCX
for paragraph_data in docx_document.read(chunk_size=1024):
for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
assert (
@ -84,7 +84,7 @@ def test_UnstructuredDocument():
), f" sentence_end != {paragraph_data.cut_type = }"
# TEST CSV
for paragraph_data in csv_document.read(chunk_size=1024):
for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"):
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
assert (
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
@ -94,7 +94,7 @@ def test_UnstructuredDocument():
), f" sentence_cut != {paragraph_data.cut_type = }"
# Test XLSX
for paragraph_data in xlsx_document.read(chunk_size=1024):
for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
assert (

View file

@ -27,7 +27,11 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
)
def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
chunks = list(chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs))
chunks = list(
chunk_by_paragraph(
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
)
)
chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks])
@ -42,7 +46,9 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
)
def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs):
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)
chunks = chunk_by_paragraph(
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
)
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
assert np.all(
chunk_indices == np.arange(len(chunk_indices))

View file

@ -49,7 +49,9 @@ Third paragraph is cut and is missing the dot at the end""",
def run_chunking_test(test_text, expected_chunks):
chunks = []
for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs=False):
for chunk_data in chunk_by_paragraph(
data=test_text, paragraph_length=12, batch_paragraphs=False
):
chunks.append(chunk_data)
assert len(chunks) == 3

View file

@ -34,9 +34,8 @@ def check_install_package(package_name):
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")
pipeline = await run_code_graph_pipeline(repo_path)
async for result in pipeline:
async for result in run_code_graph_pipeline(repo_path, include_docs=True):
print(result)
print("Here we have the repo under the repo_path")
@ -47,7 +46,9 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp
instructions = read_query_prompt("patch_gen_kg_instructions.txt")
retrieved_edges = await brute_force_triplet_search(
problem_statement, top_k=3, collections=["data_point_source_code", "data_point_text"]
problem_statement,
top_k=3,
collections=["code_summary_text"],
)
retrieved_edges_str = retrieved_edges_to_string(retrieved_edges)

View file

@ -1,7 +1,9 @@
import argparse
import asyncio
import logging
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.shared.utils import setup_logging
async def main(repo_path, include_docs):
@ -9,7 +11,7 @@ async def main(repo_path, include_docs):
print(result)
if __name__ == "__main__":
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
parser.add_argument(
@ -18,5 +20,28 @@ if __name__ == "__main__":
default=True,
help="Whether or not to process non-code files",
)
args = parser.parse_args()
asyncio.run(main(args.repo_path, args.include_docs))
parser.add_argument(
"--time",
type=lambda x: x.lower() in ("true", "1"),
default=True,
help="Whether or not to time the pipeline run",
)
return parser.parse_args()
if __name__ == "__main__":
setup_logging(logging.ERROR)
args = parse_args()
if args.time:
import time
start_time = time.time()
asyncio.run(main(args.repo_path, args.include_docs))
end_time = time.time()
print("\n" + "=" * 50)
print(f"Pipeline Execution Time: {end_time - start_time:.2f} seconds")
print("=" * 50 + "\n")
else:
asyncio.run(main(args.repo_path, args.include_docs))