Merge branch 'dev' into llama-index-notebook

This commit is contained in:
Igor Ilic 2025-01-10 17:26:05 +01:00 committed by GitHub
commit a5c91e8f0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 459 additions and 202 deletions

View file

@ -10,7 +10,7 @@ repos:
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.8.3
rev: v0.9.0
hooks:
# Run the linter.
- id: ruff

View file

@ -3,7 +3,6 @@ import logging
from pathlib import Path
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
@ -54,8 +53,6 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
await cognee.prune.prune_system(metadata=True)
await create_db_and_tables()
embedding_engine = get_embedding_engine()
cognee_config = get_cognify_config()
user = await get_default_user()
@ -63,11 +60,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(get_repo_file_dependencies),
Task(enrich_dependency_graph),
Task(expand_dependency_graph, task_config={"batch_size": 50}),
Task(
get_source_code_chunks,
embedding_model=embedding_engine.model,
task_config={"batch_size": 50},
),
Task(get_source_code_chunks, task_config={"batch_size": 50}),
Task(summarize_code, task_config={"batch_size": 50}),
Task(add_data_points, task_config={"batch_size": 50}),
]
@ -78,7 +71,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents),
Task(extract_chunks_from_documents),
Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens),
Task(
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
),

View file

@ -493,7 +493,7 @@ class Neo4jAdapter(GraphDBInterface):
query_edges = f"""
MATCH (n)-[r]->(m)
WHERE {where_clause} AND {where_clause.replace('n.', 'm.')}
WHERE {where_clause} AND {where_clause.replace("n.", "m.")}
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result_edges = await self.query(query_edges)

View file

@ -1,3 +1,6 @@
I need you to solve this issue by looking at the provided edges retrieved from a knowledge graph and
generate a single patch file that I can apply directly to this repository using git apply.
Please respond with a single patch file in the following format.
You are a senior software engineer. I need you to solve this issue by looking at the provided context and
generate a single patch file that I can apply directly to this repository using git apply.
Additionally, please make sure that you provide code only with correct syntax and
you apply the patch on the relevant files (together with their path that you can try to find out from the github issue). Don't change the names of existing
functions or classes, as they may be referenced from other code.
Please respond only with a single patch file in the following format without adding any additional context or string.

View file

@ -1,7 +1,9 @@
from uuid import uuid5, NAMESPACE_OID
from typing import Optional
from uuid import NAMESPACE_OID, uuid5
from cognee.tasks.chunks import chunk_by_paragraph
from .models.DocumentChunk import DocumentChunk
from cognee.tasks.chunks import chunk_by_paragraph
class TextChunker:
@ -10,23 +12,36 @@ class TextChunker:
chunk_index = 0
chunk_size = 0
token_count = 0
def __init__(self, document, get_text: callable, chunk_size: int = 1024):
def __init__(
self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024
):
self.document = document
self.max_chunk_size = chunk_size
self.get_text = get_text
self.max_tokens = max_tokens if max_tokens else float("inf")
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens
return word_count_fits and token_count_fits
def read(self):
paragraph_chunks = []
for content_text in self.get_text():
for chunk_data in chunk_by_paragraph(
content_text,
self.max_tokens,
self.max_chunk_size,
batch_paragraphs=True,
):
if self.chunk_size + chunk_data["word_count"] <= self.max_chunk_size:
if self.check_word_count_and_token_count(
self.chunk_size, self.token_count, chunk_data
):
paragraph_chunks.append(chunk_data)
self.chunk_size += chunk_data["word_count"]
self.token_count += chunk_data["token_count"]
else:
if len(paragraph_chunks) == 0:
yield DocumentChunk(
@ -66,6 +81,7 @@ class TextChunker:
print(e)
paragraph_chunks = [chunk_data]
self.chunk_size = chunk_data["word_count"]
self.token_count = chunk_data["token_count"]
self.chunk_index += 1

View file

@ -12,6 +12,7 @@ class DocumentChunk(DataPoint):
chunk_index: int
cut_type: str
is_part_of: Document
pydantic_type: str = "DocumentChunk"
contains: List[Entity] = None
_metadata: dict = {"index_fields": ["text"], "type": "DocumentChunk"}

View file

@ -1,12 +1,14 @@
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.shared.data_models import DefaultContentPrediction, SummarizedContent
from typing import Optional
import os
class CognifyConfig(BaseSettings):
classification_model: object = DefaultContentPrediction
summarization_model: object = SummarizedContent
max_tokens: Optional[int] = os.getenv("MAX_TOKENS")
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:

View file

@ -1,6 +1,9 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .Document import Document
from .ChunkerMapping import ChunkerConfig
from .Document import Document
class AudioDocument(Document):
@ -10,12 +13,14 @@ class AudioDocument(Document):
result = get_llm_client().create_transcript(self.raw_data_location)
return result.text
def read(self, chunk_size: int, chunker: str):
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
# Transcribe the audio file
text = self.create_transcript()
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
)
yield from chunker.read()

View file

@ -1,3 +1,4 @@
from typing import Optional
from uuid import UUID
from cognee.infrastructure.engine import DataPoint
@ -10,5 +11,5 @@ class Document(DataPoint):
mime_type: str
_metadata: dict = {"index_fields": ["name"], "type": "Document"}
def read(self, chunk_size: int, chunker=str) -> str:
def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str:
pass

View file

@ -1,6 +1,9 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from .Document import Document
from .ChunkerMapping import ChunkerConfig
from .Document import Document
class ImageDocument(Document):
@ -10,11 +13,13 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location)
return result.choices[0].message.content
def read(self, chunk_size: int, chunker: str):
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
# Transcribe the image file
text = self.transcribe_image()
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
)
yield from chunker.read()

View file

@ -1,12 +1,15 @@
from typing import Optional
from pypdf import PdfReader
from .Document import Document
from .ChunkerMapping import ChunkerConfig
from .Document import Document
class PdfDocument(Document):
type: str = "pdf"
def read(self, chunk_size: int, chunker: str):
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
file = PdfReader(self.raw_data_location)
def get_text():
@ -15,7 +18,9 @@ class PdfDocument(Document):
yield page_text
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
)
yield from chunker.read()

View file

@ -1,11 +1,13 @@
from .Document import Document
from typing import Optional
from .ChunkerMapping import ChunkerConfig
from .Document import Document
class TextDocument(Document):
type: str = "text"
def read(self, chunk_size: int, chunker: str):
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
def get_text():
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
while True:
@ -18,6 +20,8 @@ class TextDocument(Document):
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
)
yield from chunker.read()

View file

@ -1,14 +1,16 @@
from io import StringIO
from typing import Optional
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document
from cognee.modules.data.exceptions import UnstructuredLibraryImportError
from .Document import Document
class UnstructuredDocument(Document):
type: str = "unstructured"
def read(self, chunk_size: int):
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str:
def get_text():
try:
from unstructured.partition.auto import partition
@ -27,6 +29,6 @@ class UnstructuredDocument(Document):
yield text
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text)
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens)
yield from chunker.read()

View file

@ -7,5 +7,6 @@ class Entity(DataPoint):
name: str
is_a: EntityType
description: str
pydantic_type: str = "Entity"
_metadata: dict = {"index_fields": ["name"], "type": "Entity"}

View file

@ -5,5 +5,6 @@ class EntityType(DataPoint):
__tablename__ = "entity_type"
name: str
description: str
pydantic_type: str = "EntityType"
_metadata: dict = {"index_fields": ["name"], "type": "EntityType"}

View file

@ -43,7 +43,7 @@ def format_triplets(edges):
edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
# Create the formatted triplet
triplet = f"Node1: {node1_info}\n" f"Edge: {edge_info}\n" f"Node2: {node2_info}\n\n\n"
triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n"
triplets.append(triplet)
return "".join(triplets)

View file

@ -8,20 +8,27 @@ from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from cognee.api.v1.search import SearchType
from cognee.api.v1.search.search_v2 import search
from cognee.infrastructure.llm.get_llm_client import get_llm_client
async def code_description_to_code_part_search(query: str, user: User = None, top_k=2) -> list:
async def code_description_to_code_part_search(
query: str, include_docs=False, user: User = None, top_k=5
) -> list:
if user is None:
user = await get_default_user()
if user is None:
raise PermissionError("No user found in the system. Please create a user.")
retrieved_codeparts = await code_description_to_code_part(query, user, top_k)
retrieved_codeparts = await code_description_to_code_part(query, user, top_k, include_docs)
return retrieved_codeparts
async def code_description_to_code_part(query: str, user: User, top_k: int) -> List[str]:
async def code_description_to_code_part(
query: str, user: User, top_k: int, include_docs: bool = False
) -> List[str]:
"""
Maps a code description query to relevant code parts using a CodeGraph pipeline.
@ -29,6 +36,7 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
query (str): The search query describing the code parts.
user (User): The user performing the search.
top_k (int): Number of codegraph descriptions to match ( num of corresponding codeparts will be higher)
include_docs(bool): Boolean showing whether we have the docs in the graph or not
Returns:
Set[str]: A set of unique code parts matching the query.
@ -55,21 +63,48 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
)
try:
results = await vector_engine.search("code_summary_text", query_text=query, limit=top_k)
if not results:
if include_docs:
search_results = await search(SearchType.INSIGHTS, query_text=query)
concatenated_descriptions = " ".join(
obj["description"]
for tpl in search_results
for obj in tpl
if isinstance(obj, dict) and "description" in obj
)
llm_client = get_llm_client()
context_from_documents = await llm_client.acreate_structured_output(
text_input=f"The retrieved context from documents is {concatenated_descriptions}.",
system_prompt="You are a Senior Software Engineer, summarize the context from documents"
f" in a way that it is gonna be provided next to codeparts as context"
f" while trying to solve this github issue connected to the project: {query}]",
response_model=str,
)
code_summaries = await vector_engine.search(
"code_summary_text", query_text=query, limit=top_k
)
if not code_summaries:
logging.warning("No results found for query: '%s' by user: %s", query, user.id)
return []
memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=["id", "type", "text", "source_code"],
node_properties_to_project=[
"id",
"type",
"text",
"source_code",
"pydantic_type",
],
edge_properties_to_project=["relationship_name"],
)
code_pieces_to_return = set()
for node in results:
for node in code_summaries:
node_id = str(node.id)
node_to_search_from = memory_fragment.get_node(node_id)
@ -78,9 +113,16 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
continue
for code_file in node_to_search_from.get_skeleton_neighbours():
for code_file_edge in code_file.get_skeleton_edges():
if code_file_edge.get_attribute("relationship_name") == "contains":
code_pieces_to_return.add(code_file_edge.get_destination_node())
if code_file.get_attribute("pydantic_type") == "SourceCodeChunk":
for code_file_edge in code_file.get_skeleton_edges():
if code_file_edge.get_attribute("relationship_name") == "code_chunk_of":
code_pieces_to_return.add(code_file_edge.get_destination_node())
elif code_file.get_attribute("pydantic_type") == "CodePart":
code_pieces_to_return.add(code_file)
elif code_file.get_attribute("pydantic_type") == "CodeFile":
for code_file_edge in code_file.get_skeleton_edges():
if code_file_edge.get_attribute("relationship_name") == "contains":
code_pieces_to_return.add(code_file_edge.get_destination_node())
logging.info(
"Search completed for user: %s, query: '%s'. Found %d code pieces.",
@ -89,7 +131,14 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
len(code_pieces_to_return),
)
return list(code_pieces_to_return)
context = ""
for code_piece in code_pieces_to_return:
context = context + code_piece.get_attribute("source_code")
if include_docs:
context = context_from_documents + context
return context
except Exception as exec_error:
logging.error(

View file

@ -5,12 +5,14 @@ from cognee.infrastructure.engine import DataPoint
class Repository(DataPoint):
__tablename__ = "Repository"
path: str
pydantic_type: str = "Repository"
_metadata: dict = {"index_fields": [], "type": "Repository"}
class CodeFile(DataPoint):
__tablename__ = "codefile"
extracted_id: str # actually file path
pydantic_type: str = "CodeFile"
source_code: Optional[str] = None
part_of: Optional[Repository] = None
depends_on: Optional[List["CodeFile"]] = None
@ -22,6 +24,7 @@ class CodeFile(DataPoint):
class CodePart(DataPoint):
__tablename__ = "codepart"
# part_of: Optional[CodeFile] = None
pydantic_type: str = "CodePart"
source_code: Optional[str] = None
_metadata: dict = {"index_fields": [], "type": "CodePart"}
@ -30,6 +33,7 @@ class SourceCodeChunk(DataPoint):
__tablename__ = "sourcecodechunk"
code_chunk_of: Optional[CodePart] = None
source_code: Optional[str] = None
pydantic_type: str = "SourceCodeChunk"
previous_chunk: Optional["SourceCodeChunk"] = None
_metadata: dict = {"index_fields": ["source_code"], "type": "SourceCodeChunk"}

View file

@ -231,6 +231,7 @@ class SummarizedContent(BaseModel):
summary: str
description: str
pydantic_type: str = "SummarizedContent"
class SummarizedFunction(BaseModel):
@ -239,6 +240,7 @@ class SummarizedFunction(BaseModel):
inputs: Optional[List[str]] = None
outputs: Optional[List[str]] = None
decorators: Optional[List[str]] = None
pydantic_type: str = "SummarizedFunction"
class SummarizedClass(BaseModel):
@ -246,6 +248,7 @@ class SummarizedClass(BaseModel):
description: str
methods: Optional[List[SummarizedFunction]] = None
decorators: Optional[List[str]] = None
pydantic_type: str = "SummarizedClass"
class SummarizedCode(BaseModel):
@ -256,6 +259,7 @@ class SummarizedCode(BaseModel):
classes: List[SummarizedClass] = []
functions: List[SummarizedFunction] = []
workflow_description: Optional[str] = None
pydantic_type: str = "SummarizedCode"
class GraphDBType(Enum):

View file

@ -1,10 +1,18 @@
from uuid import uuid5, NAMESPACE_OID
from typing import Dict, Any, Iterator
from typing import Any, Dict, Iterator, Optional, Union
from uuid import NAMESPACE_OID, uuid5
import tiktoken
from cognee.infrastructure.databases.vector import get_vector_engine
from .chunk_by_sentence import chunk_by_sentence
def chunk_by_paragraph(
data: str, paragraph_length: int = 1024, batch_paragraphs: bool = True
data: str,
max_tokens: Optional[Union[int, float]] = None,
paragraph_length: int = 1024,
batch_paragraphs: bool = True,
) -> Iterator[Dict[str, Any]]:
"""
Chunks text by paragraph while preserving exact text reconstruction capability.
@ -15,16 +23,31 @@ def chunk_by_paragraph(
chunk_index = 0
paragraph_ids = []
last_cut_type = None
current_token_count = 0
if not max_tokens:
max_tokens = float("inf")
vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model
embedding_model = embedding_model.split("/")[-1]
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(
data, maximum_length=paragraph_length
):
# Check if this sentence would exceed length limit
if current_word_count > 0 and current_word_count + word_count > paragraph_length:
tokenizer = tiktoken.encoding_for_model(embedding_model)
token_count = len(tokenizer.encode(sentence))
if current_word_count > 0 and (
current_word_count + word_count > paragraph_length
or current_token_count + token_count > max_tokens
):
# Yield current chunk
chunk_dict = {
"text": current_chunk,
"word_count": current_word_count,
"token_count": current_token_count,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"paragraph_ids": paragraph_ids,
"chunk_index": chunk_index,
@ -37,11 +60,13 @@ def chunk_by_paragraph(
paragraph_ids = []
current_chunk = ""
current_word_count = 0
current_token_count = 0
chunk_index += 1
paragraph_ids.append(paragraph_id)
current_chunk += sentence
current_word_count += word_count
current_token_count += token_count
# Handle end of paragraph
if end_type in ("paragraph_end", "sentence_cut") and not batch_paragraphs:
@ -49,6 +74,7 @@ def chunk_by_paragraph(
chunk_dict = {
"text": current_chunk,
"word_count": current_word_count,
"token_count": current_token_count,
"paragraph_ids": paragraph_ids,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"chunk_index": chunk_index,
@ -58,6 +84,7 @@ def chunk_by_paragraph(
paragraph_ids = []
current_chunk = ""
current_word_count = 0
current_token_count = 0
chunk_index += 1
last_cut_type = end_type
@ -67,6 +94,7 @@ def chunk_by_paragraph(
chunk_dict = {
"text": current_chunk,
"word_count": current_word_count,
"token_count": current_token_count,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"paragraph_ids": paragraph_ids,
"chunk_index": chunk_index,

View file

@ -1,9 +1,16 @@
from typing import Optional
from cognee.modules.data.processing.document_types.Document import Document
async def extract_chunks_from_documents(
documents: list[Document], chunk_size: int = 1024, chunker="text_chunker"
documents: list[Document],
chunk_size: int = 1024,
chunker="text_chunker",
max_tokens: Optional[int] = None,
):
for document in documents:
for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker):
for document_chunk in document.read(
chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens
):
yield document_chunk

View file

@ -1,6 +1,5 @@
from typing import Dict, List
import parso
import logging
logger = logging.getLogger(__name__)

View file

@ -9,7 +9,6 @@ import aiofiles
import jedi
import parso
from parso.tree import BaseNode
import logging
logger = logging.getLogger(__name__)

View file

@ -29,8 +29,105 @@ async def get_non_py_files(repo_path):
"*.egg-info",
}
ALLOWED_EXTENSIONS = {
".txt",
".md",
".csv",
".json",
".xml",
".yaml",
".yml",
".html",
".css",
".js",
".ts",
".jsx",
".tsx",
".sql",
".log",
".ini",
".toml",
".properties",
".sh",
".bash",
".dockerfile",
".gitignore",
".gitattributes",
".makefile",
".pyproject",
".requirements",
".env",
".pdf",
".doc",
".docx",
".dot",
".dotx",
".rtf",
".wps",
".wpd",
".odt",
".ott",
".ottx",
".txt",
".wp",
".sdw",
".sdx",
".docm",
".dotm",
# Additional extensions for other programming languages
".java",
".c",
".cpp",
".h",
".cs",
".go",
".php",
".rb",
".swift",
".pl",
".lua",
".rs",
".scala",
".kt",
".sh",
".sql",
".v",
".asm",
".pas",
".d",
".ml",
".clj",
".cljs",
".erl",
".ex",
".exs",
".f",
".fs",
".r",
".pyi",
".pdb",
".ipynb",
".rmd",
".cabal",
".hs",
".nim",
".vhdl",
".verilog",
".svelte",
".html",
".css",
".scss",
".less",
".json5",
".yaml",
".yml",
}
def should_process(path):
return not any(pattern in path for pattern in IGNORED_PATTERNS)
_, ext = os.path.splitext(path)
return ext in ALLOWED_EXTENSIONS and not any(
pattern in path for pattern in IGNORED_PATTERNS
)
non_py_files_paths = [
os.path.join(root, file)

View file

@ -5,6 +5,7 @@ from uuid import NAMESPACE_OID, uuid5
import parso
import tiktoken
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
@ -126,6 +127,9 @@ def get_source_code_chunks_from_code_part(
logger.error(f"No source code in CodeFile {code_file_part.id}")
return
vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model
model_name = embedding_model.split("/")[-1]
tokenizer = tiktoken.encoding_for_model(model_name)
max_subchunk_tokens = max(1, int(granularity * max_tokens))
subchunk_token_counts = _get_subchunk_token_counts(
@ -150,7 +154,7 @@ def get_source_code_chunks_from_code_part(
async def get_source_code_chunks(
data_points: list[DataPoint], embedding_model="text-embedding-3-large"
data_points: list[DataPoint],
) -> AsyncGenerator[list[DataPoint], None]:
"""Processes code graph datapoints, create SourceCodeChink datapoints."""
# TODO: Add support for other embedding models, with max_token mapping
@ -165,9 +169,7 @@ async def get_source_code_chunks(
for code_part in data_point.contains:
try:
yield code_part
for source_code_chunk in get_source_code_chunks_from_code_part(
code_part, model_name=embedding_model
):
for source_code_chunk in get_source_code_chunks_from_code_part(code_part):
yield source_code_chunk
except Exception as e:
logger.error(f"Error processing code part: {e}")

View file

@ -17,5 +17,6 @@ class CodeSummary(DataPoint):
__tablename__ = "code_summary"
text: str
summarizes: Union[CodeFile, CodePart, SourceCodeChunk]
pydantic_type: str = "CodeSummary"
_metadata: dict = {"index_fields": ["text"], "type": "CodeSummary"}

View file

@ -36,12 +36,12 @@ def test_AudioDocument():
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
):
assert (
ground_truth["word_count"] == paragraph_data.word_count
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)

View file

@ -25,12 +25,12 @@ def test_ImageDocument():
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
):
assert (
ground_truth["word_count"] == paragraph_data.word_count
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)

View file

@ -27,12 +27,12 @@ def test_PdfDocument():
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker")
):
assert (
ground_truth["word_count"] == paragraph_data.word_count
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)

View file

@ -39,12 +39,12 @@ def test_TextDocument(input_file, chunk_size):
for ground_truth, paragraph_data in zip(
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker")
):
assert (
ground_truth["word_count"] == paragraph_data.word_count
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)

View file

@ -68,35 +68,35 @@ def test_UnstructuredDocument():
)
# Test PPTX
for paragraph_data in pptx_document.read(chunk_size=1024):
for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
assert (
"sentence_cut" == paragraph_data.cut_type
), f" sentence_cut != {paragraph_data.cut_type = }"
assert "sentence_cut" == paragraph_data.cut_type, (
f" sentence_cut != {paragraph_data.cut_type = }"
)
# Test DOCX
for paragraph_data in docx_document.read(chunk_size=1024):
for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
assert (
"sentence_end" == paragraph_data.cut_type
), f" sentence_end != {paragraph_data.cut_type = }"
assert "sentence_end" == paragraph_data.cut_type, (
f" sentence_end != {paragraph_data.cut_type = }"
)
# TEST CSV
for paragraph_data in csv_document.read(chunk_size=1024):
for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"):
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
assert (
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
), f"Read text doesn't match expected text: {paragraph_data.text}"
assert (
"sentence_cut" == paragraph_data.cut_type
), f" sentence_cut != {paragraph_data.cut_type = }"
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
f"Read text doesn't match expected text: {paragraph_data.text}"
)
assert "sentence_cut" == paragraph_data.cut_type, (
f" sentence_cut != {paragraph_data.cut_type = }"
)
# Test XLSX
for paragraph_data in xlsx_document.read(chunk_size=1024):
for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
assert (
"sentence_cut" == paragraph_data.cut_type
), f" sentence_cut != {paragraph_data.cut_type = }"
assert "sentence_cut" == paragraph_data.cut_type, (
f" sentence_cut != {paragraph_data.cut_type = }"
)

View file

@ -30,9 +30,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found."
assert (
result[0]["name"] == "Natural_language_processing_copy"
), "Result name does not match expected value."
assert result[0]["name"] == "Natural_language_processing_copy", (
"Result name does not match expected value."
)
result = await relational_engine.get_all_data_from_table("datasets")
assert len(result) == 2, "Unexpected number of datasets found."
@ -61,9 +61,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found."
assert (
hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"]
), "Content hash is not a part of file name."
assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], (
"Content hash is not a part of file name."
)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

View file

@ -85,9 +85,9 @@ async def main():
from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists(
get_relational_engine().db_path
), "SQLite relational database is not empty"
assert not os.path.exists(get_relational_engine().db_path), (
"SQLite relational database is not empty"
)
from cognee.infrastructure.databases.graph import get_graph_config

View file

@ -82,9 +82,9 @@ async def main():
from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists(
get_relational_engine().db_path
), "SQLite relational database is not empty"
assert not os.path.exists(get_relational_engine().db_path), (
"SQLite relational database is not empty"
)
from cognee.infrastructure.databases.graph import get_graph_config

View file

@ -24,28 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
data_hash = hashlib.md5(encoded_text).hexdigest()
# Get data entry from database based on hash contents
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
assert os.path.isfile(
data.raw_data_location
), f"Data location doesn't exist: {data.raw_data_location}"
assert os.path.isfile(data.raw_data_location), (
f"Data location doesn't exist: {data.raw_data_location}"
)
# Test deletion of data along with local files created by cognee
await engine.delete_data_entity(data.id)
assert not os.path.exists(
data.raw_data_location
), f"Data location still exists after deletion: {data.raw_data_location}"
assert not os.path.exists(data.raw_data_location), (
f"Data location still exists after deletion: {data.raw_data_location}"
)
async with engine.get_async_session() as session:
# Get data entry from database based on file path
data = (
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
).one()
assert os.path.isfile(
data.raw_data_location
), f"Data location doesn't exist: {data.raw_data_location}"
assert os.path.isfile(data.raw_data_location), (
f"Data location doesn't exist: {data.raw_data_location}"
)
# Test local files not created by cognee won't get deleted
await engine.delete_data_entity(data.id)
assert os.path.exists(
data.raw_data_location
), f"Data location doesn't exists: {data.raw_data_location}"
assert os.path.exists(data.raw_data_location), (
f"Data location doesn't exists: {data.raw_data_location}"
)
async def test_getting_of_documents(dataset_name_1):
@ -54,16 +54,16 @@ async def test_getting_of_documents(dataset_name_1):
user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
assert (
len(document_ids) == 1
), f"Number of expected documents doesn't match {len(document_ids)} != 1"
assert len(document_ids) == 1, (
f"Number of expected documents doesn't match {len(document_ids)} != 1"
)
# Test getting of documents for search when no dataset is provided
user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id)
assert (
len(document_ids) == 2
), f"Number of expected documents doesn't match {len(document_ids)} != 2"
assert len(document_ids) == 2, (
f"Number of expected documents doesn't match {len(document_ids)} != 2"
)
async def main():

View file

@ -17,9 +17,9 @@ batch_paragraphs_vals = [True, False]
def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs):
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)
reconstructed_text = "".join([chunk["text"] for chunk in chunks])
assert (
reconstructed_text == input_text
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
assert reconstructed_text == input_text, (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
)
@pytest.mark.parametrize(
@ -27,14 +27,18 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
)
def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
chunks = list(chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs))
chunks = list(
chunk_by_paragraph(
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
)
)
chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks])
larger_chunks = chunk_lengths[chunk_lengths > paragraph_length]
assert np.all(
chunk_lengths <= paragraph_length
), f"{paragraph_length = }: {larger_chunks} are too large"
assert np.all(chunk_lengths <= paragraph_length), (
f"{paragraph_length = }: {larger_chunks} are too large"
)
@pytest.mark.parametrize(
@ -42,8 +46,10 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)),
)
def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs):
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)
chunks = chunk_by_paragraph(
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
)
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
assert np.all(
chunk_indices == np.arange(len(chunk_indices))
), f"{chunk_indices = } are not monotonically increasing"
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
f"{chunk_indices = } are not monotonically increasing"
)

View file

@ -49,16 +49,18 @@ Third paragraph is cut and is missing the dot at the end""",
def run_chunking_test(test_text, expected_chunks):
chunks = []
for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs=False):
for chunk_data in chunk_by_paragraph(
data=test_text, paragraph_length=12, batch_paragraphs=False
):
chunks.append(chunk_data)
assert len(chunks) == 3
for expected_chunks_item, chunk in zip(expected_chunks, chunks):
for key in ["text", "word_count", "cut_type"]:
assert (
chunk[key] == expected_chunks_item[key]
), f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
assert chunk[key] == expected_chunks_item[key], (
f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
)
def test_chunking_whole_text():

View file

@ -16,9 +16,9 @@ maximum_length_vals = [None, 8, 64]
def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
chunks = chunk_by_sentence(input_text, maximum_length)
reconstructed_text = "".join([chunk[1] for chunk in chunks])
assert (
reconstructed_text == input_text
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
assert reconstructed_text == input_text, (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
)
@pytest.mark.parametrize(
@ -36,6 +36,6 @@ def test_paragraph_chunk_length(input_text, maximum_length):
chunk_lengths = np.array([len(list(chunk_by_word(chunk[1]))) for chunk in chunks])
larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
assert np.all(
chunk_lengths <= maximum_length
), f"{maximum_length = }: {larger_chunks} are too large"
assert np.all(chunk_lengths <= maximum_length), (
f"{maximum_length = }: {larger_chunks} are too large"
)

View file

@ -17,9 +17,9 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS
def test_chunk_by_word_isomorphism(input_text):
chunks = chunk_by_word(input_text)
reconstructed_text = "".join([chunk[0] for chunk in chunks])
assert (
reconstructed_text == input_text
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
assert reconstructed_text == input_text, (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
)
@pytest.mark.parametrize(

View file

@ -11,8 +11,9 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.api.v1.search import SearchType
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.shared.utils import render_graph
from cognee.modules.retrieval.description_to_codepart_search import (
code_description_to_code_part_search,
)
from evals.eval_utils import download_github_repo, retrieved_edges_to_string
@ -32,25 +33,18 @@ def check_install_package(package_name):
return False
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
async def generate_patch_with_cognee(instance):
repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")
pipeline = await run_code_graph_pipeline(repo_path)
async for result in pipeline:
print(result)
print("Here we have the repo under the repo_path")
await render_graph(None, include_labels=True, include_nodes=True)
include_docs = True
problem_statement = instance["problem_statement"]
instructions = read_query_prompt("patch_gen_kg_instructions.txt")
retrieved_edges = await brute_force_triplet_search(
problem_statement, top_k=3, collections=["data_point_source_code", "data_point_text"]
)
async for result in run_code_graph_pipeline(repo_path, include_docs=include_docs):
print(result)
retrieved_edges_str = retrieved_edges_to_string(retrieved_edges)
retrieved_codeparts = await code_description_to_code_part_search(
problem_statement, include_docs=include_docs
)
prompt = "\n".join(
[
@ -58,8 +52,8 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp
"<patch>",
PATCH_EXAMPLE,
"</patch>",
"These are the retrieved edges:",
retrieved_edges_str,
"This is the additional context to solve the problem (description from documentation together with codeparts):",
retrieved_codeparts,
]
)
@ -85,8 +79,6 @@ async def generate_patch_without_cognee(instance, llm_client):
async def get_preds(dataset, with_cognee=True):
llm_client = get_llm_client()
if with_cognee:
model_name = "with_cognee"
pred_func = generate_patch_with_cognee
@ -94,17 +86,18 @@ async def get_preds(dataset, with_cognee=True):
model_name = "without_cognee"
pred_func = generate_patch_without_cognee
futures = [(instance["instance_id"], pred_func(instance, llm_client)) for instance in dataset]
model_patches = await asyncio.gather(*[x[1] for x in futures])
preds = []
preds = [
{
"instance_id": instance_id,
"model_patch": model_patch,
"model_name_or_path": model_name,
}
for (instance_id, _), model_patch in zip(futures, model_patches)
]
for instance in dataset:
instance_id = instance["instance_id"]
model_patch = await pred_func(instance) # Sequentially await the async function
preds.append(
{
"instance_id": instance_id,
"model_patch": model_patch,
"model_name_or_path": model_name,
}
)
return preds
@ -134,6 +127,7 @@ async def main():
with open(predictions_path, "w") as file:
json.dump(preds, file)
""" This part is for the evaluation
subprocess.run(
[
"python",
@ -151,6 +145,7 @@ async def main():
"test_run",
]
)
"""
if __name__ == "__main__":

View file

@ -1,7 +1,9 @@
import argparse
import asyncio
import logging
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.shared.utils import setup_logging
async def main(repo_path, include_docs):
@ -9,7 +11,7 @@ async def main(repo_path, include_docs):
print(result)
if __name__ == "__main__":
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
parser.add_argument(
@ -18,5 +20,28 @@ if __name__ == "__main__":
default=True,
help="Whether or not to process non-code files",
)
args = parser.parse_args()
asyncio.run(main(args.repo_path, args.include_docs))
parser.add_argument(
"--time",
type=lambda x: x.lower() in ("true", "1"),
default=True,
help="Whether or not to time the pipeline run",
)
return parser.parse_args()
if __name__ == "__main__":
setup_logging(logging.ERROR)
args = parse_args()
if args.time:
import time
start_time = time.time()
asyncio.run(main(args.repo_path, args.include_docs))
end_time = time.time()
print("\n" + "=" * 50)
print(f"Pipeline Execution Time: {end_time - start_time:.2f} seconds")
print("=" * 50 + "\n")
else:
asyncio.run(main(args.repo_path, args.include_docs))