cognee/cognee/modules/retrieval/insights_retriever.py
lxobr 9cc357ac1c
Feat/cog 1365 unify retrievers (#572)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->
- Created the `BaseRetriever` class to unify all the retrievers and
searches.
- Implemented seven specialized retrievers (summaries, chunks,
completions, graph, graph-summary, insights, code) with consistent
get_context/get_completion interfaces.
- Added json context dumping feature in the current completion
implementations to enable context comparisons.
- Built a comparison framework to validate old vs new implementations.
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced multiple retrieval classes for enhanced search
capabilities, including `BaseRetriever`, `ChunksRetriever`,
`CodeRetriever`, `CompletionRetriever`, `GraphCompletionRetriever`,
`GraphSummaryCompletionRetriever`, `InsightsRetriever`, and
`SummariesRetriever`.
- Enhanced query completions with optional context saving for improved
data persistence.
- Implemented advanced tools to compare retrieval outcomes across
different implementations.

- **Refactor**
- Streamlined internal module organization and updated references for
increased maintainability and consistency.
- Added comments indicating future maintenance tasks related to code
merging.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-02-27 12:13:21 +01:00

66 lines
2.6 KiB
Python

import asyncio
from typing import Any, Optional
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
class InsightsRetriever(BaseRetriever):
"""Retriever for handling graph connection-based insights."""
def __init__(self, exploration_levels: int = 1, top_k: int = 5):
"""Initialize retriever with exploration levels and search parameters."""
self.exploration_levels = exploration_levels
self.top_k = top_k
async def get_context(self, query: str) -> Any:
"""Find the neighbours of a given node in the graph."""
if query is None:
return []
node_id = query
graph_engine = await get_graph_engine()
exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "id" in exact_node:
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
else:
vector_engine = get_vector_engine()
results = await asyncio.gather(
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][: self.top_k]
if len(relevant_results) == 0:
return []
node_connections_results = await asyncio.gather(
*[graph_engine.get_connections(result.id) for result in relevant_results]
)
node_connections = []
for neighbours in node_connections_results:
node_connections.extend(neighbours)
unique_node_connections_map = {}
unique_node_connections = []
for node_connection in node_connections:
if "id" not in node_connection[0] or "id" not in node_connection[2]:
continue
unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}"
if unique_id not in unique_node_connections_map:
unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection)
return unique_node_connections
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Returns the graph connections context."""
if context is None:
context = await self.get_context(query)
return context