Merge remote-tracking branch 'origin/dev' into feat/COG-1058-fastmcp
This commit is contained in:
commit
e577276d91
40 changed files with 358 additions and 431 deletions
|
|
@ -2,7 +2,7 @@ from typing import Union, BinaryIO
|
|||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines import run_tasks, Task
|
||||
from cognee.tasks.ingestion import ingest_data_with_metadata, resolve_data_directories
|
||||
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
||||
from cognee.infrastructure.databases.relational import (
|
||||
create_db_and_tables as create_relational_db_and_tables,
|
||||
)
|
||||
|
|
@ -22,7 +22,7 @@ async def add(
|
|||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
tasks = [Task(resolve_data_directories), Task(ingest_data_with_metadata, dataset_name, user)]
|
||||
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)]
|
||||
|
||||
pipeline = run_tasks(tasks, data, "add_pipeline")
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from cognee.modules.users.methods import get_default_user
|
|||
from cognee.shared.data_models import KnowledgeGraph, MonitoringTool
|
||||
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.ingestion import ingest_data_with_metadata
|
||||
from cognee.tasks.ingestion import ingest_data
|
||||
from cognee.tasks.repo_processor import (
|
||||
enrich_dependency_graph,
|
||||
expand_dependency_graph,
|
||||
|
|
@ -68,7 +68,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
|||
if include_docs:
|
||||
non_code_tasks = [
|
||||
Task(get_non_py_files, task_config={"batch_size": 50}),
|
||||
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
|
||||
Task(ingest_data, dataset_name="repo_docs", user=user),
|
||||
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
|
||||
Task(classify_documents),
|
||||
Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens),
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
return await self.embed_text(text)
|
||||
|
||||
except (litellm.exceptions.BadRequestError, litellm.llms.OpenAI.openai.OpenAIError):
|
||||
except litellm.exceptions.BadRequestError:
|
||||
raise EmbeddingException("Failed to index data points.")
|
||||
|
||||
except Exception as error:
|
||||
|
|
|
|||
|
|
@ -54,7 +54,6 @@ class TextChunker:
|
|||
contains=[],
|
||||
_metadata={
|
||||
"index_fields": ["text"],
|
||||
"metadata_id": self.document.metadata_id,
|
||||
},
|
||||
)
|
||||
paragraph_chunks = []
|
||||
|
|
@ -74,7 +73,6 @@ class TextChunker:
|
|||
contains=[],
|
||||
_metadata={
|
||||
"index_fields": ["text"],
|
||||
"metadata_id": self.document.metadata_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -95,7 +93,7 @@ class TextChunker:
|
|||
chunk_index=self.chunk_index,
|
||||
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
contains=[],
|
||||
_metadata={"index_fields": ["text"], "metadata_id": self.document.metadata_id},
|
||||
_metadata={"index_fields": ["text"]},
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
from sqlalchemy import UUID, Column, DateTime, String
|
||||
from sqlalchemy.orm import Mapped, relationship
|
||||
from sqlalchemy import UUID, Column, DateTime, String, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
from .DatasetData import DatasetData
|
||||
from .Metadata import Metadata
|
||||
|
||||
|
||||
class Data(Base):
|
||||
|
|
@ -21,6 +19,7 @@ class Data(Base):
|
|||
raw_data_location = Column(String)
|
||||
owner_id = Column(UUID, index=True)
|
||||
content_hash = Column(String)
|
||||
external_metadata = Column(JSON)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
|
@ -32,13 +31,6 @@ class Data(Base):
|
|||
cascade="all, delete",
|
||||
)
|
||||
|
||||
metadata_relationship = relationship(
|
||||
"Metadata",
|
||||
back_populates="data",
|
||||
lazy="noload",
|
||||
cascade="all, delete",
|
||||
)
|
||||
|
||||
def to_json(self) -> dict:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import UUID, Column, DateTime, String, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class Metadata(Base):
|
||||
__tablename__ = "metadata_table"
|
||||
|
||||
id = Column(UUID, primary_key=True, default=uuid4)
|
||||
metadata_repr = Column(String)
|
||||
metadata_source = Column(String)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
data_id = Column(UUID, ForeignKey("data.id", ondelete="CASCADE"), primary_key=False)
|
||||
data = relationship("Data", back_populates="metadata_relationship")
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
import warnings
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
from ..models.Metadata import Metadata
|
||||
|
||||
|
||||
async def delete_metadata(metadata_id: UUID):
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
metadata = await session.get(Metadata, metadata_id)
|
||||
if metadata is None:
|
||||
warnings.warn(f"metadata for metadata_id: {metadata_id} not found")
|
||||
|
||||
session.delete(metadata)
|
||||
session.commit()
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
import json
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
from ..models.Metadata import Metadata
|
||||
|
||||
|
||||
async def get_metadata(metadata_id: UUID) -> Metadata:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
metadata = await session.get(Metadata, metadata_id)
|
||||
|
||||
return metadata
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
import inspect
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from typing import Any, BinaryIO, Union
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.files.utils.get_file_metadata import FileMetadata
|
||||
from ..models.Metadata import Metadata
|
||||
|
||||
|
||||
async def write_metadata(
|
||||
data_item: Union[BinaryIO, str, Any], data_id: UUID, file_metadata: FileMetadata
|
||||
) -> UUID:
|
||||
metadata_dict = get_metadata_dict(data_item, file_metadata)
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
metadata = (
|
||||
await session.execute(select(Metadata).filter(Metadata.data_id == data_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if metadata is not None:
|
||||
metadata.metadata_repr = json.dumps(metadata_dict)
|
||||
metadata.metadata_source = parse_type(type(data_item))
|
||||
await session.merge(metadata)
|
||||
else:
|
||||
metadata = Metadata(
|
||||
id=data_id,
|
||||
metadata_repr=json.dumps(metadata_dict),
|
||||
metadata_source=parse_type(type(data_item)),
|
||||
data_id=data_id,
|
||||
)
|
||||
session.add(metadata)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
def parse_type(type_: Any) -> str:
|
||||
pattern = r".+'([\w_\.]+)'"
|
||||
match = re.search(pattern, str(type_))
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
raise Exception(f"type: {type_} could not be parsed")
|
||||
|
||||
|
||||
def get_metadata_dict(
|
||||
data_item: Union[BinaryIO, str, Any], file_metadata: FileMetadata
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(data_item, str):
|
||||
return file_metadata
|
||||
elif isinstance(data_item, BinaryIO):
|
||||
return file_metadata
|
||||
elif hasattr(data_item, "dict") and inspect.ismethod(getattr(data_item, "dict")):
|
||||
return {**file_metadata, **data_item.dict()}
|
||||
else:
|
||||
warnings.warn(
|
||||
f"metadata of type {type(data_item)}: {str(data_item)[:20]}... does not have dict method. Defaulting to string method"
|
||||
)
|
||||
try:
|
||||
return {**dict(file_metadata), "content": str(data_item)}
|
||||
except Exception as e:
|
||||
raise Exception(f"Could not cast metadata to string: {e}")
|
||||
|
|
@ -7,7 +7,7 @@ from cognee.infrastructure.engine import DataPoint
|
|||
class Document(DataPoint):
|
||||
name: str
|
||||
raw_data_location: str
|
||||
metadata_id: UUID
|
||||
external_metadata: Optional[str]
|
||||
mime_type: str
|
||||
_metadata: dict = {"index_fields": ["name"], "type": "Document"}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,29 @@ from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
|||
|
||||
async def chunk_naive_llm_classifier(
|
||||
data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]
|
||||
):
|
||||
) -> list[DocumentChunk]:
|
||||
"""
|
||||
Classifies a list of document chunks using a specified classification model and updates vector and graph databases with the classification results.
|
||||
|
||||
Vector Database Structure:
|
||||
- Collection Name: `classification`
|
||||
- Payload Schema:
|
||||
- uuid (str): Unique identifier for the classification.
|
||||
- text (str): Text label of the classification.
|
||||
- chunk_id (str): Identifier of the chunk associated with this classification.
|
||||
- document_id (str): Identifier of the document associated with this classification.
|
||||
|
||||
Graph Database Structure:
|
||||
- Nodes:
|
||||
- Represent document chunks, classification types, and classification subtypes.
|
||||
- Edges:
|
||||
- `is_media_type`: Links document chunks to their classification type.
|
||||
- `is_subtype_of`: Links classification subtypes to their parent type.
|
||||
- `is_classified_as`: Links document chunks to their classification subtypes.
|
||||
Notes:
|
||||
- The function assumes that vector and graph database engines (`get_vector_engine` and `get_graph_engine`) are properly initialized and accessible.
|
||||
- Classification labels are processed to ensure uniqueness using UUIDs based on their values.
|
||||
"""
|
||||
if len(data_chunks) == 0:
|
||||
return data_chunks
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,12 @@ def chunk_by_paragraph(
|
|||
"""
|
||||
Chunks text by paragraph while preserving exact text reconstruction capability.
|
||||
When chunks are joined with empty string "", they reproduce the original text exactly.
|
||||
|
||||
Notes:
|
||||
- Tokenization is handled using the `tiktoken` library, ensuring compatibility with the vector engine's embedding model.
|
||||
- If `batch_paragraphs` is False, each paragraph will be yielded as a separate chunk.
|
||||
- Handles cases where paragraphs exceed the specified token or word limits by splitting them as needed.
|
||||
- Remaining text at the end of the input will be yielded as a final chunk.
|
||||
"""
|
||||
current_chunk = ""
|
||||
current_word_count = 0
|
||||
|
|
|
|||
|
|
@ -1,9 +1,19 @@
|
|||
from uuid import uuid4
|
||||
from typing import Optional
|
||||
from uuid import uuid4, UUID
|
||||
from typing import Optional, Iterator, Tuple
|
||||
from .chunk_by_word import chunk_by_word
|
||||
|
||||
|
||||
def chunk_by_sentence(data: str, maximum_length: Optional[int] = None):
|
||||
def chunk_by_sentence(
|
||||
data: str, maximum_length: Optional[int] = None
|
||||
) -> Iterator[Tuple[UUID, str, int, Optional[str]]]:
|
||||
"""
|
||||
Splits the input text into sentences based on word-level processing, with optional sentence length constraints.
|
||||
|
||||
Notes:
|
||||
- Relies on the `chunk_by_word` function for word-level tokenization and classification.
|
||||
- Ensures sentences within paragraphs are uniquely identifiable using UUIDs.
|
||||
- Handles cases where the text ends mid-sentence by appending a special "sentence_cut" type.
|
||||
"""
|
||||
sentence = ""
|
||||
paragraph_id = uuid4()
|
||||
word_count = 0
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import re
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
|
||||
SENTENCE_ENDINGS = r"[.;!?…]"
|
||||
PARAGRAPH_ENDINGS = r"[\n\r]"
|
||||
|
|
@ -34,7 +36,7 @@ def is_real_paragraph_end(last_char: str, current_pos: int, text: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def chunk_by_word(data: str):
|
||||
def chunk_by_word(data: str) -> Iterator[Tuple[str, str]]:
|
||||
"""
|
||||
Chunks text into words and endings while preserving whitespace.
|
||||
Whitespace is included with the preceding word.
|
||||
|
|
|
|||
|
|
@ -3,11 +3,19 @@ from cognee.infrastructure.databases.vector import get_vector_engine
|
|||
|
||||
async def query_chunks(query: str) -> list[dict]:
|
||||
"""
|
||||
|
||||
Queries the vector database to retrieve chunks related to the given query string.
|
||||
|
||||
Parameters:
|
||||
- query (str): The query string to filter nodes by.
|
||||
|
||||
Returns:
|
||||
- list(dict): A list of objects providing information about the chunks related to query.
|
||||
|
||||
Notes:
|
||||
- The function uses the `search` method of the vector engine to find matches.
|
||||
- Limits the results to the top 5 matching chunks to balance performance and relevance.
|
||||
- Ensure that the vector database is properly initialized and contains the "document_chunk_text" collection.
|
||||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,14 @@ from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
|||
|
||||
|
||||
async def remove_disconnected_chunks(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
|
||||
"""
|
||||
Removes disconnected or obsolete chunks from the graph database.
|
||||
|
||||
Notes:
|
||||
- Obsolete chunks are defined as chunks with no "next_chunk" predecessor.
|
||||
- Fully disconnected nodes are identified and deleted separately.
|
||||
- This function assumes that the graph database is properly initialized and accessible.
|
||||
"""
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
document_ids = set((data_chunk.document_id for data_chunk in data_chunks))
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ from cognee.modules.retrieval.brute_force_triplet_search import brute_force_trip
|
|||
|
||||
|
||||
def retrieved_edges_to_string(retrieved_edges: list) -> str:
|
||||
"""
|
||||
Converts a list of retrieved graph edges into a human-readable string format.
|
||||
|
||||
"""
|
||||
edge_strings = []
|
||||
for edge in retrieved_edges:
|
||||
node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name")
|
||||
|
|
@ -18,11 +22,19 @@ def retrieved_edges_to_string(retrieved_edges: list) -> str:
|
|||
|
||||
async def graph_query_completion(query: str) -> list:
|
||||
"""
|
||||
Executes a query on the graph database and retrieves a relevant completion based on the found data.
|
||||
|
||||
Parameters:
|
||||
- query (str): The query string to compute.
|
||||
|
||||
Returns:
|
||||
- list: Answer to the query.
|
||||
|
||||
Notes:
|
||||
- The `brute_force_triplet_search` is used to retrieve relevant graph data.
|
||||
- Prompts are dynamically rendered and provided to the LLM for contextual understanding.
|
||||
- Ensure that the LLM client and graph database are properly configured and accessible.
|
||||
|
||||
"""
|
||||
found_triplets = await brute_force_triplet_search(query, top_k=5)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,20 @@ from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
|||
|
||||
async def query_completion(query: str) -> list:
|
||||
"""
|
||||
|
||||
Executes a query against a vector database and computes a relevant response using an LLM.
|
||||
|
||||
Parameters:
|
||||
- query (str): The query string to compute.
|
||||
|
||||
Returns:
|
||||
- list: Answer to the query.
|
||||
|
||||
Notes:
|
||||
- Limits the search to the top 1 matching chunk for simplicity and relevance.
|
||||
- Ensure that the vector database and LLM client are properly configured and accessible.
|
||||
- The response model used for the LLM output is expected to be a string.
|
||||
|
||||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,19 @@
|
|||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.users.permissions.methods import check_permission_on_documents
|
||||
from typing import List
|
||||
|
||||
|
||||
async def check_permissions_on_documents(documents: list[Document], user, permissions):
|
||||
async def check_permissions_on_documents(
|
||||
documents: list[Document], user, permissions
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Validates a user's permissions on a list of documents.
|
||||
|
||||
Notes:
|
||||
- This function assumes that `check_permission_on_documents` raises an exception if the permission check fails.
|
||||
- It is designed to validate multiple permissions in a sequential manner for the same set of documents.
|
||||
- Ensure that the `Document` and `user` objects conform to the expected structure and interfaces.
|
||||
"""
|
||||
document_ids = [document.id for document in documents]
|
||||
|
||||
for permission in permissions:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from cognee.modules.data.models import Data
|
||||
import json
|
||||
from cognee.modules.data.processing.document_types import (
|
||||
Document,
|
||||
PdfDocument,
|
||||
|
|
@ -7,7 +8,6 @@ from cognee.modules.data.processing.document_types import (
|
|||
TextDocument,
|
||||
UnstructuredDocument,
|
||||
)
|
||||
from cognee.modules.data.operations.get_metadata import get_metadata
|
||||
|
||||
EXTENSION_TO_DOCUMENT_CLASS = {
|
||||
"pdf": PdfDocument, # Text documents
|
||||
|
|
@ -50,16 +50,22 @@ EXTENSION_TO_DOCUMENT_CLASS = {
|
|||
|
||||
|
||||
async def classify_documents(data_documents: list[Data]) -> list[Document]:
|
||||
"""
|
||||
Classifies a list of data items into specific document types based on file extensions.
|
||||
|
||||
Notes:
|
||||
- The function relies on `get_metadata` to retrieve metadata information for each data item.
|
||||
- Ensure the `Data` objects and their attributes (e.g., `extension`, `id`) are valid before calling this function.
|
||||
"""
|
||||
documents = []
|
||||
for data_item in data_documents:
|
||||
metadata = await get_metadata(data_item.id)
|
||||
document = EXTENSION_TO_DOCUMENT_CLASS[data_item.extension](
|
||||
id=data_item.id,
|
||||
title=f"{data_item.name}.{data_item.extension}",
|
||||
raw_data_location=data_item.raw_data_location,
|
||||
name=data_item.name,
|
||||
mime_type=data_item.mime_type,
|
||||
metadata_id=metadata.id,
|
||||
external_metadata=json.dumps(data_item.external_metadata, indent=4),
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, AsyncGenerator
|
||||
|
||||
from cognee.modules.data.processing.document_types.Document import Document
|
||||
|
||||
|
|
@ -8,7 +8,14 @@ async def extract_chunks_from_documents(
|
|||
chunk_size: int = 1024,
|
||||
chunker="text_chunker",
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
) -> AsyncGenerator:
|
||||
"""
|
||||
Extracts chunks of data from a list of documents based on the specified chunking parameters.
|
||||
|
||||
Notes:
|
||||
- The `read` method of the `Document` class must be implemented to support the chunking operation.
|
||||
- The `chunker` parameter determines the chunking logic and should align with the document type.
|
||||
"""
|
||||
for document in documents:
|
||||
for document_chunk in document.read(
|
||||
chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens
|
||||
|
|
|
|||
|
|
@ -1,12 +1,21 @@
|
|||
import asyncio
|
||||
from typing import Type
|
||||
from typing import Type, List
|
||||
from pydantic import BaseModel
|
||||
from cognee.modules.data.extraction.knowledge_graph import extract_content_graph
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
|
||||
async def extract_graph_from_code(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
|
||||
async def extract_graph_from_code(
|
||||
data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]
|
||||
) -> List[DocumentChunk]:
|
||||
"""
|
||||
Extracts a knowledge graph from the text content of document chunks using a specified graph model.
|
||||
|
||||
Notes:
|
||||
- The `extract_content_graph` function processes each chunk's text to extract graph information.
|
||||
- Graph nodes are stored using the `add_data_points` function for later retrieval or analysis.
|
||||
"""
|
||||
chunk_graphs = await asyncio.gather(
|
||||
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Type
|
||||
from typing import Type, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -13,7 +13,14 @@ from cognee.modules.graph.utils import (
|
|||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
|
||||
async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
|
||||
async def extract_graph_from_data(
|
||||
data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]
|
||||
) -> List[DocumentChunk]:
|
||||
"""
|
||||
Extracts and integrates a knowledge graph from the text content of document chunks using a specified graph model.
|
||||
|
||||
"""
|
||||
|
||||
chunk_graphs = await asyncio.gather(
|
||||
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,3 @@
|
|||
from .ingest_data import ingest_data
|
||||
from .save_data_to_storage import save_data_to_storage
|
||||
from .save_data_item_to_storage import save_data_item_to_storage
|
||||
from .save_data_item_with_metadata_to_storage import save_data_item_with_metadata_to_storage
|
||||
from .ingest_data_with_metadata import ingest_data_with_metadata
|
||||
from .ingest_data import ingest_data
|
||||
from .resolve_data_directories import resolve_data_directories
|
||||
|
|
|
|||
|
|
@ -1,16 +1,24 @@
|
|||
from typing import Any, List
|
||||
|
||||
import dlt
|
||||
import cognee.modules.ingestion as ingestion
|
||||
|
||||
from uuid import UUID
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.modules.data.models.DatasetData import DatasetData
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.permissions.methods import give_permission_on_document
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from .get_dlt_destination import get_dlt_destination
|
||||
from .save_data_item_to_storage import (
|
||||
save_data_item_to_storage,
|
||||
)
|
||||
|
||||
from typing import Union, BinaryIO
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
|
||||
async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
||||
async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||
destination = get_dlt_destination()
|
||||
|
||||
pipeline = dlt.pipeline(
|
||||
|
|
@ -18,12 +26,21 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
|||
destination=destination,
|
||||
)
|
||||
|
||||
@dlt.resource(standalone=True, merge_key="id")
|
||||
async def data_resources(file_paths: str):
|
||||
def get_external_metadata_dict(data_item: Union[BinaryIO, str, Any]) -> dict[str, Any]:
|
||||
if hasattr(data_item, "dict") and inspect.ismethod(getattr(data_item, "dict")):
|
||||
return {"metadata": data_item.dict(), "origin": str(type(data_item))}
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Data of type {type(data_item)}... does not have dict method. Returning empty metadata."
|
||||
)
|
||||
return {}
|
||||
|
||||
@dlt.resource(standalone=True, primary_key="id", merge_key="id")
|
||||
async def data_resources(file_paths: List[str], user: User):
|
||||
for file_path in file_paths:
|
||||
with open(file_path.replace("file://", ""), mode="rb") as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
data_id = ingestion.identify(classified_data)
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
file_metadata = classified_data.get_metadata()
|
||||
yield {
|
||||
"id": data_id,
|
||||
|
|
@ -31,71 +48,111 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
|||
"file_path": file_metadata["file_path"],
|
||||
"extension": file_metadata["extension"],
|
||||
"mime_type": file_metadata["mime_type"],
|
||||
"content_hash": file_metadata["content_hash"],
|
||||
"owner_id": str(user.id),
|
||||
}
|
||||
|
||||
async def data_storing(table_name, dataset_name, user: User):
|
||||
db_engine = get_relational_engine()
|
||||
async def data_storing(data: Any, dataset_name: str, user: User):
|
||||
if not isinstance(data, list):
|
||||
# Convert data to a list as we work with lists further down.
|
||||
data = [data]
|
||||
|
||||
file_paths = []
|
||||
|
||||
# Process data
|
||||
for data_item in data:
|
||||
file_path = await save_data_item_to_storage(data_item, dataset_name)
|
||||
|
||||
file_paths.append(file_path)
|
||||
|
||||
# Ingest data and add metadata
|
||||
with open(file_path.replace("file://", ""), mode="rb") as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
|
||||
# data_id is the hash of file contents + owner id to avoid duplicate data
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
|
||||
file_metadata = classified_data.get_metadata()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
# Read metadata stored with dlt
|
||||
files_metadata = await db_engine.get_all_data_from_table(table_name, dataset_name)
|
||||
for file_metadata in files_metadata:
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.modules.data.models import Data
|
||||
|
||||
dataset = await create_dataset(dataset_name, user.id, session)
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
data = (
|
||||
await session.execute(select(Data).filter(Data.id == UUID(file_metadata["id"])))
|
||||
).scalar_one_or_none()
|
||||
async with db_engine.get_async_session() as session:
|
||||
dataset = await create_dataset(dataset_name, user.id, session)
|
||||
|
||||
if data is not None:
|
||||
data.name = file_metadata["name"]
|
||||
data.raw_data_location = file_metadata["file_path"]
|
||||
data.extension = file_metadata["extension"]
|
||||
data.mime_type = file_metadata["mime_type"]
|
||||
# Check to see if data should be updated
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
await session.merge(data)
|
||||
await session.commit()
|
||||
else:
|
||||
data = Data(
|
||||
id=UUID(file_metadata["id"]),
|
||||
name=file_metadata["name"],
|
||||
raw_data_location=file_metadata["file_path"],
|
||||
extension=file_metadata["extension"],
|
||||
mime_type=file_metadata["mime_type"],
|
||||
)
|
||||
if data_point is not None:
|
||||
data_point.name = file_metadata["name"]
|
||||
data_point.raw_data_location = file_metadata["file_path"]
|
||||
data_point.extension = file_metadata["extension"]
|
||||
data_point.mime_type = file_metadata["mime_type"]
|
||||
data_point.owner_id = user.id
|
||||
data_point.content_hash = file_metadata["content_hash"]
|
||||
data_point.external_metadata = (get_external_metadata_dict(data_item),)
|
||||
await session.merge(data_point)
|
||||
else:
|
||||
data_point = Data(
|
||||
id=data_id,
|
||||
name=file_metadata["name"],
|
||||
raw_data_location=file_metadata["file_path"],
|
||||
extension=file_metadata["extension"],
|
||||
mime_type=file_metadata["mime_type"],
|
||||
owner_id=user.id,
|
||||
content_hash=file_metadata["content_hash"],
|
||||
external_metadata=get_external_metadata_dict(data_item),
|
||||
)
|
||||
|
||||
# Check if data is already in dataset
|
||||
dataset_data = (
|
||||
await session.execute(
|
||||
select(DatasetData).filter(
|
||||
DatasetData.data_id == data_id, DatasetData.dataset_id == dataset.id
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
# If data is not present in dataset add it
|
||||
if dataset_data is None:
|
||||
dataset.data.append(data_point)
|
||||
|
||||
dataset.data.append(data)
|
||||
await session.commit()
|
||||
|
||||
await give_permission_on_document(user, UUID(file_metadata["id"]), "read")
|
||||
await give_permission_on_document(user, UUID(file_metadata["id"]), "write")
|
||||
await give_permission_on_document(user, data_id, "read")
|
||||
await give_permission_on_document(user, data_id, "write")
|
||||
return file_paths
|
||||
|
||||
send_telemetry("cognee.add EXECUTION STARTED", user_id=user.id)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
file_paths = await data_storing(data, dataset_name, user)
|
||||
|
||||
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
|
||||
# can't be used inside the pipeline
|
||||
if db_engine.engine.dialect.name == "sqlite":
|
||||
# To use sqlite with dlt dataset_name must be set to "main".
|
||||
# Sqlite doesn't support schemas
|
||||
run_info = pipeline.run(
|
||||
data_resources(file_paths),
|
||||
data_resources(file_paths, user),
|
||||
table_name="file_metadata",
|
||||
dataset_name="main",
|
||||
write_disposition="merge",
|
||||
)
|
||||
else:
|
||||
# Data should be stored in the same schema to allow deduplication
|
||||
run_info = pipeline.run(
|
||||
data_resources(file_paths),
|
||||
data_resources(file_paths, user),
|
||||
table_name="file_metadata",
|
||||
dataset_name=dataset_name,
|
||||
dataset_name="public",
|
||||
write_disposition="merge",
|
||||
)
|
||||
|
||||
await data_storing("file_metadata", dataset_name, user)
|
||||
send_telemetry("cognee.add EXECUTION COMPLETED", user_id=user.id)
|
||||
|
||||
return run_info
|
||||
|
|
|
|||
|
|
@ -1,145 +0,0 @@
|
|||
from typing import Any, List
|
||||
|
||||
import dlt
|
||||
import cognee.modules.ingestion as ingestion
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.modules.data.models.DatasetData import DatasetData
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.permissions.methods import give_permission_on_document
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.data.operations.write_metadata import write_metadata
|
||||
from .get_dlt_destination import get_dlt_destination
|
||||
from .save_data_item_with_metadata_to_storage import (
|
||||
save_data_item_with_metadata_to_storage,
|
||||
)
|
||||
|
||||
|
||||
async def ingest_data_with_metadata(data: Any, dataset_name: str, user: User):
|
||||
destination = get_dlt_destination()
|
||||
|
||||
pipeline = dlt.pipeline(
|
||||
pipeline_name="file_load_from_filesystem",
|
||||
destination=destination,
|
||||
)
|
||||
|
||||
@dlt.resource(standalone=True, primary_key="id", merge_key="id")
|
||||
async def data_resources(file_paths: List[str], user: User):
|
||||
for file_path in file_paths:
|
||||
with open(file_path.replace("file://", ""), mode="rb") as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
file_metadata = classified_data.get_metadata()
|
||||
yield {
|
||||
"id": data_id,
|
||||
"name": file_metadata["name"],
|
||||
"file_path": file_metadata["file_path"],
|
||||
"extension": file_metadata["extension"],
|
||||
"mime_type": file_metadata["mime_type"],
|
||||
"content_hash": file_metadata["content_hash"],
|
||||
"owner_id": str(user.id),
|
||||
}
|
||||
|
||||
async def data_storing(data: Any, dataset_name: str, user: User):
|
||||
if not isinstance(data, list):
|
||||
# Convert data to a list as we work with lists further down.
|
||||
data = [data]
|
||||
|
||||
file_paths = []
|
||||
|
||||
# Process data
|
||||
for data_item in data:
|
||||
file_path = await save_data_item_with_metadata_to_storage(data_item, dataset_name)
|
||||
|
||||
file_paths.append(file_path)
|
||||
|
||||
# Ingest data and add metadata
|
||||
with open(file_path.replace("file://", ""), mode="rb") as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
|
||||
# data_id is the hash of file contents + owner id to avoid duplicate data
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
|
||||
file_metadata = classified_data.get_metadata()
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.modules.data.models import Data
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
dataset = await create_dataset(dataset_name, user.id, session)
|
||||
|
||||
# Check to see if data should be updated
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if data_point is not None:
|
||||
data_point.name = file_metadata["name"]
|
||||
data_point.raw_data_location = file_metadata["file_path"]
|
||||
data_point.extension = file_metadata["extension"]
|
||||
data_point.mime_type = file_metadata["mime_type"]
|
||||
data_point.owner_id = user.id
|
||||
data_point.content_hash = file_metadata["content_hash"]
|
||||
await session.merge(data_point)
|
||||
else:
|
||||
data_point = Data(
|
||||
id=data_id,
|
||||
name=file_metadata["name"],
|
||||
raw_data_location=file_metadata["file_path"],
|
||||
extension=file_metadata["extension"],
|
||||
mime_type=file_metadata["mime_type"],
|
||||
owner_id=user.id,
|
||||
content_hash=file_metadata["content_hash"],
|
||||
)
|
||||
|
||||
# Check if data is already in dataset
|
||||
dataset_data = (
|
||||
await session.execute(
|
||||
select(DatasetData).filter(
|
||||
DatasetData.data_id == data_id, DatasetData.dataset_id == dataset.id
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
# If data is not present in dataset add it
|
||||
if dataset_data is None:
|
||||
dataset.data.append(data_point)
|
||||
|
||||
await session.commit()
|
||||
await write_metadata(data_item, data_point.id, file_metadata)
|
||||
|
||||
await give_permission_on_document(user, data_id, "read")
|
||||
await give_permission_on_document(user, data_id, "write")
|
||||
return file_paths
|
||||
|
||||
send_telemetry("cognee.add EXECUTION STARTED", user_id=user.id)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
file_paths = await data_storing(data, dataset_name, user)
|
||||
|
||||
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
|
||||
# can't be used inside the pipeline
|
||||
if db_engine.engine.dialect.name == "sqlite":
|
||||
# To use sqlite with dlt dataset_name must be set to "main".
|
||||
# Sqlite doesn't support schemas
|
||||
run_info = pipeline.run(
|
||||
data_resources(file_paths, user),
|
||||
table_name="file_metadata",
|
||||
dataset_name="main",
|
||||
write_disposition="merge",
|
||||
)
|
||||
else:
|
||||
# Data should be stored in the same schema to allow deduplication
|
||||
run_info = pipeline.run(
|
||||
data_resources(file_paths, user),
|
||||
table_name="file_metadata",
|
||||
dataset_name="public",
|
||||
write_disposition="merge",
|
||||
)
|
||||
|
||||
send_telemetry("cognee.add EXECUTION COMPLETED", user_id=user.id)
|
||||
|
||||
return run_info
|
||||
|
|
@ -1,12 +1,18 @@
|
|||
from typing import Union, BinaryIO
|
||||
from typing import Union, BinaryIO, Any
|
||||
|
||||
from cognee.modules.ingestion.exceptions import IngestionError
|
||||
from cognee.modules.ingestion import save_data_to_file
|
||||
|
||||
|
||||
def save_data_item_to_storage(data_item: Union[BinaryIO, str], dataset_name: str) -> str:
|
||||
async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any], dataset_name: str) -> str:
|
||||
if "llama_index" in str(type(data_item)):
|
||||
# Dynamic import is used because the llama_index module is optional.
|
||||
from .transform_data import get_data_from_llama_index
|
||||
|
||||
file_path = get_data_from_llama_index(data_item, dataset_name)
|
||||
|
||||
# data is a file object coming from upload.
|
||||
if hasattr(data_item, "file"):
|
||||
elif hasattr(data_item, "file"):
|
||||
file_path = save_data_to_file(data_item.file, filename=data_item.filename)
|
||||
|
||||
elif isinstance(data_item, str):
|
||||
|
|
|
|||
|
|
@ -1,30 +0,0 @@
|
|||
from typing import Union, BinaryIO, Any
|
||||
|
||||
from cognee.modules.ingestion.exceptions import IngestionError
|
||||
from cognee.modules.ingestion import save_data_to_file
|
||||
|
||||
|
||||
async def save_data_item_with_metadata_to_storage(
|
||||
data_item: Union[BinaryIO, str, Any], dataset_name: str
|
||||
) -> str:
|
||||
if "llama_index" in str(type(data_item)):
|
||||
# Dynamic import is used because the llama_index module is optional.
|
||||
from .transform_data import get_data_from_llama_index
|
||||
|
||||
file_path = get_data_from_llama_index(data_item, dataset_name)
|
||||
|
||||
# data is a file object coming from upload.
|
||||
elif hasattr(data_item, "file"):
|
||||
file_path = save_data_to_file(data_item.file, filename=data_item.filename)
|
||||
|
||||
elif isinstance(data_item, str):
|
||||
# data is a file path
|
||||
if data_item.startswith("file://") or data_item.startswith("/"):
|
||||
file_path = data_item.replace("file://", "")
|
||||
# data is text
|
||||
else:
|
||||
file_path = save_data_to_file(data_item)
|
||||
else:
|
||||
raise IngestionError(message=f"Data type not supported: {type(data_item)}")
|
||||
|
||||
return file_path
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
from typing import Union, BinaryIO
|
||||
from cognee.tasks.ingestion.save_data_item_to_storage import save_data_item_to_storage
|
||||
|
||||
|
||||
def save_data_to_storage(data: Union[BinaryIO, str], dataset_name) -> list[str]:
|
||||
if not isinstance(data, list):
|
||||
# Convert data to a list as we work with lists further down.
|
||||
data = [data]
|
||||
|
||||
file_paths = []
|
||||
|
||||
for data_item in data:
|
||||
file_path = save_data_item_to_storage(data_item, dataset_name)
|
||||
file_paths.append(file_path)
|
||||
|
||||
return file_paths
|
||||
|
|
@ -29,7 +29,7 @@ def test_AudioDocument():
|
|||
id=uuid.uuid4(),
|
||||
name="audio-dummy-test",
|
||||
raw_data_location="",
|
||||
metadata_id=uuid.uuid4(),
|
||||
external_metadata="",
|
||||
mime_type="",
|
||||
)
|
||||
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ def test_ImageDocument():
|
|||
id=uuid.uuid4(),
|
||||
name="image-dummy-test",
|
||||
raw_data_location="",
|
||||
metadata_id=uuid.uuid4(),
|
||||
external_metadata="",
|
||||
mime_type="",
|
||||
)
|
||||
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def test_PdfDocument():
|
|||
id=uuid.uuid4(),
|
||||
name="Test document.pdf",
|
||||
raw_data_location=test_file_path,
|
||||
metadata_id=uuid.uuid4(),
|
||||
external_metadata="",
|
||||
mime_type="",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def test_TextDocument(input_file, chunk_size):
|
|||
id=uuid.uuid4(),
|
||||
name=input_file,
|
||||
raw_data_location=test_file_path,
|
||||
metadata_id=uuid.uuid4(),
|
||||
external_metadata="",
|
||||
mime_type="",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def test_UnstructuredDocument():
|
|||
id=uuid.uuid4(),
|
||||
name="example.pptx",
|
||||
raw_data_location=pptx_file_path,
|
||||
metadata_id=uuid.uuid4(),
|
||||
external_metadata="",
|
||||
mime_type="application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
)
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ def test_UnstructuredDocument():
|
|||
id=uuid.uuid4(),
|
||||
name="example.docx",
|
||||
raw_data_location=docx_file_path,
|
||||
metadata_id=uuid.uuid4(),
|
||||
external_metadata="",
|
||||
mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
)
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ def test_UnstructuredDocument():
|
|||
id=uuid.uuid4(),
|
||||
name="example.csv",
|
||||
raw_data_location=csv_file_path,
|
||||
metadata_id=uuid.uuid4(),
|
||||
external_metadata="",
|
||||
mime_type="text/csv",
|
||||
)
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ def test_UnstructuredDocument():
|
|||
id=uuid.uuid4(),
|
||||
name="example.xlsx",
|
||||
raw_data_location=xlsx_file_path,
|
||||
metadata_id=uuid.uuid4(),
|
||||
external_metadata="",
|
||||
mime_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,12 +10,29 @@ from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
|||
from evals.qa_dataset_utils import load_qa_dataset
|
||||
from evals.qa_metrics_utils import get_metrics
|
||||
from evals.qa_context_provider_utils import qa_context_providers, valid_pipeline_slices
|
||||
import random
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def answer_qa_instance(instance, context_provider):
|
||||
context = await context_provider(instance)
|
||||
async def answer_qa_instance(instance, context_provider, contexts_filename):
|
||||
if os.path.exists(contexts_filename):
|
||||
with open(contexts_filename, "r") as file:
|
||||
preloaded_contexts = json.load(file)
|
||||
else:
|
||||
preloaded_contexts = {}
|
||||
|
||||
if instance["_id"] in preloaded_contexts:
|
||||
context = preloaded_contexts[instance["_id"]]
|
||||
else:
|
||||
context = await context_provider(instance)
|
||||
preloaded_contexts[instance["_id"]] = context
|
||||
|
||||
with open(contexts_filename, "w") as file:
|
||||
json.dump(preloaded_contexts, file)
|
||||
|
||||
args = {
|
||||
"question": instance["question"],
|
||||
|
|
@ -49,12 +66,27 @@ async def deepeval_answers(instances, answers, eval_metrics):
|
|||
return eval_results
|
||||
|
||||
|
||||
async def deepeval_on_instances(instances, context_provider, eval_metrics):
|
||||
async def deepeval_on_instances(
|
||||
instances, context_provider, eval_metrics, answers_filename, contexts_filename
|
||||
):
|
||||
if os.path.exists(answers_filename):
|
||||
with open(answers_filename, "r") as file:
|
||||
preloaded_answers = json.load(file)
|
||||
else:
|
||||
preloaded_answers = {}
|
||||
|
||||
answers = []
|
||||
for instance in tqdm(instances, desc="Getting answers"):
|
||||
answer = await answer_qa_instance(instance, context_provider)
|
||||
if instance["_id"] in preloaded_answers:
|
||||
answer = preloaded_answers[instance["_id"]]
|
||||
else:
|
||||
answer = await answer_qa_instance(instance, context_provider, contexts_filename)
|
||||
preloaded_answers[instance["_id"]] = answer
|
||||
answers.append(answer)
|
||||
|
||||
with open(answers_filename, "w") as file:
|
||||
json.dump(preloaded_answers, file)
|
||||
|
||||
eval_results = await deepeval_answers(instances, answers, eval_metrics)
|
||||
score_lists_dict = {}
|
||||
for instance_result in eval_results.test_results:
|
||||
|
|
@ -72,21 +104,38 @@ async def deepeval_on_instances(instances, context_provider, eval_metrics):
|
|||
|
||||
|
||||
async def eval_on_QA_dataset(
|
||||
dataset_name_or_filename: str, context_provider_name, num_samples, metric_name_list
|
||||
dataset_name_or_filename: str, context_provider_name, num_samples, metric_name_list, out_path
|
||||
):
|
||||
dataset = load_qa_dataset(dataset_name_or_filename)
|
||||
context_provider = qa_context_providers[context_provider_name]
|
||||
eval_metrics = get_metrics(metric_name_list)
|
||||
instances = dataset if not num_samples else dataset[:num_samples]
|
||||
|
||||
out_path = Path(out_path)
|
||||
if not out_path.exists():
|
||||
out_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
random.seed(42)
|
||||
instances = dataset if not num_samples else random.sample(dataset, num_samples)
|
||||
|
||||
contexts_filename = out_path / Path(
|
||||
f"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json"
|
||||
)
|
||||
if "promptfoo_metrics" in eval_metrics:
|
||||
promptfoo_results = await eval_metrics["promptfoo_metrics"].measure(
|
||||
instances, context_provider
|
||||
instances, context_provider, contexts_filename
|
||||
)
|
||||
else:
|
||||
promptfoo_results = {}
|
||||
|
||||
answers_filename = out_path / Path(
|
||||
f"answers_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json"
|
||||
)
|
||||
deepeval_results = await deepeval_on_instances(
|
||||
instances, context_provider, eval_metrics["deepeval_metrics"]
|
||||
instances,
|
||||
context_provider,
|
||||
eval_metrics["deepeval_metrics"],
|
||||
answers_filename,
|
||||
contexts_filename,
|
||||
)
|
||||
|
||||
results = promptfoo_results | deepeval_results
|
||||
|
|
@ -95,14 +144,14 @@ async def eval_on_QA_dataset(
|
|||
|
||||
|
||||
async def incremental_eval_on_QA_dataset(
|
||||
dataset_name_or_filename: str, num_samples, metric_name_list
|
||||
dataset_name_or_filename: str, num_samples, metric_name_list, out_path
|
||||
):
|
||||
pipeline_slice_names = valid_pipeline_slices.keys()
|
||||
|
||||
incremental_results = {}
|
||||
for pipeline_slice_name in pipeline_slice_names:
|
||||
results = await eval_on_QA_dataset(
|
||||
dataset_name_or_filename, pipeline_slice_name, num_samples, metric_name_list
|
||||
dataset_name_or_filename, pipeline_slice_name, num_samples, metric_name_list, out_path
|
||||
)
|
||||
incremental_results[pipeline_slice_name] = results
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class PromptfooMetric:
|
|||
else:
|
||||
raise Exception(f"{metric_name} is not a valid promptfoo metric")
|
||||
|
||||
async def measure(self, instances, context_provider):
|
||||
async def measure(self, instances, context_provider, contexts_filename):
|
||||
with open(os.path.join(os.getcwd(), "evals/promptfoo_config_template.yaml"), "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
|
|
@ -40,10 +40,20 @@ class PromptfooMetric:
|
|||
]
|
||||
}
|
||||
|
||||
# Fill config file with test cases
|
||||
tests = []
|
||||
if os.path.exists(contexts_filename):
|
||||
with open(contexts_filename, "r") as file:
|
||||
preloaded_contexts = json.load(file)
|
||||
else:
|
||||
preloaded_contexts = {}
|
||||
|
||||
for instance in instances:
|
||||
context = await context_provider(instance)
|
||||
if instance["_id"] in preloaded_contexts:
|
||||
context = preloaded_contexts[instance["_id"]]
|
||||
else:
|
||||
context = await context_provider(instance)
|
||||
preloaded_contexts[instance["_id"]] = context
|
||||
|
||||
test = {
|
||||
"vars": {
|
||||
"name": instance["question"][:15],
|
||||
|
|
@ -52,7 +62,10 @@ class PromptfooMetric:
|
|||
}
|
||||
}
|
||||
tests.append(test)
|
||||
|
||||
config["tests"] = tests
|
||||
with open(contexts_filename, "w") as file:
|
||||
json.dump(preloaded_contexts, file)
|
||||
|
||||
# Write the updated YAML back, preserving formatting and structure
|
||||
updated_yaml_file_path = os.path.join(os.getcwd(), "config_with_context.yaml")
|
||||
|
|
|
|||
|
|
@ -39,10 +39,22 @@ def _insight_to_string(triplet: tuple) -> str:
|
|||
return ""
|
||||
|
||||
node1_name = node1["name"] if "name" in node1 else "N/A"
|
||||
node1_description = node1["description"] if "description" in node1 else node1["text"]
|
||||
node1_description = (
|
||||
node1["description"]
|
||||
if "description" in node1
|
||||
else node1["text"]
|
||||
if "text" in node1
|
||||
else "N/A"
|
||||
)
|
||||
node1_string = f"name: {node1_name}, description: {node1_description}"
|
||||
node2_name = node2["name"] if "name" in node2 else "N/A"
|
||||
node2_description = node2["description"] if "description" in node2 else node2["text"]
|
||||
node2_description = (
|
||||
node2["description"]
|
||||
if "description" in node2
|
||||
else node2["text"]
|
||||
if "text" in node2
|
||||
else "N/A"
|
||||
)
|
||||
node2_string = f"name: {node2_name}, description: {node2_description}"
|
||||
|
||||
edge_string = edge.get("relationship_name", "")
|
||||
|
|
@ -58,7 +70,7 @@ def _insight_to_string(triplet: tuple) -> str:
|
|||
async def get_context_with_cognee(
|
||||
instance: dict,
|
||||
task_indices: list[int] = None,
|
||||
search_types: list[SearchType] = [SearchType.SUMMARIES, SearchType.CHUNKS],
|
||||
search_types: list[SearchType] = [SearchType.INSIGHTS, SearchType.SUMMARIES, SearchType.CHUNKS],
|
||||
) -> str:
|
||||
await cognify_instance(instance, task_indices)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,10 @@
|
|||
],
|
||||
"metric_names": [
|
||||
"Correctness",
|
||||
"Comprehensiveness"
|
||||
"Comprehensiveness",
|
||||
"Directness",
|
||||
"Diversity",
|
||||
"Empowerment",
|
||||
"promptfoo.directness"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,17 +22,12 @@ async def run_evals_on_paramset(paramset: dict, out_path: str):
|
|||
|
||||
if rag_option == "cognee_incremental":
|
||||
result = await incremental_eval_on_QA_dataset(
|
||||
dataset,
|
||||
num_samples,
|
||||
paramset["metric_names"],
|
||||
dataset, num_samples, paramset["metric_names"], out_path
|
||||
)
|
||||
results[dataset][num_samples] |= result
|
||||
else:
|
||||
result = await eval_on_QA_dataset(
|
||||
dataset,
|
||||
rag_option,
|
||||
num_samples,
|
||||
paramset["metric_names"],
|
||||
dataset, rag_option, num_samples, paramset["metric_names"], out_path
|
||||
)
|
||||
results[dataset][num_samples][rag_option] = result
|
||||
|
||||
|
|
|
|||
|
|
@ -118,10 +118,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"from typing import Union, BinaryIO\n",
|
||||
"\n",
|
||||
|
|
@ -133,7 +133,7 @@
|
|||
")\n",
|
||||
"from cognee.modules.users.models import User\n",
|
||||
"from cognee.modules.users.methods import get_default_user\n",
|
||||
"from cognee.tasks.ingestion.ingest_data_with_metadata import ingest_data_with_metadata\n",
|
||||
"from cognee.tasks.ingestion.ingest_data import ingest_data\n",
|
||||
"import cognee\n",
|
||||
"\n",
|
||||
"# Create a clean slate for cognee -- reset data and system state\n",
|
||||
|
|
@ -153,7 +153,7 @@
|
|||
" if user is None:\n",
|
||||
" user = await get_default_user()\n",
|
||||
"\n",
|
||||
" await ingest_data_with_metadata(data, dataset_name, user)\n",
|
||||
" await ingest_data(data, dataset_name, user)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"await add(documents)\n",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue