Feat/cog 1365 unify retrievers (#572)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> - Created the `BaseRetriever` class to unify all the retrievers and searches. - Implemented seven specialized retrievers (summaries, chunks, completions, graph, graph-summary, insights, code) with consistent get_context/get_completion interfaces. - Added json context dumping feature in the current completion implementations to enable context comparisons. - Built a comparison framework to validate old vs new implementations. ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced multiple retrieval classes for enhanced search capabilities, including `BaseRetriever`, `ChunksRetriever`, `CodeRetriever`, `CompletionRetriever`, `GraphCompletionRetriever`, `GraphSummaryCompletionRetriever`, `InsightsRetriever`, and `SummariesRetriever`. - Enhanced query completions with optional context saving for improved data persistence. - Implemented advanced tools to compare retrieval outcomes across different implementations. - **Refactor** - Streamlined internal module organization and updated references for increased maintainability and consistency. - Added comments indicating future maintenance tasks related to code merging. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
parent
86b34657aa
commit
9cc357ac1c
31 changed files with 740 additions and 31 deletions
|
|
@ -1 +1 @@
|
|||
from .code_graph_retrieval import code_graph_retrieval
|
||||
from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval
|
||||
|
|
|
|||
16
cognee/modules/retrieval/base_retriever.py
Normal file
16
cognee/modules/retrieval/base_retriever.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""Base class for all retrieval operations."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Retrieves context based on the query."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Generates a response using the query and optional context."""
|
||||
pass
|
||||
20
cognee/modules/retrieval/chunks_retriever.py
Normal file
20
cognee/modules/retrieval/chunks_retriever.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
|
||||
|
||||
class ChunksRetriever(BaseRetriever):
|
||||
"""Retriever for handling document chunk-based searches."""
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Retrieves document chunks context based on the query."""
|
||||
vector_engine = get_vector_engine()
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5)
|
||||
return [result.payload for result in found_chunks]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Generates a completion using document chunks context."""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
146
cognee/modules/retrieval/code_retriever.py
Normal file
146
cognee/modules/retrieval/code_retriever.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
from typing import Any, Optional, List, Dict
|
||||
import asyncio
|
||||
import aiofiles
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.low_level import DataPoint
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
|
||||
|
||||
class CodeRetriever(BaseRetriever):
|
||||
"""Retriever for handling code-based searches."""
|
||||
|
||||
class CodeQueryInfo(BaseModel):
|
||||
"""Response model for information extraction from the query"""
|
||||
|
||||
filenames: List[str] = []
|
||||
sourcecode: str
|
||||
|
||||
def __init__(self, limit: int = 3):
|
||||
"""Initialize retriever with search parameters."""
|
||||
self.limit = limit
|
||||
self.file_name_collections = ["CodeFile_name"]
|
||||
self.classes_and_functions_collections = [
|
||||
"ClassDefinition_source_code",
|
||||
"FunctionDefinition_source_code",
|
||||
]
|
||||
|
||||
async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
|
||||
"""Process the query using LLM to extract file names and source code parts."""
|
||||
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
|
||||
llm_client = get_llm_client()
|
||||
try:
|
||||
return await llm_client.acreate_structured_output(
|
||||
text_input=query,
|
||||
system_prompt=system_prompt,
|
||||
response_model=self.CodeQueryInfo,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to retrieve structured output from LLM") from e
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Find relevant code files based on the query."""
|
||||
if not query or not isinstance(query, str):
|
||||
raise ValueError("The query must be a non-empty string.")
|
||||
|
||||
try:
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
except Exception as e:
|
||||
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
|
||||
|
||||
files_and_codeparts = await self._process_query(query)
|
||||
|
||||
similar_filenames = []
|
||||
similar_codepieces = []
|
||||
|
||||
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
|
||||
for collection in self.file_name_collections:
|
||||
search_results_file = await vector_engine.search(
|
||||
collection, query, limit=self.limit
|
||||
)
|
||||
for res in search_results_file:
|
||||
similar_filenames.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
for collection in self.classes_and_functions_collections:
|
||||
search_results_code = await vector_engine.search(
|
||||
collection, query, limit=self.limit
|
||||
)
|
||||
for res in search_results_code:
|
||||
similar_codepieces.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
else:
|
||||
for collection in self.file_name_collections:
|
||||
for file_from_query in files_and_codeparts.filenames:
|
||||
search_results_file = await vector_engine.search(
|
||||
collection, file_from_query, limit=self.limit
|
||||
)
|
||||
for res in search_results_file:
|
||||
similar_filenames.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
for collection in self.classes_and_functions_collections:
|
||||
search_results_code = await vector_engine.search(
|
||||
collection, files_and_codeparts.sourcecode, limit=self.limit
|
||||
)
|
||||
for res in search_results_code:
|
||||
similar_codepieces.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
file_ids = [str(item["id"]) for item in similar_filenames]
|
||||
code_ids = [str(item["id"]) for item in similar_codepieces]
|
||||
|
||||
relevant_triplets = await asyncio.gather(
|
||||
*[graph_engine.get_connections(node_id) for node_id in code_ids + file_ids]
|
||||
)
|
||||
|
||||
paths = set()
|
||||
for sublist in relevant_triplets:
|
||||
for tpl in sublist:
|
||||
if isinstance(tpl, tuple) and len(tpl) >= 3:
|
||||
if "file_path" in tpl[0]:
|
||||
paths.add(tpl[0]["file_path"])
|
||||
if "file_path" in tpl[2]:
|
||||
paths.add(tpl[2]["file_path"])
|
||||
|
||||
retrieved_files = {}
|
||||
read_tasks = []
|
||||
for file_path in paths:
|
||||
|
||||
async def read_file(fp):
|
||||
try:
|
||||
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
|
||||
retrieved_files[fp] = await f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading {fp}: {e}")
|
||||
retrieved_files[fp] = ""
|
||||
|
||||
read_tasks.append(read_file(file_path))
|
||||
|
||||
await asyncio.gather(*read_tasks)
|
||||
|
||||
return [
|
||||
{
|
||||
"name": file_path,
|
||||
"description": file_path,
|
||||
"content": retrieved_files[file_path],
|
||||
}
|
||||
for file_path in paths
|
||||
]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Returns the code files context."""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
40
cognee/modules/retrieval/completion_retriever.py
Normal file
40
cognee/modules/retrieval/completion_retriever.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
||||
|
||||
|
||||
class CompletionRetriever(BaseRetriever):
|
||||
"""Retriever for handling LLM-based completion searches."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_prompt_path: str = "context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
):
|
||||
"""Initialize retriever with optional custom prompt paths."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Retrieves relevant document chunks as context."""
|
||||
vector_engine = get_vector_engine()
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1)
|
||||
if len(found_chunks) == 0:
|
||||
raise NoRelevantDataFound
|
||||
return found_chunks[0].payload["text"]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Generates an LLM completion using the context."""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
return [completion]
|
||||
71
cognee/modules/retrieval/graph_completion_retriever.py
Normal file
71
cognee/modules/retrieval/graph_completion_retriever.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from cognee.infrastructure.engine import ExtendableDataPoint
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
||||
|
||||
|
||||
class GraphCompletionRetriever(BaseRetriever):
|
||||
"""Retriever for handling graph-based completion searches."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: int = 5,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k
|
||||
|
||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||
"""Converts retrieved graph edges into a human-readable string format."""
|
||||
edge_strings = []
|
||||
for edge in retrieved_edges:
|
||||
node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name")
|
||||
node2_string = edge.node2.attributes.get("text") or edge.node2.attributes.get("name")
|
||||
edge_string = edge.attributes["relationship_type"]
|
||||
edge_str = f"{node1_string} -- {edge_string} -- {node2_string}"
|
||||
edge_strings.append(edge_str)
|
||||
return "\n---\n".join(edge_strings)
|
||||
|
||||
async def get_triplets(self, query: str) -> list:
|
||||
"""Retrieves relevant graph triplets."""
|
||||
subclasses = get_all_subclasses(ExtendableDataPoint)
|
||||
vector_index_collections = []
|
||||
|
||||
for subclass in subclasses:
|
||||
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
|
||||
for field_name in index_fields:
|
||||
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
||||
|
||||
found_triplets = await brute_force_triplet_search(
|
||||
query, top_k=self.top_k, collections=vector_index_collections or None
|
||||
)
|
||||
|
||||
if len(found_triplets) == 0:
|
||||
raise NoRelevantDataFound
|
||||
|
||||
return found_triplets
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Retrieves and resolves graph triplets into context."""
|
||||
triplets = await self.get_triplets(query)
|
||||
return await self.resolve_edges_to_text(triplets)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Generates a completion using graph connections context."""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
return [completion]
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
||||
|
||||
class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
||||
"""Retriever for handling graph-based completion searches with summarized context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
summarize_prompt_path: str = "summarize_search_results.txt",
|
||||
top_k: int = 5,
|
||||
):
|
||||
"""Initialize retriever with default prompt paths and search parameters."""
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
)
|
||||
self.summarize_prompt_path = summarize_prompt_path
|
||||
|
||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||
"""Converts retrieved graph edges into a summary without redundancies."""
|
||||
direct_text = await super().resolve_edges_to_text(retrieved_edges)
|
||||
system_prompt = read_query_prompt(self.summarize_prompt_path)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
return await llm_client.acreate_structured_output(
|
||||
text_input=direct_text,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
66
cognee/modules/retrieval/insights_retriever.py
Normal file
66
cognee/modules/retrieval/insights_retriever.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
|
||||
|
||||
class InsightsRetriever(BaseRetriever):
|
||||
"""Retriever for handling graph connection-based insights."""
|
||||
|
||||
def __init__(self, exploration_levels: int = 1, top_k: int = 5):
|
||||
"""Initialize retriever with exploration levels and search parameters."""
|
||||
self.exploration_levels = exploration_levels
|
||||
self.top_k = top_k
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Find the neighbours of a given node in the graph."""
|
||||
if query is None:
|
||||
return []
|
||||
|
||||
node_id = query
|
||||
graph_engine = await get_graph_engine()
|
||||
exact_node = await graph_engine.extract_node(node_id)
|
||||
|
||||
if exact_node is not None and "id" in exact_node:
|
||||
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
|
||||
else:
|
||||
vector_engine = get_vector_engine()
|
||||
results = await asyncio.gather(
|
||||
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
|
||||
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
|
||||
)
|
||||
results = [*results[0], *results[1]]
|
||||
relevant_results = [result for result in results if result.score < 0.5][: self.top_k]
|
||||
|
||||
if len(relevant_results) == 0:
|
||||
return []
|
||||
|
||||
node_connections_results = await asyncio.gather(
|
||||
*[graph_engine.get_connections(result.id) for result in relevant_results]
|
||||
)
|
||||
|
||||
node_connections = []
|
||||
for neighbours in node_connections_results:
|
||||
node_connections.extend(neighbours)
|
||||
|
||||
unique_node_connections_map = {}
|
||||
unique_node_connections = []
|
||||
|
||||
for node_connection in node_connections:
|
||||
if "id" not in node_connection[0] or "id" not in node_connection[2]:
|
||||
continue
|
||||
|
||||
unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}"
|
||||
if unique_id not in unique_node_connections_map:
|
||||
unique_node_connections_map[unique_id] = True
|
||||
unique_node_connections.append(node_connection)
|
||||
|
||||
return unique_node_connections
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Returns the graph connections context."""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
24
cognee/modules/retrieval/summaries_retriever.py
Normal file
24
cognee/modules/retrieval/summaries_retriever.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
|
||||
|
||||
class SummariesRetriever(BaseRetriever):
|
||||
"""Retriever for handling summary-based searches."""
|
||||
|
||||
def __init__(self, limit: int = 5):
|
||||
"""Initialize retriever with search parameters."""
|
||||
self.limit = limit
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Retrieves summary context based on the query."""
|
||||
vector_engine = get_vector_engine()
|
||||
summaries_results = await vector_engine.search("TextSummary_text", query, limit=self.limit)
|
||||
return [summary.payload for summary in summaries_results]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Generates a completion using summaries context."""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
0
cognee/modules/retrieval/utils/__init__.py
Normal file
0
cognee/modules/retrieval/utils/__init__.py
Normal file
23
cognee/modules/retrieval/utils/completion.py
Normal file
23
cognee/modules/retrieval/utils/completion.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
|
||||
|
||||
async def generate_completion(
|
||||
query: str,
|
||||
context: str,
|
||||
user_prompt_path: str,
|
||||
system_prompt_path: str,
|
||||
) -> str:
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
args = {"question": query, "context": context}
|
||||
user_prompt = render_prompt(user_prompt_path, args)
|
||||
system_prompt = read_query_prompt(system_prompt_path)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
return await llm_client.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
221
cognee/modules/retrieval/utils/run_search_comparisons.py
Normal file
221
cognee/modules/retrieval/utils/run_search_comparisons.py
Normal file
|
|
@ -0,0 +1,221 @@
|
|||
# TODO: delete after merging COG-1365, see COG-1403
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Type
|
||||
|
||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
)
|
||||
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
|
||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval
|
||||
from cognee.tasks.chunks import query_chunks
|
||||
from cognee.tasks.completion import (
|
||||
query_completion,
|
||||
graph_query_completion,
|
||||
graph_query_summary_completion,
|
||||
)
|
||||
from cognee.tasks.graph import query_graph_connections
|
||||
from cognee.tasks.summarization import query_summaries
|
||||
from examples.python.dynamic_steps_example import main as setup_main
|
||||
|
||||
|
||||
CONTEXT_DUMP_DIR = "context_dumps"
|
||||
|
||||
# Define retriever configurations
|
||||
COMPLETION_RETRIEVERS = [
|
||||
{
|
||||
"name": "completion",
|
||||
"old_implementation": query_completion,
|
||||
"retriever_class": CompletionRetriever,
|
||||
"type": "completion",
|
||||
},
|
||||
{
|
||||
"name": "graph completion",
|
||||
"old_implementation": graph_query_completion,
|
||||
"retriever_class": GraphCompletionRetriever,
|
||||
"type": "graph_completion",
|
||||
},
|
||||
{
|
||||
"name": "graph summary completion",
|
||||
"old_implementation": graph_query_summary_completion,
|
||||
"retriever_class": GraphSummaryCompletionRetriever,
|
||||
"type": "graph_summary_completion",
|
||||
},
|
||||
]
|
||||
|
||||
BASIC_RETRIEVERS = [
|
||||
{
|
||||
"name": "summaries search",
|
||||
"old_implementation": query_summaries,
|
||||
"retriever_class": SummariesRetriever,
|
||||
},
|
||||
{
|
||||
"name": "chunks search",
|
||||
"old_implementation": query_chunks,
|
||||
"retriever_class": ChunksRetriever,
|
||||
},
|
||||
{
|
||||
"name": "insights search",
|
||||
"old_implementation": query_graph_connections,
|
||||
"retriever_class": InsightsRetriever,
|
||||
},
|
||||
{
|
||||
"name": "code search",
|
||||
"old_implementation": code_graph_retrieval,
|
||||
"retriever_class": CodeRetriever,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def compare_completion(old_results: list, new_results: list) -> Dict:
|
||||
"""Compare two lists of completion results and print differences."""
|
||||
lengths_match = len(old_results) == len(new_results)
|
||||
matches = []
|
||||
|
||||
if lengths_match:
|
||||
print("Results length match")
|
||||
matches = [old == new for old, new in zip(old_results, new_results)]
|
||||
if all(matches):
|
||||
print("All entries match")
|
||||
else:
|
||||
print(f"Differences found at indices: {[i for i, m in enumerate(matches) if not m]}")
|
||||
print("\nDifferences:")
|
||||
for i, (old, new) in enumerate(zip(old_results, new_results)):
|
||||
if old != new:
|
||||
print(f"\nIndex {i}:")
|
||||
print("Old:", json.dumps(old, indent=2))
|
||||
print("New:", json.dumps(new, indent=2))
|
||||
else:
|
||||
print(f"Results length mismatch: {len(old_results)} vs {len(new_results)}")
|
||||
print("\nOld results:", json.dumps(old_results, indent=2))
|
||||
print("\nNew results:", json.dumps(new_results, indent=2))
|
||||
|
||||
return {
|
||||
"old_results": old_results,
|
||||
"new_results": new_results,
|
||||
"lengths_match": lengths_match,
|
||||
"element_matches": matches,
|
||||
}
|
||||
|
||||
|
||||
async def compare_retriever(
|
||||
query: str, old_implementation: Callable, new_retriever: Any, name: str
|
||||
) -> Dict:
|
||||
"""Compare old and new retriever implementations."""
|
||||
print(f"\nComparing {name}...")
|
||||
|
||||
# Get results from both implementations
|
||||
old_results = await old_implementation(query)
|
||||
new_results = await new_retriever.get_completion(query)
|
||||
|
||||
return await compare_completion(old_results, new_results)
|
||||
|
||||
|
||||
async def compare_completion_context(
|
||||
query: str, old_implementation: Callable, retriever_class: Type, name: str, retriever_type: str
|
||||
) -> Dict:
|
||||
"""Compare context between old completion implementation and new retriever."""
|
||||
print(f"\nComparing {name} contexts...")
|
||||
|
||||
# Get context from old implementation with dumping
|
||||
context_path = f"{CONTEXT_DUMP_DIR}/{retriever_type}_{hash(query)}_context.json"
|
||||
os.makedirs(CONTEXT_DUMP_DIR, exist_ok=True)
|
||||
await old_implementation(query, save_context_path=context_path)
|
||||
|
||||
# Get context from new implementation
|
||||
retriever = retriever_class()
|
||||
new_context = await retriever.get_context(query)
|
||||
|
||||
# Read dumped context
|
||||
with open(context_path, "r") as f:
|
||||
old_context = json.load(f)
|
||||
|
||||
# Compare contexts
|
||||
contexts_match = old_context == new_context
|
||||
if contexts_match:
|
||||
print("Contexts match exactly")
|
||||
else:
|
||||
print("Contexts differ:")
|
||||
print("\nOld context:", json.dumps(old_context, indent=2))
|
||||
print("\nNew context:", json.dumps(new_context, indent=2))
|
||||
|
||||
return {
|
||||
"old_context": old_context,
|
||||
"new_context": new_context,
|
||||
"contexts_match": contexts_match,
|
||||
}
|
||||
|
||||
|
||||
async def main(query: str, comparisons: Dict[str, bool], setup_steps: Dict[str, bool]):
|
||||
"""Run comparison tests for selected retrievers with the given setup configuration."""
|
||||
# Ensure retriever is always False in setup steps
|
||||
setup_steps["retriever"] = False
|
||||
await setup_main(setup_steps)
|
||||
|
||||
# Compare contexts for completion-based retrievers
|
||||
for retriever in COMPLETION_RETRIEVERS:
|
||||
context_key = f"{retriever['type']}_context"
|
||||
if comparisons.get(context_key, False):
|
||||
await compare_completion_context(
|
||||
query=query,
|
||||
old_implementation=retriever["old_implementation"],
|
||||
retriever_class=retriever["retriever_class"],
|
||||
name=retriever["name"],
|
||||
retriever_type=retriever["type"],
|
||||
)
|
||||
|
||||
# Run completion comparisons
|
||||
for retriever in COMPLETION_RETRIEVERS:
|
||||
if comparisons.get(retriever["type"], False):
|
||||
await compare_retriever(
|
||||
query=query,
|
||||
old_implementation=retriever["old_implementation"],
|
||||
new_retriever=retriever["retriever_class"](),
|
||||
name=retriever["name"],
|
||||
)
|
||||
|
||||
# Run basic retriever comparisons
|
||||
for retriever in BASIC_RETRIEVERS:
|
||||
retriever_type = retriever["name"].split()[0]
|
||||
if comparisons.get(retriever_type, False):
|
||||
await compare_retriever(
|
||||
query=query,
|
||||
old_implementation=retriever["old_implementation"],
|
||||
new_retriever=retriever["retriever_class"](),
|
||||
name=retriever["name"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
||||
test_query = "Who has experience in data science?"
|
||||
comparisons = {
|
||||
# Context comparisons
|
||||
"completion_context": True,
|
||||
"graph_completion_context": True,
|
||||
"graph_summary_completion_context": True,
|
||||
# Result comparisons
|
||||
"summaries": True,
|
||||
"chunks": True,
|
||||
"insights": True,
|
||||
"code": False,
|
||||
"completion": True,
|
||||
"graph_completion": True,
|
||||
"graph_summary_completion": True,
|
||||
}
|
||||
setup_steps = {
|
||||
"prune_data": True,
|
||||
"prune_system": True,
|
||||
"add_text": True,
|
||||
"cognify": True,
|
||||
}
|
||||
|
||||
asyncio.run(main(test_query, comparisons, setup_steps))
|
||||
|
|
@ -3,7 +3,7 @@ from typing import Callable
|
|||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.modules.retrieval.code_graph_retrieval import code_graph_retrieval
|
||||
from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -44,6 +44,7 @@ async def search(
|
|||
|
||||
|
||||
async def specific_search(query_type: SearchType, query: str, user: User) -> list:
|
||||
# TODO: update after merging COG-1365, see COG-1403
|
||||
search_tasks: dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: query_summaries,
|
||||
SearchType.INSIGHTS: query_graph_connections,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
# TODO: delete after merging COG-1365, see COG-1403
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,20 @@
|
|||
# TODO: delete after merging COG-1365, see COG-1403
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from cognee.infrastructure.engine import ExtendableDataPoint
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from typing import Callable
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def retrieved_edges_to_string(retrieved_edges: list) -> str:
|
||||
"""
|
||||
Converts a list of retrieved graph edges into a human-readable string format.
|
||||
|
|
@ -23,12 +30,16 @@ async def retrieved_edges_to_string(retrieved_edges: list) -> str:
|
|||
return "\n---\n".join(edge_strings)
|
||||
|
||||
|
||||
async def graph_query_completion(query: str, context_resolver: Callable = None) -> list:
|
||||
async def graph_query_completion(
|
||||
query: str, context_resolver: Callable = None, save_context_path: str = None
|
||||
) -> list:
|
||||
"""
|
||||
Executes a query on the graph database and retrieves a relevant completion based on the found data.
|
||||
|
||||
Parameters:
|
||||
- query (str): The query string to compute.
|
||||
- context_resolver (Callable): A function to convert retrieved edges to a string.
|
||||
- save_context_path (str): Path to save the retrieved context.
|
||||
|
||||
Returns:
|
||||
- list: Answer to the query.
|
||||
|
|
@ -38,7 +49,6 @@ async def graph_query_completion(query: str, context_resolver: Callable = None)
|
|||
- Prompts are dynamically rendered and provided to the LLM for contextual understanding.
|
||||
- Ensure that the LLM client and graph database are properly configured and accessible.
|
||||
"""
|
||||
|
||||
subclasses = get_all_subclasses(DataPoint)
|
||||
|
||||
vector_index_collections = []
|
||||
|
|
@ -58,9 +68,19 @@ async def graph_query_completion(query: str, context_resolver: Callable = None)
|
|||
if not context_resolver:
|
||||
context_resolver = retrieved_edges_to_string
|
||||
|
||||
# Get context and optionally dump it
|
||||
context = await context_resolver(found_triplets)
|
||||
if save_context_path:
|
||||
try:
|
||||
os.makedirs(os.path.dirname(save_context_path), exist_ok=True)
|
||||
with open(save_context_path, "w") as f:
|
||||
json.dump(context, f, indent=2)
|
||||
except (OSError, TypeError, ValueError) as e:
|
||||
logger.error(f"Failed to save context to {save_context_path}: {str(e)}")
|
||||
# Consider whether to raise or continue silently
|
||||
args = {
|
||||
"question": query,
|
||||
"context": await context_resolver(found_triplets),
|
||||
"context": context,
|
||||
}
|
||||
user_prompt = render_prompt("graph_context_for_question.txt", args)
|
||||
system_prompt = read_query_prompt("answer_simple_question.txt")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
# TODO: delete after merging COG-1365, see COG-1403
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.tasks.completion.graph_query_completion import (
|
||||
|
|
@ -22,5 +23,8 @@ async def retrieved_edges_to_summary(retrieved_edges: list) -> str:
|
|||
return summarized_context
|
||||
|
||||
|
||||
async def graph_query_summary_completion(query: str) -> list:
|
||||
return await graph_query_completion(query, context_resolver=retrieved_edges_to_summary)
|
||||
async def graph_query_summary_completion(query: str, save_context_path: str = None) -> list:
|
||||
"""Executes a query on the graph database and retrieves a summarized completion with optional context saving."""
|
||||
return await graph_query_completion(
|
||||
query, context_resolver=retrieved_edges_to_summary, save_context_path=save_context_path
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,24 @@
|
|||
# TODO: delete after merging COG-1365, see COG-1403
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
|
||||
|
||||
async def query_completion(query: str) -> list:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def query_completion(query: str, save_context_path: str = None) -> list:
|
||||
"""
|
||||
|
||||
Executes a query against a vector database and computes a relevant response using an LLM.
|
||||
|
||||
Parameters:
|
||||
- query (str): The query string to compute.
|
||||
- save_context_path (str): The path to save the context.
|
||||
|
||||
Returns:
|
||||
- list: Answer to the query.
|
||||
|
|
@ -28,9 +36,19 @@ async def query_completion(query: str) -> list:
|
|||
if len(found_chunks) == 0:
|
||||
raise NoRelevantDataFound
|
||||
|
||||
# Get context and optionally dump it
|
||||
context = found_chunks[0].payload["text"]
|
||||
if save_context_path:
|
||||
try:
|
||||
os.makedirs(os.path.dirname(save_context_path), exist_ok=True)
|
||||
with open(save_context_path, "w", encoding="utf-8") as f:
|
||||
json.dump(context, f, indent=2, ensure_ascii=False)
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to save context to {save_context_path}: {str(e)}")
|
||||
# Continue execution as context saving is optional
|
||||
args = {
|
||||
"question": query,
|
||||
"context": found_chunks[0].payload["text"],
|
||||
"context": context,
|
||||
}
|
||||
user_prompt = render_prompt("context_for_question.txt", args)
|
||||
system_prompt = read_query_prompt("answer_simple_question.txt")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
# TODO: delete after merging COG-1365, see COG-1403
|
||||
import asyncio
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
# TODO: delete after merging COG-1365, see COG-1403
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import pathlib
|
||||
import cognee
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import cognee
|
|||
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import pathlib
|
||||
import cognee
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import pathlib
|
||||
import cognee
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,19 +13,19 @@ async def test_code_description_to_code_part_no_results():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.description_to_codepart_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.description_to_codepart_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.description_to_codepart_search.get_graph_engine",
|
||||
"cognee.modules.retrieval.utils.description_to_codepart_search.get_graph_engine",
|
||||
return_value=AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.description_to_codepart_search.CogneeGraph",
|
||||
"cognee.modules.retrieval.utils.description_to_codepart_search.CogneeGraph",
|
||||
return_value=AsyncMock(),
|
||||
),
|
||||
):
|
||||
from cognee.modules.retrieval.description_to_codepart_search import (
|
||||
from cognee.modules.retrieval.utils.description_to_codepart_search import (
|
||||
code_description_to_code_part,
|
||||
)
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ async def test_code_description_to_code_part_invalid_query():
|
|||
mock_user = AsyncMock()
|
||||
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
from cognee.modules.retrieval.description_to_codepart_search import (
|
||||
from cognee.modules.retrieval.utils.description_to_codepart_search import (
|
||||
code_description_to_code_part,
|
||||
)
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ async def test_code_description_to_code_part_invalid_top_k():
|
|||
mock_user = AsyncMock()
|
||||
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||
from cognee.modules.retrieval.description_to_codepart_search import (
|
||||
from cognee.modules.retrieval.utils.description_to_codepart_search import (
|
||||
code_description_to_code_part,
|
||||
)
|
||||
|
||||
|
|
@ -70,15 +70,15 @@ async def test_code_description_to_code_part_initialization_error():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.description_to_codepart_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.description_to_codepart_search.get_vector_engine",
|
||||
side_effect=Exception("Engine init failed"),
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.description_to_codepart_search.get_graph_engine",
|
||||
"cognee.modules.retrieval.utils.description_to_codepart_search.get_graph_engine",
|
||||
return_value=AsyncMock(),
|
||||
),
|
||||
):
|
||||
from cognee.modules.retrieval.description_to_codepart_search import (
|
||||
from cognee.modules.retrieval.utils.description_to_codepart_search import (
|
||||
code_description_to_code_part,
|
||||
)
|
||||
|
||||
|
|
@ -99,19 +99,19 @@ async def test_code_description_to_code_part_execution_error():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.description_to_codepart_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.description_to_codepart_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.description_to_codepart_search.get_graph_engine",
|
||||
"cognee.modules.retrieval.utils.description_to_codepart_search.get_graph_engine",
|
||||
return_value=AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.description_to_codepart_search.CogneeGraph",
|
||||
"cognee.modules.retrieval.utils.description_to_codepart_search.CogneeGraph",
|
||||
return_value=AsyncMock(),
|
||||
),
|
||||
):
|
||||
from cognee.modules.retrieval.description_to_codepart_search import (
|
||||
from cognee.modules.retrieval.utils.description_to_codepart_search import (
|
||||
code_description_to_code_part,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE
|
|||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.modules.retrieval.description_to_codepart_search import (
|
||||
from cognee.modules.retrieval.utils.description_to_codepart_search import (
|
||||
code_description_to_code_part_search,
|
||||
)
|
||||
from evals.eval_utils import download_github_repo
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import cognee
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string
|
||||
from functools import partial
|
||||
from cognee.api.v1.cognify.cognify_v2 import get_default_tasks
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from cognee.infrastructure.databases.relational import (
|
|||
from cognee.tasks.temporal_awareness.index_graphiti_objects import (
|
||||
index_and_transform_graphiti_nodes_and_edges,
|
||||
)
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@
|
|||
"from cognee.tasks.temporal_awareness.index_graphiti_objects import (\n",
|
||||
" index_and_transform_graphiti_nodes_and_edges,\n",
|
||||
")\n",
|
||||
"from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search\n",
|
||||
"from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search\n",
|
||||
"from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string\n",
|
||||
"from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt\n",
|
||||
"from cognee.infrastructure.llm.get_llm_client import get_llm_client"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue