<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> - Created the `BaseRetriever` class to unify all the retrievers and searches. - Implemented seven specialized retrievers (summaries, chunks, completions, graph, graph-summary, insights, code) with consistent get_context/get_completion interfaces. - Added json context dumping feature in the current completion implementations to enable context comparisons. - Built a comparison framework to validate old vs new implementations. ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced multiple retrieval classes for enhanced search capabilities, including `BaseRetriever`, `ChunksRetriever`, `CodeRetriever`, `CompletionRetriever`, `GraphCompletionRetriever`, `GraphSummaryCompletionRetriever`, `InsightsRetriever`, and `SummariesRetriever`. - Enhanced query completions with optional context saving for improved data persistence. - Implemented advanced tools to compare retrieval outcomes across different implementations. - **Refactor** - Streamlined internal module organization and updated references for increased maintainability and consistency. - Added comments indicating future maintenance tasks related to code merging. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
66 lines
2.6 KiB
Python
66 lines
2.6 KiB
Python
import asyncio
|
|
from typing import Any, Optional
|
|
|
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
|
|
|
|
class InsightsRetriever(BaseRetriever):
|
|
"""Retriever for handling graph connection-based insights."""
|
|
|
|
def __init__(self, exploration_levels: int = 1, top_k: int = 5):
|
|
"""Initialize retriever with exploration levels and search parameters."""
|
|
self.exploration_levels = exploration_levels
|
|
self.top_k = top_k
|
|
|
|
async def get_context(self, query: str) -> Any:
|
|
"""Find the neighbours of a given node in the graph."""
|
|
if query is None:
|
|
return []
|
|
|
|
node_id = query
|
|
graph_engine = await get_graph_engine()
|
|
exact_node = await graph_engine.extract_node(node_id)
|
|
|
|
if exact_node is not None and "id" in exact_node:
|
|
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
|
|
else:
|
|
vector_engine = get_vector_engine()
|
|
results = await asyncio.gather(
|
|
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
|
|
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
|
|
)
|
|
results = [*results[0], *results[1]]
|
|
relevant_results = [result for result in results if result.score < 0.5][: self.top_k]
|
|
|
|
if len(relevant_results) == 0:
|
|
return []
|
|
|
|
node_connections_results = await asyncio.gather(
|
|
*[graph_engine.get_connections(result.id) for result in relevant_results]
|
|
)
|
|
|
|
node_connections = []
|
|
for neighbours in node_connections_results:
|
|
node_connections.extend(neighbours)
|
|
|
|
unique_node_connections_map = {}
|
|
unique_node_connections = []
|
|
|
|
for node_connection in node_connections:
|
|
if "id" not in node_connection[0] or "id" not in node_connection[2]:
|
|
continue
|
|
|
|
unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}"
|
|
if unique_id not in unique_node_connections_map:
|
|
unique_node_connections_map[unique_id] = True
|
|
unique_node_connections.append(node_connection)
|
|
|
|
return unique_node_connections
|
|
|
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
"""Returns the graph connections context."""
|
|
if context is None:
|
|
context = await self.get_context(query)
|
|
return context
|