feat: entity brute force triplet search [COG-1325] (#589)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->
- Refactored `brute_force_triplet_search`, extracting memory projection.
- Built **TripletSearchContextProvider** (extends
**BaseContextProvider**) to create a single memory projection and
perform a triplet search for each entity.
- Refactored `entity_completion` into **EntityCompletionRetriever**
(extends **BaseRetriever**).
- Added **SummarizedTripletSearchContextProvider** (extends
**TripletSearchContextProvider**) for an alternative summarized output
format.
- Developed and tested an example showcasing both context providers,
comparing raw triplets, summaries, and standard search results.
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced text summarization now delivers clearer, more concise
overviews of search results.
- Improved search performance with optimized context retrieval and
memory reuse for faster, more reliable results.
- Introduced advanced entity-based completion for generating more
relevant, context-aware responses.

- **Refactor**
- Streamlined internal workflows and error handling to ensure a smoother
overall experience.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Boris <boris@topoteretes.com>
This commit is contained in:
lxobr 2025-03-05 11:17:58 +01:00 committed by GitHub
parent 7bac2303cc
commit f033f733b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 403 additions and 129 deletions

View file

@ -0,0 +1,68 @@
from typing import Any, Optional, List
import logging
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.completion import generate_completion
logger = logging.getLogger("entity_completion_retriever")
class EntityCompletionRetriever(BaseRetriever):
"""Retriever that uses entity-based completion for generating responses."""
def __init__(
self,
extractor: BaseEntityExtractor,
context_provider: BaseContextProvider,
user_prompt_path: str = "context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
):
self.extractor = extractor
self.context_provider = context_provider
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
async def get_context(self, query: str) -> Any:
"""Get context using entity extraction and context provider."""
try:
logger.info(f"Processing query: {query[:100]}")
entities = await self.extractor.extract_entities(query)
if not entities:
logger.info("No entities extracted")
return None
context = await self.context_provider.get_context(entities, query)
if not context:
logger.info("No context retrieved")
return None
return context
except Exception as e:
logger.error(f"Context retrieval failed: {str(e)}")
return None
async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
"""Generate completion using provided context or fetch new context."""
try:
if context is None:
context = await self.get_context(query)
if context is None:
return ["No relevant entities found for the query."]
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
return [completion]
except Exception as e:
logger.error(f"Completion generation failed: {str(e)}")
return ["Completion generation failed"]

View file

@ -0,0 +1,22 @@
from typing import List, Optional
from cognee.modules.retrieval.utils.completion import summarize_text
from cognee.modules.retrieval.context_providers.TripletSearchContextProvider import (
TripletSearchContextProvider,
)
class SummarizedTripletSearchContextProvider(TripletSearchContextProvider):
"""Context provider that uses summarized triplet search results."""
async def _format_triplets(
self, triplets: List, entity_name: str, summarize_prompt_path: Optional[str] = None
) -> str:
"""Format triplets into a summarized text."""
direct_text = await super()._format_triplets(triplets, entity_name)
if summarize_prompt_path is None:
summarize_prompt_path = "summarize_search_results.txt"
summary = await summarize_text(direct_text, summarize_prompt_path)
return f"Summary for {entity_name}:\n{summary}\n---\n"

View file

@ -0,0 +1,97 @@
from typing import List, Optional
import asyncio
from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
brute_force_triplet_search,
format_triplets,
get_memory_fragment,
)
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
class TripletSearchContextProvider(BaseContextProvider):
"""Context provider that uses brute force triplet search for each entity."""
def __init__(
self,
top_k: int = 3,
collections: List[str] = None,
properties_to_project: List[str] = None,
):
self.top_k = top_k
self.collections = collections
self.properties_to_project = properties_to_project
def _get_entity_text(self, entity: DataPoint) -> Optional[str]:
"""Concatenates available entity text fields with graceful fallback."""
texts = []
if hasattr(entity, "name") and entity.name:
texts.append(entity.name)
if hasattr(entity, "description") and entity.description:
texts.append(entity.description)
if hasattr(entity, "text") and entity.text:
texts.append(entity.text)
return " ".join(texts) if texts else None
def _get_search_tasks(
self,
entities: List[DataPoint],
query: str,
user: User,
memory_fragment: CogneeGraph,
) -> List:
"""Creates search tasks for valid entities."""
tasks = [
brute_force_triplet_search(
query=f"{entity_text} {query}",
user=user,
top_k=self.top_k,
collections=self.collections,
properties_to_project=self.properties_to_project,
memory_fragment=memory_fragment,
)
for entity in entities
if (entity_text := self._get_entity_text(entity)) is not None
]
return tasks
async def _format_triplets(self, triplets: List, entity_name: str) -> str:
"""Format triplets into readable text."""
direct_text = format_triplets(triplets)
return f"Context for {entity_name}:\n{direct_text}\n---\n"
async def _results_to_context(self, entities: List[DataPoint], results: List) -> str:
"""Formats search results into context string."""
triplets = []
for entity, entity_triplets in zip(entities, results):
entity_name = (
getattr(entity, "name", None)
or getattr(entity, "description", None)
or getattr(entity, "text", str(entity))
)
triplets.append(await self._format_triplets(entity_triplets, entity_name))
return "\n".join(triplets) if triplets else "No relevant context found."
async def get_context(self, entities: List[DataPoint], query: str) -> str:
"""Get context for each entity using brute force triplet search."""
if not entities:
return "No entities provided for context search."
user = await get_default_user()
memory_fragment = await get_memory_fragment(self.properties_to_project)
search_tasks = self._get_search_tasks(entities, query, user, memory_fragment)
if not search_tasks:
return "No valid entities found for context search."
results = await asyncio.gather(*search_tasks)
return await self._results_to_context(entities, results)

View file

@ -1,8 +1,7 @@
from typing import Optional
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import summarize_text
class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
@ -26,11 +25,4 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""Converts retrieved graph edges into a summary without redundancies."""
direct_text = await super().resolve_edges_to_text(retrieved_edges)
system_prompt = read_query_prompt(self.summarize_prompt_path)
llm_client = get_llm_client()
return await llm_client.acreate_structured_output(
text_input=direct_text,
system_prompt=system_prompt,
response_model=str,
)
return await summarize_text(direct_text, self.summarize_prompt_path)

View file

@ -1,6 +1,6 @@
import asyncio
import logging
from typing import List
from typing import List, Optional
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
@ -49,12 +49,32 @@ def format_triplets(edges):
return "".join(triplets)
async def get_memory_fragment(
properties_to_project: Optional[List[str]] = None,
) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
graph_engine = await get_graph_engine()
memory_fragment = CogneeGraph()
if properties_to_project is None:
properties_to_project = ["id", "description", "name", "type", "text"]
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project,
edge_properties_to_project=["relationship_name"],
)
return memory_fragment
async def brute_force_triplet_search(
query: str,
user: User = None,
top_k: int = 5,
collections: List[str] = None,
properties_to_project: List[str] = None,
memory_fragment: Optional[CogneeGraph] = None,
) -> list:
if user is None:
user = await get_default_user()
@ -63,7 +83,12 @@ async def brute_force_triplet_search(
raise PermissionError("No user found in the system. Please create a user.")
retrieved_results = await brute_force_search(
query, user, top_k, collections=collections, properties_to_project=properties_to_project
query,
user,
top_k,
collections=collections,
properties_to_project=properties_to_project,
memory_fragment=memory_fragment,
)
return retrieved_results
@ -74,6 +99,7 @@ async def brute_force_search(
top_k: int,
collections: List[str] = None,
properties_to_project: List[str] = None,
memory_fragment: Optional[CogneeGraph] = None,
) -> list:
"""
Performs a brute force search to retrieve the top triplets from the graph.
@ -82,7 +108,9 @@ async def brute_force_search(
query (str): The search query.
user (User): The user performing the search.
top_k (int): The number of top results to retrieve.
collections (Optional[List[str]]): List of collections to query. Defaults to predefined collections.
collections (Optional[List[str]]): List of collections to query.
properties_to_project (Optional[List[str]]): List of properties to project.
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
Returns:
list: The top triplet results.
@ -92,6 +120,9 @@ async def brute_force_search(
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
if memory_fragment is None:
memory_fragment = await get_memory_fragment(properties_to_project)
if collections is None:
collections = [
"Entity_name",
@ -102,9 +133,8 @@ async def brute_force_search(
try:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
logging.error("Failed to initialize engines: %s", e)
logging.error("Failed to initialize vector engine: %s", e)
raise RuntimeError("Initialization error") from e
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
@ -119,22 +149,12 @@ async def brute_force_search(
node_distances = {collection: result for collection, result in zip(collections, results)}
memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project
or ["id", "description", "name", "type", "text"],
edge_properties_to_project=["relationship_name"],
)
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
send_telemetry("cognee.brute_force_triplet_search EXECUTION COMPLETED", user.id)
return results

View file

@ -21,3 +21,18 @@ async def generate_completion(
system_prompt=system_prompt,
response_model=str,
)
async def summarize_text(
text: str,
prompt_path: str = "summarize_search_results.txt",
) -> str:
"""Summarizes text using LLM with the specified prompt."""
system_prompt = read_query_prompt(prompt_path)
llm_client = get_llm_client()
return await llm_client.acreate_structured_output(
text_input=text,
system_prompt=system_prompt,
response_model=str,
)

View file

@ -1,103 +0,0 @@
from typing import List
import logging
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from cognee.infrastructure.entities.BaseEntityExtractor import (
BaseEntityExtractor,
)
from cognee.infrastructure.context.BaseContextProvider import (
BaseContextProvider,
)
logger = logging.getLogger("entity_completion")
# Default prompt template paths
DEFAULT_SYSTEM_PROMPT_TEMPLATE = "answer_simple_question.txt"
DEFAULT_USER_PROMPT_TEMPLATE = "context_for_question.txt"
async def get_llm_response(
query: str,
context: str,
system_prompt_template: str = None,
user_prompt_template: str = None,
) -> str:
"""Generate LLM response based on query and context."""
try:
args = {
"question": query,
"context": context,
}
user_prompt = render_prompt(user_prompt_template or DEFAULT_USER_PROMPT_TEMPLATE, args)
system_prompt = read_query_prompt(system_prompt_template or DEFAULT_SYSTEM_PROMPT_TEMPLATE)
llm_client = get_llm_client()
return await llm_client.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
except Exception as e:
logger.error(f"LLM response generation failed: {str(e)}")
raise
async def entity_completion(
query: str,
extractor: BaseEntityExtractor,
context_provider: BaseContextProvider,
system_prompt_template: str = None,
user_prompt_template: str = None,
) -> List[str]:
"""Execute entity-based completion using provided components."""
if not query or not isinstance(query, str):
logger.error("Invalid query type or empty query")
return ["Invalid query input"]
try:
logger.info(f"Processing query: {query[:100]}")
entities = await extractor.extract_entities(query)
if not entities:
logger.info("No entities extracted")
return ["No entities found"]
context = await context_provider.get_context(entities, query)
if not context:
logger.info("No context retrieved")
return ["No context found"]
response = await get_llm_response(
query, context, system_prompt_template, user_prompt_template
)
return [response]
except Exception as e:
logger.error(f"Entity completion failed: {str(e)}")
return ["Entity completion failed"]
if __name__ == "__main__":
# For testing purposes, will be removed by the end of the sprint
import asyncio
import logging
from cognee.tasks.entity_completion.entity_extractors.dummy_entity_extractor import (
DummyEntityExtractor,
)
from cognee.tasks.entity_completion.context_providers.dummy_context_provider import (
DummyContextProvider,
)
logging.basicConfig(level=logging.INFO)
async def run_entity_completion():
# Uses config defaults
result = await entity_completion(
"Tell me about Einstein",
DummyEntityExtractor(),
DummyContextProvider(),
)
print(f"Query Response: {result[0]}")
asyncio.run(run_entity_completion())

View file

@ -0,0 +1,163 @@
import cognee
import asyncio
import logging
from cognee.api.v1.search import SearchType
from cognee.shared.utils import setup_logging
from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever
from cognee.modules.retrieval.context_providers.TripletSearchContextProvider import (
TripletSearchContextProvider,
)
from cognee.modules.retrieval.context_providers.SummarizedTripletSearchContextProvider import (
SummarizedTripletSearchContextProvider,
)
from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor
article_1 = """
Title: The Theory of Relativity: A Revolutionary Breakthrough
Author: Dr. Sarah Chen
Albert Einstein's theory of relativity fundamentally changed our understanding of space, time, and gravity. Published in 1915, the general theory of relativity describes gravity as a consequence of the curvature of spacetime caused by mass and energy. This groundbreaking work built upon his special theory of relativity from 1905, which introduced the famous equation E=mc².
Einstein's work at the Swiss Patent Office gave him time to develop these revolutionary ideas. His mathematical framework predicted several phenomena that were later confirmed, including:
- The bending of light by gravity
- The precession of Mercury's orbit
- The existence of black holes
The theory continues to be tested and validated today, most recently through the detection of gravitational waves by LIGO in 2015, exactly 100 years after its publication.
"""
article_2 = """
Title: The Manhattan Project and Its Scientific Director
Author: Prof. Michael Werner
J. Robert Oppenheimer's leadership of the Manhattan Project marked a pivotal moment in scientific history. As scientific director of the Los Alamos Laboratory, he assembled and led an extraordinary team of physicists in the development of the atomic bomb during World War II.
Oppenheimer's journey to Los Alamos began at Harvard and continued through his groundbreaking work in quantum mechanics and nuclear physics at Berkeley. His expertise in theoretical physics and exceptional leadership abilities made him the ideal candidate to head the secret weapons laboratory.
Key aspects of his directorship included:
- Recruitment of top scientific talent from across the country
- Integration of theoretical physics with practical engineering challenges
- Development of implosion-type nuclear weapons
- Management of complex security and ethical considerations
After witnessing the first nuclear test, codenamed Trinity, Oppenheimer famously quoted the Bhagavad Gita: "Now I am become Death, the destroyer of worlds." This moment reflected the profound moral implications of scientific advancement that would shape his later advocacy for international atomic controls.
"""
article_3 = """
Title: The Birth of Quantum Physics
Author: Dr. Lisa Martinez
The early 20th century witnessed a revolutionary transformation in our understanding of the microscopic world. The development of quantum mechanics emerged from the collaborative efforts of numerous brilliant physicists grappling with phenomena that classical physics couldn't explain.
Key contributors and their insights included:
- Max Planck's discovery of energy quantization (1900)
- Niels Bohr's model of the atom with discrete energy levels (1913)
- Werner Heisenberg's uncertainty principle (1927)
- Erwin Schrödinger's wave equation (1926)
- Paul Dirac's quantum theory of the electron (1928)
Einstein's 1905 paper on the photoelectric effect, which demonstrated light's particle nature, was a crucial contribution to this field. The Copenhagen interpretation, developed primarily by Bohr and Heisenberg, became the standard understanding of quantum mechanics, despite ongoing debates about its philosophical implications. These foundational developments continue to influence modern physics, from quantum computing to quantum field theory.
"""
async def main(enable_steps):
# Step 1: Reset data and system state
if enable_steps.get("prune_data"):
await cognee.prune.prune_data()
print("Data pruned.")
if enable_steps.get("prune_system"):
await cognee.prune.prune_system(metadata=True)
print("System pruned.")
# Step 2: Add text
if enable_steps.get("add_text"):
text_list = [article_1, article_2, article_3]
for text in text_list:
await cognee.add(text)
print(f"Added text: {text[:50]}...")
# Step 3: Create knowledge graph
if enable_steps.get("cognify"):
await cognee.cognify()
print("Knowledge graph created.")
# Step 4: Query insights using our new retrievers
if enable_steps.get("retriever"):
# Common settings
search_settings = {
"top_k": 5,
"collections": ["Entity_name", "TextSummary_text"],
"properties_to_project": ["name", "description", "text"],
}
# Create both context providers
direct_provider = TripletSearchContextProvider(**search_settings)
summary_provider = SummarizedTripletSearchContextProvider(**search_settings)
# Create retrievers with different providers
direct_retriever = EntityCompletionRetriever(
extractor=DummyEntityExtractor(),
context_provider=direct_provider,
system_prompt_path="answer_simple_question.txt",
user_prompt_path="context_for_question.txt",
)
summary_retriever = EntityCompletionRetriever(
extractor=DummyEntityExtractor(),
context_provider=summary_provider,
system_prompt_path="answer_simple_question.txt",
user_prompt_path="context_for_question.txt",
)
query = "What were the early contributions to quantum physics?"
print("\nQuery:", query)
# Try with direct triplets
print("\n=== Direct Triplets ===")
context = await direct_retriever.get_context(query)
print("\nEntity Context:")
print(context)
result = await direct_retriever.get_completion(query)
print("\nEntity Completion:")
print(result)
# Try with summarized triplets
print("\n=== Summarized Triplets ===")
context = await summary_retriever.get_context(query)
print("\nEntity Context:")
print(context)
result = await summary_retriever.get_completion(query)
print("\nEntity Completion:")
print(result)
# Compare with standard search
print("\n=== Standard Search ===")
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text=query
)
print(search_results)
if __name__ == "__main__":
setup_logging(logging.ERROR)
rebuild_kg = True
retrieve = True
steps_to_enable = {
"prune_data": rebuild_kg,
"prune_system": rebuild_kg,
"add_text": rebuild_kg,
"cognify": rebuild_kg,
"retriever": retrieve,
}
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main(steps_to_enable))
finally:
loop.run_until_complete(loop.shutdown_asyncgens())