Merge pull request #414 from topoteretes/COG-949

Code graph pipeline improvements and fixes
This commit is contained in:
Vasilije 2025-01-10 14:32:06 +01:00 committed by GitHub
commit f7e808eddd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 255 additions and 54 deletions

View file

@ -3,7 +3,6 @@ import logging
from pathlib import Path from pathlib import Path
from cognee.base_config import get_base_config from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.modules.cognify.config import get_cognify_config from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task from cognee.modules.pipelines.tasks.Task import Task
@ -54,8 +53,6 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
await create_db_and_tables() await create_db_and_tables()
embedding_engine = get_embedding_engine()
cognee_config = get_cognify_config() cognee_config = get_cognify_config()
user = await get_default_user() user = await get_default_user()
@ -63,11 +60,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(get_repo_file_dependencies), Task(get_repo_file_dependencies),
Task(enrich_dependency_graph), Task(enrich_dependency_graph),
Task(expand_dependency_graph, task_config={"batch_size": 50}), Task(expand_dependency_graph, task_config={"batch_size": 50}),
Task( Task(get_source_code_chunks, task_config={"batch_size": 50}),
get_source_code_chunks,
embedding_model=embedding_engine.model,
task_config={"batch_size": 50},
),
Task(summarize_code, task_config={"batch_size": 50}), Task(summarize_code, task_config={"batch_size": 50}),
Task(add_data_points, task_config={"batch_size": 50}), Task(add_data_points, task_config={"batch_size": 50}),
] ]
@ -78,7 +71,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user), Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
Task(get_data_list_for_user, dataset_name="repo_docs", user=user), Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents), Task(classify_documents),
Task(extract_chunks_from_documents), Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens),
Task( Task(
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50} extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
), ),

View file

@ -1,7 +1,9 @@
from uuid import uuid5, NAMESPACE_OID from typing import Optional
from uuid import NAMESPACE_OID, uuid5
from cognee.tasks.chunks import chunk_by_paragraph
from .models.DocumentChunk import DocumentChunk from .models.DocumentChunk import DocumentChunk
from cognee.tasks.chunks import chunk_by_paragraph
class TextChunker: class TextChunker:
@ -10,23 +12,36 @@ class TextChunker:
chunk_index = 0 chunk_index = 0
chunk_size = 0 chunk_size = 0
token_count = 0
def __init__(self, document, get_text: callable, chunk_size: int = 1024): def __init__(
self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024
):
self.document = document self.document = document
self.max_chunk_size = chunk_size self.max_chunk_size = chunk_size
self.get_text = get_text self.get_text = get_text
self.max_tokens = max_tokens if max_tokens else float("inf")
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens
return word_count_fits and token_count_fits
def read(self): def read(self):
paragraph_chunks = [] paragraph_chunks = []
for content_text in self.get_text(): for content_text in self.get_text():
for chunk_data in chunk_by_paragraph( for chunk_data in chunk_by_paragraph(
content_text, content_text,
self.max_tokens,
self.max_chunk_size, self.max_chunk_size,
batch_paragraphs=True, batch_paragraphs=True,
): ):
if self.chunk_size + chunk_data["word_count"] <= self.max_chunk_size: if self.check_word_count_and_token_count(
self.chunk_size, self.token_count, chunk_data
):
paragraph_chunks.append(chunk_data) paragraph_chunks.append(chunk_data)
self.chunk_size += chunk_data["word_count"] self.chunk_size += chunk_data["word_count"]
self.token_count += chunk_data["token_count"]
else: else:
if len(paragraph_chunks) == 0: if len(paragraph_chunks) == 0:
yield DocumentChunk( yield DocumentChunk(
@ -66,6 +81,7 @@ class TextChunker:
print(e) print(e)
paragraph_chunks = [chunk_data] paragraph_chunks = [chunk_data]
self.chunk_size = chunk_data["word_count"] self.chunk_size = chunk_data["word_count"]
self.token_count = chunk_data["token_count"]
self.chunk_index += 1 self.chunk_index += 1

View file

@ -1,12 +1,14 @@
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.shared.data_models import DefaultContentPrediction, SummarizedContent from cognee.shared.data_models import DefaultContentPrediction, SummarizedContent
from typing import Optional
import os
class CognifyConfig(BaseSettings): class CognifyConfig(BaseSettings):
classification_model: object = DefaultContentPrediction classification_model: object = DefaultContentPrediction
summarization_model: object = SummarizedContent summarization_model: object = SummarizedContent
max_tokens: Optional[int] = os.getenv("MAX_TOKENS")
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict: def to_dict(self) -> dict:

View file

@ -1,6 +1,9 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .Document import Document
from .ChunkerMapping import ChunkerConfig from .ChunkerMapping import ChunkerConfig
from .Document import Document
class AudioDocument(Document): class AudioDocument(Document):
@ -10,12 +13,14 @@ class AudioDocument(Document):
result = get_llm_client().create_transcript(self.raw_data_location) result = get_llm_client().create_transcript(self.raw_data_location)
return result.text return result.text
def read(self, chunk_size: int, chunker: str): def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
# Transcribe the audio file # Transcribe the audio file
text = self.create_transcript() text = self.create_transcript()
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text]) chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -1,3 +1,4 @@
from typing import Optional
from uuid import UUID from uuid import UUID
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -10,5 +11,5 @@ class Document(DataPoint):
mime_type: str mime_type: str
_metadata: dict = {"index_fields": ["name"], "type": "Document"} _metadata: dict = {"index_fields": ["name"], "type": "Document"}
def read(self, chunk_size: int, chunker=str) -> str: def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str:
pass pass

View file

@ -1,6 +1,9 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .Document import Document
from .ChunkerMapping import ChunkerConfig from .ChunkerMapping import ChunkerConfig
from .Document import Document
class ImageDocument(Document): class ImageDocument(Document):
@ -10,11 +13,13 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location) result = get_llm_client().transcribe_image(self.raw_data_location)
return result.choices[0].message.content return result.choices[0].message.content
def read(self, chunk_size: int, chunker: str): def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
# Transcribe the image file # Transcribe the image file
text = self.transcribe_image() text = self.transcribe_image()
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text]) chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -1,12 +1,15 @@
from typing import Optional
from pypdf import PdfReader from pypdf import PdfReader
from .Document import Document
from .ChunkerMapping import ChunkerConfig from .ChunkerMapping import ChunkerConfig
from .Document import Document
class PdfDocument(Document): class PdfDocument(Document):
type: str = "pdf" type: str = "pdf"
def read(self, chunk_size: int, chunker: str): def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
file = PdfReader(self.raw_data_location) file = PdfReader(self.raw_data_location)
def get_text(): def get_text():
@ -15,7 +18,9 @@ class PdfDocument(Document):
yield page_text yield page_text
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text) chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -1,11 +1,13 @@
from .Document import Document from typing import Optional
from .ChunkerMapping import ChunkerConfig from .ChunkerMapping import ChunkerConfig
from .Document import Document
class TextDocument(Document): class TextDocument(Document):
type: str = "text" type: str = "text"
def read(self, chunk_size: int, chunker: str): def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
def get_text(): def get_text():
with open(self.raw_data_location, mode="r", encoding="utf-8") as file: with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
while True: while True:
@ -18,6 +20,8 @@ class TextDocument(Document):
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text) chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -1,14 +1,16 @@
from io import StringIO from io import StringIO
from typing import Optional
from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
from cognee.modules.data.exceptions import UnstructuredLibraryImportError from cognee.modules.data.exceptions import UnstructuredLibraryImportError
from .Document import Document
class UnstructuredDocument(Document): class UnstructuredDocument(Document):
type: str = "unstructured" type: str = "unstructured"
def read(self, chunk_size: int): def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str:
def get_text(): def get_text():
try: try:
from unstructured.partition.auto import partition from unstructured.partition.auto import partition
@ -27,6 +29,6 @@ class UnstructuredDocument(Document):
yield text yield text
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text) chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens)
yield from chunker.read() yield from chunker.read()

View file

@ -1,10 +1,18 @@
from uuid import uuid5, NAMESPACE_OID from typing import Any, Dict, Iterator, Optional, Union
from typing import Dict, Any, Iterator from uuid import NAMESPACE_OID, uuid5
import tiktoken
from cognee.infrastructure.databases.vector import get_vector_engine
from .chunk_by_sentence import chunk_by_sentence from .chunk_by_sentence import chunk_by_sentence
def chunk_by_paragraph( def chunk_by_paragraph(
data: str, paragraph_length: int = 1024, batch_paragraphs: bool = True data: str,
max_tokens: Optional[Union[int, float]] = None,
paragraph_length: int = 1024,
batch_paragraphs: bool = True,
) -> Iterator[Dict[str, Any]]: ) -> Iterator[Dict[str, Any]]:
""" """
Chunks text by paragraph while preserving exact text reconstruction capability. Chunks text by paragraph while preserving exact text reconstruction capability.
@ -15,16 +23,31 @@ def chunk_by_paragraph(
chunk_index = 0 chunk_index = 0
paragraph_ids = [] paragraph_ids = []
last_cut_type = None last_cut_type = None
current_token_count = 0
if not max_tokens:
max_tokens = float("inf")
vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model
embedding_model = embedding_model.split("/")[-1]
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence( for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(
data, maximum_length=paragraph_length data, maximum_length=paragraph_length
): ):
# Check if this sentence would exceed length limit # Check if this sentence would exceed length limit
if current_word_count > 0 and current_word_count + word_count > paragraph_length:
tokenizer = tiktoken.encoding_for_model(embedding_model)
token_count = len(tokenizer.encode(sentence))
if current_word_count > 0 and (
current_word_count + word_count > paragraph_length
or current_token_count + token_count > max_tokens
):
# Yield current chunk # Yield current chunk
chunk_dict = { chunk_dict = {
"text": current_chunk, "text": current_chunk,
"word_count": current_word_count, "word_count": current_word_count,
"token_count": current_token_count,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk), "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"paragraph_ids": paragraph_ids, "paragraph_ids": paragraph_ids,
"chunk_index": chunk_index, "chunk_index": chunk_index,
@ -37,11 +60,13 @@ def chunk_by_paragraph(
paragraph_ids = [] paragraph_ids = []
current_chunk = "" current_chunk = ""
current_word_count = 0 current_word_count = 0
current_token_count = 0
chunk_index += 1 chunk_index += 1
paragraph_ids.append(paragraph_id) paragraph_ids.append(paragraph_id)
current_chunk += sentence current_chunk += sentence
current_word_count += word_count current_word_count += word_count
current_token_count += token_count
# Handle end of paragraph # Handle end of paragraph
if end_type in ("paragraph_end", "sentence_cut") and not batch_paragraphs: if end_type in ("paragraph_end", "sentence_cut") and not batch_paragraphs:
@ -49,6 +74,7 @@ def chunk_by_paragraph(
chunk_dict = { chunk_dict = {
"text": current_chunk, "text": current_chunk,
"word_count": current_word_count, "word_count": current_word_count,
"token_count": current_token_count,
"paragraph_ids": paragraph_ids, "paragraph_ids": paragraph_ids,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk), "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"chunk_index": chunk_index, "chunk_index": chunk_index,
@ -58,6 +84,7 @@ def chunk_by_paragraph(
paragraph_ids = [] paragraph_ids = []
current_chunk = "" current_chunk = ""
current_word_count = 0 current_word_count = 0
current_token_count = 0
chunk_index += 1 chunk_index += 1
last_cut_type = end_type last_cut_type = end_type
@ -67,6 +94,7 @@ def chunk_by_paragraph(
chunk_dict = { chunk_dict = {
"text": current_chunk, "text": current_chunk,
"word_count": current_word_count, "word_count": current_word_count,
"token_count": current_token_count,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk), "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"paragraph_ids": paragraph_ids, "paragraph_ids": paragraph_ids,
"chunk_index": chunk_index, "chunk_index": chunk_index,

View file

@ -1,9 +1,16 @@
from typing import Optional
from cognee.modules.data.processing.document_types.Document import Document from cognee.modules.data.processing.document_types.Document import Document
async def extract_chunks_from_documents( async def extract_chunks_from_documents(
documents: list[Document], chunk_size: int = 1024, chunker="text_chunker" documents: list[Document],
chunk_size: int = 1024,
chunker="text_chunker",
max_tokens: Optional[int] = None,
): ):
for document in documents: for document in documents:
for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker): for document_chunk in document.read(
chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens
):
yield document_chunk yield document_chunk

View file

@ -29,8 +29,105 @@ async def get_non_py_files(repo_path):
"*.egg-info", "*.egg-info",
} }
ALLOWED_EXTENSIONS = {
".txt",
".md",
".csv",
".json",
".xml",
".yaml",
".yml",
".html",
".css",
".js",
".ts",
".jsx",
".tsx",
".sql",
".log",
".ini",
".toml",
".properties",
".sh",
".bash",
".dockerfile",
".gitignore",
".gitattributes",
".makefile",
".pyproject",
".requirements",
".env",
".pdf",
".doc",
".docx",
".dot",
".dotx",
".rtf",
".wps",
".wpd",
".odt",
".ott",
".ottx",
".txt",
".wp",
".sdw",
".sdx",
".docm",
".dotm",
# Additional extensions for other programming languages
".java",
".c",
".cpp",
".h",
".cs",
".go",
".php",
".rb",
".swift",
".pl",
".lua",
".rs",
".scala",
".kt",
".sh",
".sql",
".v",
".asm",
".pas",
".d",
".ml",
".clj",
".cljs",
".erl",
".ex",
".exs",
".f",
".fs",
".r",
".pyi",
".pdb",
".ipynb",
".rmd",
".cabal",
".hs",
".nim",
".vhdl",
".verilog",
".svelte",
".html",
".css",
".scss",
".less",
".json5",
".yaml",
".yml",
}
def should_process(path): def should_process(path):
return not any(pattern in path for pattern in IGNORED_PATTERNS) _, ext = os.path.splitext(path)
return ext in ALLOWED_EXTENSIONS and not any(
pattern in path for pattern in IGNORED_PATTERNS
)
non_py_files_paths = [ non_py_files_paths = [
os.path.join(root, file) os.path.join(root, file)

View file

@ -5,6 +5,7 @@ from uuid import NAMESPACE_OID, uuid5
import parso import parso
import tiktoken import tiktoken
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
@ -126,6 +127,9 @@ def get_source_code_chunks_from_code_part(
logger.error(f"No source code in CodeFile {code_file_part.id}") logger.error(f"No source code in CodeFile {code_file_part.id}")
return return
vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model
model_name = embedding_model.split("/")[-1]
tokenizer = tiktoken.encoding_for_model(model_name) tokenizer = tiktoken.encoding_for_model(model_name)
max_subchunk_tokens = max(1, int(granularity * max_tokens)) max_subchunk_tokens = max(1, int(granularity * max_tokens))
subchunk_token_counts = _get_subchunk_token_counts( subchunk_token_counts = _get_subchunk_token_counts(
@ -150,7 +154,7 @@ def get_source_code_chunks_from_code_part(
async def get_source_code_chunks( async def get_source_code_chunks(
data_points: list[DataPoint], embedding_model="text-embedding-3-large" data_points: list[DataPoint],
) -> AsyncGenerator[list[DataPoint], None]: ) -> AsyncGenerator[list[DataPoint], None]:
"""Processes code graph datapoints, create SourceCodeChink datapoints.""" """Processes code graph datapoints, create SourceCodeChink datapoints."""
# TODO: Add support for other embedding models, with max_token mapping # TODO: Add support for other embedding models, with max_token mapping
@ -165,9 +169,7 @@ async def get_source_code_chunks(
for code_part in data_point.contains: for code_part in data_point.contains:
try: try:
yield code_part yield code_part
for source_code_chunk in get_source_code_chunks_from_code_part( for source_code_chunk in get_source_code_chunks_from_code_part(code_part):
code_part, model_name=embedding_model
):
yield source_code_chunk yield source_code_chunk
except Exception as e: except Exception as e:
logger.error(f"Error processing code part: {e}") logger.error(f"Error processing code part: {e}")

View file

@ -68,7 +68,7 @@ def test_UnstructuredDocument():
) )
# Test PPTX # Test PPTX
for paragraph_data in pptx_document.read(chunk_size=1024): for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }" assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }" assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
assert ( assert (
@ -76,7 +76,7 @@ def test_UnstructuredDocument():
), f" sentence_cut != {paragraph_data.cut_type = }" ), f" sentence_cut != {paragraph_data.cut_type = }"
# Test DOCX # Test DOCX
for paragraph_data in docx_document.read(chunk_size=1024): for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }" assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }" assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
assert ( assert (
@ -84,7 +84,7 @@ def test_UnstructuredDocument():
), f" sentence_end != {paragraph_data.cut_type = }" ), f" sentence_end != {paragraph_data.cut_type = }"
# TEST CSV # TEST CSV
for paragraph_data in csv_document.read(chunk_size=1024): for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"):
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }" assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
assert ( assert (
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
@ -94,7 +94,7 @@ def test_UnstructuredDocument():
), f" sentence_cut != {paragraph_data.cut_type = }" ), f" sentence_cut != {paragraph_data.cut_type = }"
# Test XLSX # Test XLSX
for paragraph_data in xlsx_document.read(chunk_size=1024): for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }" assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }" assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
assert ( assert (

View file

@ -27,7 +27,11 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
) )
def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
chunks = list(chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)) chunks = list(
chunk_by_paragraph(
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
)
)
chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks]) chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks])
@ -42,7 +46,9 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
) )
def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs): def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs):
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs) chunks = chunk_by_paragraph(
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
)
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
assert np.all( assert np.all(
chunk_indices == np.arange(len(chunk_indices)) chunk_indices == np.arange(len(chunk_indices))

View file

@ -49,7 +49,9 @@ Third paragraph is cut and is missing the dot at the end""",
def run_chunking_test(test_text, expected_chunks): def run_chunking_test(test_text, expected_chunks):
chunks = [] chunks = []
for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs=False): for chunk_data in chunk_by_paragraph(
data=test_text, paragraph_length=12, batch_paragraphs=False
):
chunks.append(chunk_data) chunks.append(chunk_data)
assert len(chunks) == 3 assert len(chunks) == 3

View file

@ -34,9 +34,8 @@ def check_install_package(package_name):
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS): async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
repo_path = download_github_repo(instance, "../RAW_GIT_REPOS") repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")
pipeline = await run_code_graph_pipeline(repo_path)
async for result in pipeline: async for result in run_code_graph_pipeline(repo_path, include_docs=True):
print(result) print(result)
print("Here we have the repo under the repo_path") print("Here we have the repo under the repo_path")
@ -47,7 +46,9 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp
instructions = read_query_prompt("patch_gen_kg_instructions.txt") instructions = read_query_prompt("patch_gen_kg_instructions.txt")
retrieved_edges = await brute_force_triplet_search( retrieved_edges = await brute_force_triplet_search(
problem_statement, top_k=3, collections=["data_point_source_code", "data_point_text"] problem_statement,
top_k=3,
collections=["code_summary_text"],
) )
retrieved_edges_str = retrieved_edges_to_string(retrieved_edges) retrieved_edges_str = retrieved_edges_to_string(retrieved_edges)

View file

@ -1,7 +1,9 @@
import argparse import argparse
import asyncio import asyncio
import logging
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.shared.utils import setup_logging
async def main(repo_path, include_docs): async def main(repo_path, include_docs):
@ -9,7 +11,7 @@ async def main(repo_path, include_docs):
print(result) print(result)
if __name__ == "__main__": def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository") parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
parser.add_argument( parser.add_argument(
@ -18,5 +20,28 @@ if __name__ == "__main__":
default=True, default=True,
help="Whether or not to process non-code files", help="Whether or not to process non-code files",
) )
args = parser.parse_args() parser.add_argument(
asyncio.run(main(args.repo_path, args.include_docs)) "--time",
type=lambda x: x.lower() in ("true", "1"),
default=True,
help="Whether or not to time the pipeline run",
)
return parser.parse_args()
if __name__ == "__main__":
setup_logging(logging.ERROR)
args = parse_args()
if args.time:
import time
start_time = time.time()
asyncio.run(main(args.repo_path, args.include_docs))
end_time = time.time()
print("\n" + "=" * 50)
print(f"Pipeline Execution Time: {end_time - start_time:.2f} seconds")
print("=" * 50 + "\n")
else:
asyncio.run(main(args.repo_path, args.include_docs))