Merge pull request #414 from topoteretes/COG-949
Code graph pipeline improvements and fixes
This commit is contained in:
commit
f7e808eddd
18 changed files with 255 additions and 54 deletions
|
|
@ -3,7 +3,6 @@ import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
|
||||||
from cognee.modules.cognify.config import get_cognify_config
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
from cognee.modules.pipelines import run_tasks
|
from cognee.modules.pipelines import run_tasks
|
||||||
from cognee.modules.pipelines.tasks.Task import Task
|
from cognee.modules.pipelines.tasks.Task import Task
|
||||||
|
|
@ -54,8 +53,6 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
await create_db_and_tables()
|
await create_db_and_tables()
|
||||||
|
|
||||||
embedding_engine = get_embedding_engine()
|
|
||||||
|
|
||||||
cognee_config = get_cognify_config()
|
cognee_config = get_cognify_config()
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
|
|
@ -63,11 +60,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
||||||
Task(get_repo_file_dependencies),
|
Task(get_repo_file_dependencies),
|
||||||
Task(enrich_dependency_graph),
|
Task(enrich_dependency_graph),
|
||||||
Task(expand_dependency_graph, task_config={"batch_size": 50}),
|
Task(expand_dependency_graph, task_config={"batch_size": 50}),
|
||||||
Task(
|
Task(get_source_code_chunks, task_config={"batch_size": 50}),
|
||||||
get_source_code_chunks,
|
|
||||||
embedding_model=embedding_engine.model,
|
|
||||||
task_config={"batch_size": 50},
|
|
||||||
),
|
|
||||||
Task(summarize_code, task_config={"batch_size": 50}),
|
Task(summarize_code, task_config={"batch_size": 50}),
|
||||||
Task(add_data_points, task_config={"batch_size": 50}),
|
Task(add_data_points, task_config={"batch_size": 50}),
|
||||||
]
|
]
|
||||||
|
|
@ -78,7 +71,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
||||||
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
|
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
|
||||||
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
|
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(extract_chunks_from_documents),
|
Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens),
|
||||||
Task(
|
Task(
|
||||||
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
|
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
from uuid import uuid5, NAMESPACE_OID
|
from typing import Optional
|
||||||
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
|
from cognee.tasks.chunks import chunk_by_paragraph
|
||||||
|
|
||||||
from .models.DocumentChunk import DocumentChunk
|
from .models.DocumentChunk import DocumentChunk
|
||||||
from cognee.tasks.chunks import chunk_by_paragraph
|
|
||||||
|
|
||||||
|
|
||||||
class TextChunker:
|
class TextChunker:
|
||||||
|
|
@ -10,23 +12,36 @@ class TextChunker:
|
||||||
|
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
chunk_size = 0
|
chunk_size = 0
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
def __init__(self, document, get_text: callable, chunk_size: int = 1024):
|
def __init__(
|
||||||
|
self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024
|
||||||
|
):
|
||||||
self.document = document
|
self.document = document
|
||||||
self.max_chunk_size = chunk_size
|
self.max_chunk_size = chunk_size
|
||||||
self.get_text = get_text
|
self.get_text = get_text
|
||||||
|
self.max_tokens = max_tokens if max_tokens else float("inf")
|
||||||
|
|
||||||
|
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
|
||||||
|
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
|
||||||
|
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens
|
||||||
|
return word_count_fits and token_count_fits
|
||||||
|
|
||||||
def read(self):
|
def read(self):
|
||||||
paragraph_chunks = []
|
paragraph_chunks = []
|
||||||
for content_text in self.get_text():
|
for content_text in self.get_text():
|
||||||
for chunk_data in chunk_by_paragraph(
|
for chunk_data in chunk_by_paragraph(
|
||||||
content_text,
|
content_text,
|
||||||
|
self.max_tokens,
|
||||||
self.max_chunk_size,
|
self.max_chunk_size,
|
||||||
batch_paragraphs=True,
|
batch_paragraphs=True,
|
||||||
):
|
):
|
||||||
if self.chunk_size + chunk_data["word_count"] <= self.max_chunk_size:
|
if self.check_word_count_and_token_count(
|
||||||
|
self.chunk_size, self.token_count, chunk_data
|
||||||
|
):
|
||||||
paragraph_chunks.append(chunk_data)
|
paragraph_chunks.append(chunk_data)
|
||||||
self.chunk_size += chunk_data["word_count"]
|
self.chunk_size += chunk_data["word_count"]
|
||||||
|
self.token_count += chunk_data["token_count"]
|
||||||
else:
|
else:
|
||||||
if len(paragraph_chunks) == 0:
|
if len(paragraph_chunks) == 0:
|
||||||
yield DocumentChunk(
|
yield DocumentChunk(
|
||||||
|
|
@ -66,6 +81,7 @@ class TextChunker:
|
||||||
print(e)
|
print(e)
|
||||||
paragraph_chunks = [chunk_data]
|
paragraph_chunks = [chunk_data]
|
||||||
self.chunk_size = chunk_data["word_count"]
|
self.chunk_size = chunk_data["word_count"]
|
||||||
|
self.token_count = chunk_data["token_count"]
|
||||||
|
|
||||||
self.chunk_index += 1
|
self.chunk_index += 1
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,14 @@
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from cognee.shared.data_models import DefaultContentPrediction, SummarizedContent
|
from cognee.shared.data_models import DefaultContentPrediction, SummarizedContent
|
||||||
|
from typing import Optional
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class CognifyConfig(BaseSettings):
|
class CognifyConfig(BaseSettings):
|
||||||
classification_model: object = DefaultContentPrediction
|
classification_model: object = DefaultContentPrediction
|
||||||
summarization_model: object = SummarizedContent
|
summarization_model: object = SummarizedContent
|
||||||
|
max_tokens: Optional[int] = os.getenv("MAX_TOKENS")
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
from .Document import Document
|
|
||||||
from .ChunkerMapping import ChunkerConfig
|
from .ChunkerMapping import ChunkerConfig
|
||||||
|
from .Document import Document
|
||||||
|
|
||||||
|
|
||||||
class AudioDocument(Document):
|
class AudioDocument(Document):
|
||||||
|
|
@ -10,12 +13,14 @@ class AudioDocument(Document):
|
||||||
result = get_llm_client().create_transcript(self.raw_data_location)
|
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||||
return result.text
|
return result.text
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str):
|
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
|
||||||
# Transcribe the audio file
|
# Transcribe the audio file
|
||||||
|
|
||||||
text = self.create_transcript()
|
text = self.create_transcript()
|
||||||
|
|
||||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||||
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
|
chunker = chunker_func(
|
||||||
|
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
@ -10,5 +11,5 @@ class Document(DataPoint):
|
||||||
mime_type: str
|
mime_type: str
|
||||||
_metadata: dict = {"index_fields": ["name"], "type": "Document"}
|
_metadata: dict = {"index_fields": ["name"], "type": "Document"}
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker=str) -> str:
|
def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
from .Document import Document
|
|
||||||
from .ChunkerMapping import ChunkerConfig
|
from .ChunkerMapping import ChunkerConfig
|
||||||
|
from .Document import Document
|
||||||
|
|
||||||
|
|
||||||
class ImageDocument(Document):
|
class ImageDocument(Document):
|
||||||
|
|
@ -10,11 +13,13 @@ class ImageDocument(Document):
|
||||||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||||
return result.choices[0].message.content
|
return result.choices[0].message.content
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str):
|
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
|
||||||
# Transcribe the image file
|
# Transcribe the image file
|
||||||
text = self.transcribe_image()
|
text = self.transcribe_image()
|
||||||
|
|
||||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||||
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
|
chunker = chunker_func(
|
||||||
|
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from .Document import Document
|
|
||||||
from .ChunkerMapping import ChunkerConfig
|
from .ChunkerMapping import ChunkerConfig
|
||||||
|
from .Document import Document
|
||||||
|
|
||||||
|
|
||||||
class PdfDocument(Document):
|
class PdfDocument(Document):
|
||||||
type: str = "pdf"
|
type: str = "pdf"
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str):
|
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
|
||||||
file = PdfReader(self.raw_data_location)
|
file = PdfReader(self.raw_data_location)
|
||||||
|
|
||||||
def get_text():
|
def get_text():
|
||||||
|
|
@ -15,7 +18,9 @@ class PdfDocument(Document):
|
||||||
yield page_text
|
yield page_text
|
||||||
|
|
||||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||||
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
|
chunker = chunker_func(
|
||||||
|
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
from .Document import Document
|
from typing import Optional
|
||||||
|
|
||||||
from .ChunkerMapping import ChunkerConfig
|
from .ChunkerMapping import ChunkerConfig
|
||||||
|
from .Document import Document
|
||||||
|
|
||||||
|
|
||||||
class TextDocument(Document):
|
class TextDocument(Document):
|
||||||
type: str = "text"
|
type: str = "text"
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str):
|
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
|
||||||
def get_text():
|
def get_text():
|
||||||
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
|
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -18,6 +20,8 @@ class TextDocument(Document):
|
||||||
|
|
||||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||||
|
|
||||||
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
|
chunker = chunker_func(
|
||||||
|
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,16 @@
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from cognee.modules.chunking.TextChunker import TextChunker
|
from cognee.modules.chunking.TextChunker import TextChunker
|
||||||
from .Document import Document
|
|
||||||
from cognee.modules.data.exceptions import UnstructuredLibraryImportError
|
from cognee.modules.data.exceptions import UnstructuredLibraryImportError
|
||||||
|
|
||||||
|
from .Document import Document
|
||||||
|
|
||||||
|
|
||||||
class UnstructuredDocument(Document):
|
class UnstructuredDocument(Document):
|
||||||
type: str = "unstructured"
|
type: str = "unstructured"
|
||||||
|
|
||||||
def read(self, chunk_size: int):
|
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str:
|
||||||
def get_text():
|
def get_text():
|
||||||
try:
|
try:
|
||||||
from unstructured.partition.auto import partition
|
from unstructured.partition.auto import partition
|
||||||
|
|
@ -27,6 +29,6 @@ class UnstructuredDocument(Document):
|
||||||
|
|
||||||
yield text
|
yield text
|
||||||
|
|
||||||
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text)
|
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,18 @@
|
||||||
from uuid import uuid5, NAMESPACE_OID
|
from typing import Any, Dict, Iterator, Optional, Union
|
||||||
from typing import Dict, Any, Iterator
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
from .chunk_by_sentence import chunk_by_sentence
|
from .chunk_by_sentence import chunk_by_sentence
|
||||||
|
|
||||||
|
|
||||||
def chunk_by_paragraph(
|
def chunk_by_paragraph(
|
||||||
data: str, paragraph_length: int = 1024, batch_paragraphs: bool = True
|
data: str,
|
||||||
|
max_tokens: Optional[Union[int, float]] = None,
|
||||||
|
paragraph_length: int = 1024,
|
||||||
|
batch_paragraphs: bool = True,
|
||||||
) -> Iterator[Dict[str, Any]]:
|
) -> Iterator[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Chunks text by paragraph while preserving exact text reconstruction capability.
|
Chunks text by paragraph while preserving exact text reconstruction capability.
|
||||||
|
|
@ -15,16 +23,31 @@ def chunk_by_paragraph(
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
paragraph_ids = []
|
paragraph_ids = []
|
||||||
last_cut_type = None
|
last_cut_type = None
|
||||||
|
current_token_count = 0
|
||||||
|
if not max_tokens:
|
||||||
|
max_tokens = float("inf")
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
embedding_model = vector_engine.embedding_engine.model
|
||||||
|
embedding_model = embedding_model.split("/")[-1]
|
||||||
|
|
||||||
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(
|
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(
|
||||||
data, maximum_length=paragraph_length
|
data, maximum_length=paragraph_length
|
||||||
):
|
):
|
||||||
# Check if this sentence would exceed length limit
|
# Check if this sentence would exceed length limit
|
||||||
if current_word_count > 0 and current_word_count + word_count > paragraph_length:
|
|
||||||
|
tokenizer = tiktoken.encoding_for_model(embedding_model)
|
||||||
|
token_count = len(tokenizer.encode(sentence))
|
||||||
|
|
||||||
|
if current_word_count > 0 and (
|
||||||
|
current_word_count + word_count > paragraph_length
|
||||||
|
or current_token_count + token_count > max_tokens
|
||||||
|
):
|
||||||
# Yield current chunk
|
# Yield current chunk
|
||||||
chunk_dict = {
|
chunk_dict = {
|
||||||
"text": current_chunk,
|
"text": current_chunk,
|
||||||
"word_count": current_word_count,
|
"word_count": current_word_count,
|
||||||
|
"token_count": current_token_count,
|
||||||
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
||||||
"paragraph_ids": paragraph_ids,
|
"paragraph_ids": paragraph_ids,
|
||||||
"chunk_index": chunk_index,
|
"chunk_index": chunk_index,
|
||||||
|
|
@ -37,11 +60,13 @@ def chunk_by_paragraph(
|
||||||
paragraph_ids = []
|
paragraph_ids = []
|
||||||
current_chunk = ""
|
current_chunk = ""
|
||||||
current_word_count = 0
|
current_word_count = 0
|
||||||
|
current_token_count = 0
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
|
|
||||||
paragraph_ids.append(paragraph_id)
|
paragraph_ids.append(paragraph_id)
|
||||||
current_chunk += sentence
|
current_chunk += sentence
|
||||||
current_word_count += word_count
|
current_word_count += word_count
|
||||||
|
current_token_count += token_count
|
||||||
|
|
||||||
# Handle end of paragraph
|
# Handle end of paragraph
|
||||||
if end_type in ("paragraph_end", "sentence_cut") and not batch_paragraphs:
|
if end_type in ("paragraph_end", "sentence_cut") and not batch_paragraphs:
|
||||||
|
|
@ -49,6 +74,7 @@ def chunk_by_paragraph(
|
||||||
chunk_dict = {
|
chunk_dict = {
|
||||||
"text": current_chunk,
|
"text": current_chunk,
|
||||||
"word_count": current_word_count,
|
"word_count": current_word_count,
|
||||||
|
"token_count": current_token_count,
|
||||||
"paragraph_ids": paragraph_ids,
|
"paragraph_ids": paragraph_ids,
|
||||||
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
||||||
"chunk_index": chunk_index,
|
"chunk_index": chunk_index,
|
||||||
|
|
@ -58,6 +84,7 @@ def chunk_by_paragraph(
|
||||||
paragraph_ids = []
|
paragraph_ids = []
|
||||||
current_chunk = ""
|
current_chunk = ""
|
||||||
current_word_count = 0
|
current_word_count = 0
|
||||||
|
current_token_count = 0
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
|
|
||||||
last_cut_type = end_type
|
last_cut_type = end_type
|
||||||
|
|
@ -67,6 +94,7 @@ def chunk_by_paragraph(
|
||||||
chunk_dict = {
|
chunk_dict = {
|
||||||
"text": current_chunk,
|
"text": current_chunk,
|
||||||
"word_count": current_word_count,
|
"word_count": current_word_count,
|
||||||
|
"token_count": current_token_count,
|
||||||
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
||||||
"paragraph_ids": paragraph_ids,
|
"paragraph_ids": paragraph_ids,
|
||||||
"chunk_index": chunk_index,
|
"chunk_index": chunk_index,
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,16 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from cognee.modules.data.processing.document_types.Document import Document
|
from cognee.modules.data.processing.document_types.Document import Document
|
||||||
|
|
||||||
|
|
||||||
async def extract_chunks_from_documents(
|
async def extract_chunks_from_documents(
|
||||||
documents: list[Document], chunk_size: int = 1024, chunker="text_chunker"
|
documents: list[Document],
|
||||||
|
chunk_size: int = 1024,
|
||||||
|
chunker="text_chunker",
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
for document in documents:
|
for document in documents:
|
||||||
for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker):
|
for document_chunk in document.read(
|
||||||
|
chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens
|
||||||
|
):
|
||||||
yield document_chunk
|
yield document_chunk
|
||||||
|
|
|
||||||
|
|
@ -29,8 +29,105 @@ async def get_non_py_files(repo_path):
|
||||||
"*.egg-info",
|
"*.egg-info",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ALLOWED_EXTENSIONS = {
|
||||||
|
".txt",
|
||||||
|
".md",
|
||||||
|
".csv",
|
||||||
|
".json",
|
||||||
|
".xml",
|
||||||
|
".yaml",
|
||||||
|
".yml",
|
||||||
|
".html",
|
||||||
|
".css",
|
||||||
|
".js",
|
||||||
|
".ts",
|
||||||
|
".jsx",
|
||||||
|
".tsx",
|
||||||
|
".sql",
|
||||||
|
".log",
|
||||||
|
".ini",
|
||||||
|
".toml",
|
||||||
|
".properties",
|
||||||
|
".sh",
|
||||||
|
".bash",
|
||||||
|
".dockerfile",
|
||||||
|
".gitignore",
|
||||||
|
".gitattributes",
|
||||||
|
".makefile",
|
||||||
|
".pyproject",
|
||||||
|
".requirements",
|
||||||
|
".env",
|
||||||
|
".pdf",
|
||||||
|
".doc",
|
||||||
|
".docx",
|
||||||
|
".dot",
|
||||||
|
".dotx",
|
||||||
|
".rtf",
|
||||||
|
".wps",
|
||||||
|
".wpd",
|
||||||
|
".odt",
|
||||||
|
".ott",
|
||||||
|
".ottx",
|
||||||
|
".txt",
|
||||||
|
".wp",
|
||||||
|
".sdw",
|
||||||
|
".sdx",
|
||||||
|
".docm",
|
||||||
|
".dotm",
|
||||||
|
# Additional extensions for other programming languages
|
||||||
|
".java",
|
||||||
|
".c",
|
||||||
|
".cpp",
|
||||||
|
".h",
|
||||||
|
".cs",
|
||||||
|
".go",
|
||||||
|
".php",
|
||||||
|
".rb",
|
||||||
|
".swift",
|
||||||
|
".pl",
|
||||||
|
".lua",
|
||||||
|
".rs",
|
||||||
|
".scala",
|
||||||
|
".kt",
|
||||||
|
".sh",
|
||||||
|
".sql",
|
||||||
|
".v",
|
||||||
|
".asm",
|
||||||
|
".pas",
|
||||||
|
".d",
|
||||||
|
".ml",
|
||||||
|
".clj",
|
||||||
|
".cljs",
|
||||||
|
".erl",
|
||||||
|
".ex",
|
||||||
|
".exs",
|
||||||
|
".f",
|
||||||
|
".fs",
|
||||||
|
".r",
|
||||||
|
".pyi",
|
||||||
|
".pdb",
|
||||||
|
".ipynb",
|
||||||
|
".rmd",
|
||||||
|
".cabal",
|
||||||
|
".hs",
|
||||||
|
".nim",
|
||||||
|
".vhdl",
|
||||||
|
".verilog",
|
||||||
|
".svelte",
|
||||||
|
".html",
|
||||||
|
".css",
|
||||||
|
".scss",
|
||||||
|
".less",
|
||||||
|
".json5",
|
||||||
|
".yaml",
|
||||||
|
".yml",
|
||||||
|
}
|
||||||
|
|
||||||
def should_process(path):
|
def should_process(path):
|
||||||
return not any(pattern in path for pattern in IGNORED_PATTERNS)
|
_, ext = os.path.splitext(path)
|
||||||
|
return ext in ALLOWED_EXTENSIONS and not any(
|
||||||
|
pattern in path for pattern in IGNORED_PATTERNS
|
||||||
|
)
|
||||||
|
|
||||||
non_py_files_paths = [
|
non_py_files_paths = [
|
||||||
os.path.join(root, file)
|
os.path.join(root, file)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from uuid import NAMESPACE_OID, uuid5
|
||||||
import parso
|
import parso
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
|
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
|
||||||
|
|
||||||
|
|
@ -126,6 +127,9 @@ def get_source_code_chunks_from_code_part(
|
||||||
logger.error(f"No source code in CodeFile {code_file_part.id}")
|
logger.error(f"No source code in CodeFile {code_file_part.id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
embedding_model = vector_engine.embedding_engine.model
|
||||||
|
model_name = embedding_model.split("/")[-1]
|
||||||
tokenizer = tiktoken.encoding_for_model(model_name)
|
tokenizer = tiktoken.encoding_for_model(model_name)
|
||||||
max_subchunk_tokens = max(1, int(granularity * max_tokens))
|
max_subchunk_tokens = max(1, int(granularity * max_tokens))
|
||||||
subchunk_token_counts = _get_subchunk_token_counts(
|
subchunk_token_counts = _get_subchunk_token_counts(
|
||||||
|
|
@ -150,7 +154,7 @@ def get_source_code_chunks_from_code_part(
|
||||||
|
|
||||||
|
|
||||||
async def get_source_code_chunks(
|
async def get_source_code_chunks(
|
||||||
data_points: list[DataPoint], embedding_model="text-embedding-3-large"
|
data_points: list[DataPoint],
|
||||||
) -> AsyncGenerator[list[DataPoint], None]:
|
) -> AsyncGenerator[list[DataPoint], None]:
|
||||||
"""Processes code graph datapoints, create SourceCodeChink datapoints."""
|
"""Processes code graph datapoints, create SourceCodeChink datapoints."""
|
||||||
# TODO: Add support for other embedding models, with max_token mapping
|
# TODO: Add support for other embedding models, with max_token mapping
|
||||||
|
|
@ -165,9 +169,7 @@ async def get_source_code_chunks(
|
||||||
for code_part in data_point.contains:
|
for code_part in data_point.contains:
|
||||||
try:
|
try:
|
||||||
yield code_part
|
yield code_part
|
||||||
for source_code_chunk in get_source_code_chunks_from_code_part(
|
for source_code_chunk in get_source_code_chunks_from_code_part(code_part):
|
||||||
code_part, model_name=embedding_model
|
|
||||||
):
|
|
||||||
yield source_code_chunk
|
yield source_code_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing code part: {e}")
|
logger.error(f"Error processing code part: {e}")
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ def test_UnstructuredDocument():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test PPTX
|
# Test PPTX
|
||||||
for paragraph_data in pptx_document.read(chunk_size=1024):
|
for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"):
|
||||||
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
|
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
|
||||||
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
||||||
assert (
|
assert (
|
||||||
|
|
@ -76,7 +76,7 @@ def test_UnstructuredDocument():
|
||||||
), f" sentence_cut != {paragraph_data.cut_type = }"
|
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||||
|
|
||||||
# Test DOCX
|
# Test DOCX
|
||||||
for paragraph_data in docx_document.read(chunk_size=1024):
|
for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"):
|
||||||
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
|
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
|
||||||
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
||||||
assert (
|
assert (
|
||||||
|
|
@ -84,7 +84,7 @@ def test_UnstructuredDocument():
|
||||||
), f" sentence_end != {paragraph_data.cut_type = }"
|
), f" sentence_end != {paragraph_data.cut_type = }"
|
||||||
|
|
||||||
# TEST CSV
|
# TEST CSV
|
||||||
for paragraph_data in csv_document.read(chunk_size=1024):
|
for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"):
|
||||||
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
|
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
|
||||||
assert (
|
assert (
|
||||||
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
|
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
|
||||||
|
|
@ -94,7 +94,7 @@ def test_UnstructuredDocument():
|
||||||
), f" sentence_cut != {paragraph_data.cut_type = }"
|
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||||
|
|
||||||
# Test XLSX
|
# Test XLSX
|
||||||
for paragraph_data in xlsx_document.read(chunk_size=1024):
|
for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"):
|
||||||
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
|
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
|
||||||
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
||||||
assert (
|
assert (
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,11 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para
|
||||||
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
|
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
|
||||||
)
|
)
|
||||||
def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
|
def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
|
||||||
chunks = list(chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs))
|
chunks = list(
|
||||||
|
chunk_by_paragraph(
|
||||||
|
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks])
|
chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks])
|
||||||
|
|
||||||
|
|
@ -42,7 +46,9 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
|
||||||
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
|
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
|
||||||
)
|
)
|
||||||
def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs):
|
def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs):
|
||||||
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)
|
chunks = chunk_by_paragraph(
|
||||||
|
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
|
||||||
|
)
|
||||||
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
||||||
assert np.all(
|
assert np.all(
|
||||||
chunk_indices == np.arange(len(chunk_indices))
|
chunk_indices == np.arange(len(chunk_indices))
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,9 @@ Third paragraph is cut and is missing the dot at the end""",
|
||||||
|
|
||||||
def run_chunking_test(test_text, expected_chunks):
|
def run_chunking_test(test_text, expected_chunks):
|
||||||
chunks = []
|
chunks = []
|
||||||
for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs=False):
|
for chunk_data in chunk_by_paragraph(
|
||||||
|
data=test_text, paragraph_length=12, batch_paragraphs=False
|
||||||
|
):
|
||||||
chunks.append(chunk_data)
|
chunks.append(chunk_data)
|
||||||
|
|
||||||
assert len(chunks) == 3
|
assert len(chunks) == 3
|
||||||
|
|
|
||||||
|
|
@ -34,9 +34,8 @@ def check_install_package(package_name):
|
||||||
|
|
||||||
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
|
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
|
||||||
repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")
|
repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")
|
||||||
pipeline = await run_code_graph_pipeline(repo_path)
|
|
||||||
|
|
||||||
async for result in pipeline:
|
async for result in run_code_graph_pipeline(repo_path, include_docs=True):
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
print("Here we have the repo under the repo_path")
|
print("Here we have the repo under the repo_path")
|
||||||
|
|
@ -47,7 +46,9 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp
|
||||||
instructions = read_query_prompt("patch_gen_kg_instructions.txt")
|
instructions = read_query_prompt("patch_gen_kg_instructions.txt")
|
||||||
|
|
||||||
retrieved_edges = await brute_force_triplet_search(
|
retrieved_edges = await brute_force_triplet_search(
|
||||||
problem_statement, top_k=3, collections=["data_point_source_code", "data_point_text"]
|
problem_statement,
|
||||||
|
top_k=3,
|
||||||
|
collections=["code_summary_text"],
|
||||||
)
|
)
|
||||||
|
|
||||||
retrieved_edges_str = retrieved_edges_to_string(retrieved_edges)
|
retrieved_edges_str = retrieved_edges_to_string(retrieved_edges)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||||
|
from cognee.shared.utils import setup_logging
|
||||||
|
|
||||||
|
|
||||||
async def main(repo_path, include_docs):
|
async def main(repo_path, include_docs):
|
||||||
|
|
@ -9,7 +11,7 @@ async def main(repo_path, include_docs):
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
|
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -18,5 +20,28 @@ if __name__ == "__main__":
|
||||||
default=True,
|
default=True,
|
||||||
help="Whether or not to process non-code files",
|
help="Whether or not to process non-code files",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
parser.add_argument(
|
||||||
asyncio.run(main(args.repo_path, args.include_docs))
|
"--time",
|
||||||
|
type=lambda x: x.lower() in ("true", "1"),
|
||||||
|
default=True,
|
||||||
|
help="Whether or not to time the pipeline run",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
setup_logging(logging.ERROR)
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if args.time:
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
asyncio.run(main(args.repo_path, args.include_docs))
|
||||||
|
end_time = time.time()
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print(f"Pipeline Execution Time: {end_time - start_time:.2f} seconds")
|
||||||
|
print("=" * 50 + "\n")
|
||||||
|
else:
|
||||||
|
asyncio.run(main(args.repo_path, args.include_docs))
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue