Transition to new retrievers, update searches (#585)

<!-- .github/pull_request_template.md -->

## Description
Delete legacy search implementations after migrating to new retriever
classes

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced search and retrieval capabilities, providing improved context
resolution for code queries, completions, summaries, and graph
connections.
  
- **Refactor**
- Shifted to a modular, object-oriented approach that consolidates query
logic and streamlines error management for a more robust and scalable
experience.
  
- **Bug Fixes**
- Improved error handling for unsupported search types and retrieval
operations.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Daniel Molnar 2025-02-27 15:25:24 +01:00 committed by GitHub
parent f9b6630024
commit d27f847753
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 38 additions and 676 deletions

View file

@ -4,7 +4,7 @@ from fastapi import APIRouter
from fastapi.responses import JSONResponse
from cognee.api.DTO import InDTO
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.modules.retrieval import code_graph_retrieval
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.storage.utils import JSONEncoder
@ -43,7 +43,8 @@ def get_code_pipeline_router() -> APIRouter:
else payload.full_input
)
retrieved_files = await code_graph_retrieval(query)
retriever = CodeRetriever()
retrieved_files = await retriever.get_context(query)
return json.dumps(retrieved_files, cls=JSONEncoder)
except Exception as error:

View file

@ -1 +1 @@
from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval
from cognee.modules.retrieval.code_retriever import CodeRetriever

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any, Optional, Callable
class BaseRetriever(ABC):
@ -14,3 +14,8 @@ class BaseRetriever(ABC):
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates a response using the query and optional context."""
pass
@classmethod
def as_search(cls) -> Callable:
"""Creates a search function from the retriever class."""
return lambda query: cls().get_completion(query)

View file

@ -1,128 +0,0 @@
import asyncio
import aiofiles
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from typing import List, Dict, Any
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
class CodeQueryInfo(BaseModel):
"""Response model for information extraction from the query"""
filenames: List[str] = []
sourcecode: str
async def code_graph_retrieval(query: str) -> list[dict[str, Any]]:
if not query or not isinstance(query, str):
raise ValueError("The query must be a non-empty string.")
file_name_collections = ["CodeFile_name"]
classes_and_functions_collections = [
"ClassDefinition_source_code",
"FunctionDefinition_source_code",
]
try:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
llm_client = get_llm_client()
try:
files_and_codeparts = await llm_client.acreate_structured_output(
text_input=query,
system_prompt=system_prompt,
response_model=CodeQueryInfo,
)
except Exception as e:
raise RuntimeError("Failed to retrieve structured output from LLM") from e
similar_filenames = []
similar_codepieces = []
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
for collection in file_name_collections:
search_results_file = await vector_engine.search(collection, query, limit=3)
for res in search_results_file:
similar_filenames.append({"id": res.id, "score": res.score, "payload": res.payload})
for collection in classes_and_functions_collections:
search_results_code = await vector_engine.search(collection, query, limit=3)
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
else:
for collection in file_name_collections:
for file_from_query in files_and_codeparts.filenames:
search_results_file = await vector_engine.search(
collection, file_from_query, limit=3
)
for res in search_results_file:
similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
for collection in classes_and_functions_collections:
for code_from_query in files_and_codeparts.sourcecode:
search_results_code = await vector_engine.search(
collection, code_from_query, limit=3
)
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
file_ids = [str(item["id"]) for item in similar_filenames]
code_ids = [str(item["id"]) for item in similar_codepieces]
relevant_triplets = await asyncio.gather(
*[graph_engine.get_connections(node_id) for node_id in code_ids + file_ids]
)
paths = set()
for sublist in relevant_triplets:
for tpl in sublist:
if isinstance(tpl, tuple) and len(tpl) >= 3:
if "file_path" in tpl[0]:
paths.add(tpl[0]["file_path"])
if "file_path" in tpl[2]: # Third tuple element
paths.add(tpl[2]["file_path"])
retrieved_files = {}
read_tasks = []
for file_path in paths:
async def read_file(fp):
try:
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
retrieved_files[fp] = await f.read()
except Exception as e:
print(f"Error reading {fp}: {e}")
retrieved_files[fp] = ""
read_tasks.append(read_file(file_path))
await asyncio.gather(*read_tasks)
result = [
{
"name": file_path,
"description": file_path,
"content": retrieved_files[file_path],
}
for file_path in paths
]
return result

View file

@ -1,221 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
import asyncio
import json
import logging
import os
from typing import Any, Callable, Dict, Type
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval
from cognee.tasks.chunks import query_chunks
from cognee.tasks.completion import (
query_completion,
graph_query_completion,
graph_query_summary_completion,
)
from cognee.tasks.graph import query_graph_connections
from cognee.tasks.summarization import query_summaries
from examples.python.dynamic_steps_example import main as setup_main
CONTEXT_DUMP_DIR = "context_dumps"
# Define retriever configurations
COMPLETION_RETRIEVERS = [
{
"name": "completion",
"old_implementation": query_completion,
"retriever_class": CompletionRetriever,
"type": "completion",
},
{
"name": "graph completion",
"old_implementation": graph_query_completion,
"retriever_class": GraphCompletionRetriever,
"type": "graph_completion",
},
{
"name": "graph summary completion",
"old_implementation": graph_query_summary_completion,
"retriever_class": GraphSummaryCompletionRetriever,
"type": "graph_summary_completion",
},
]
BASIC_RETRIEVERS = [
{
"name": "summaries search",
"old_implementation": query_summaries,
"retriever_class": SummariesRetriever,
},
{
"name": "chunks search",
"old_implementation": query_chunks,
"retriever_class": ChunksRetriever,
},
{
"name": "insights search",
"old_implementation": query_graph_connections,
"retriever_class": InsightsRetriever,
},
{
"name": "code search",
"old_implementation": code_graph_retrieval,
"retriever_class": CodeRetriever,
},
]
async def compare_completion(old_results: list, new_results: list) -> Dict:
"""Compare two lists of completion results and print differences."""
lengths_match = len(old_results) == len(new_results)
matches = []
if lengths_match:
print("Results length match")
matches = [old == new for old, new in zip(old_results, new_results)]
if all(matches):
print("All entries match")
else:
print(f"Differences found at indices: {[i for i, m in enumerate(matches) if not m]}")
print("\nDifferences:")
for i, (old, new) in enumerate(zip(old_results, new_results)):
if old != new:
print(f"\nIndex {i}:")
print("Old:", json.dumps(old, indent=2))
print("New:", json.dumps(new, indent=2))
else:
print(f"Results length mismatch: {len(old_results)} vs {len(new_results)}")
print("\nOld results:", json.dumps(old_results, indent=2))
print("\nNew results:", json.dumps(new_results, indent=2))
return {
"old_results": old_results,
"new_results": new_results,
"lengths_match": lengths_match,
"element_matches": matches,
}
async def compare_retriever(
query: str, old_implementation: Callable, new_retriever: Any, name: str
) -> Dict:
"""Compare old and new retriever implementations."""
print(f"\nComparing {name}...")
# Get results from both implementations
old_results = await old_implementation(query)
new_results = await new_retriever.get_completion(query)
return await compare_completion(old_results, new_results)
async def compare_completion_context(
query: str, old_implementation: Callable, retriever_class: Type, name: str, retriever_type: str
) -> Dict:
"""Compare context between old completion implementation and new retriever."""
print(f"\nComparing {name} contexts...")
# Get context from old implementation with dumping
context_path = f"{CONTEXT_DUMP_DIR}/{retriever_type}_{hash(query)}_context.json"
os.makedirs(CONTEXT_DUMP_DIR, exist_ok=True)
await old_implementation(query, save_context_path=context_path)
# Get context from new implementation
retriever = retriever_class()
new_context = await retriever.get_context(query)
# Read dumped context
with open(context_path, "r") as f:
old_context = json.load(f)
# Compare contexts
contexts_match = old_context == new_context
if contexts_match:
print("Contexts match exactly")
else:
print("Contexts differ:")
print("\nOld context:", json.dumps(old_context, indent=2))
print("\nNew context:", json.dumps(new_context, indent=2))
return {
"old_context": old_context,
"new_context": new_context,
"contexts_match": contexts_match,
}
async def main(query: str, comparisons: Dict[str, bool], setup_steps: Dict[str, bool]):
"""Run comparison tests for selected retrievers with the given setup configuration."""
# Ensure retriever is always False in setup steps
setup_steps["retriever"] = False
await setup_main(setup_steps)
# Compare contexts for completion-based retrievers
for retriever in COMPLETION_RETRIEVERS:
context_key = f"{retriever['type']}_context"
if comparisons.get(context_key, False):
await compare_completion_context(
query=query,
old_implementation=retriever["old_implementation"],
retriever_class=retriever["retriever_class"],
name=retriever["name"],
retriever_type=retriever["type"],
)
# Run completion comparisons
for retriever in COMPLETION_RETRIEVERS:
if comparisons.get(retriever["type"], False):
await compare_retriever(
query=query,
old_implementation=retriever["old_implementation"],
new_retriever=retriever["retriever_class"](),
name=retriever["name"],
)
# Run basic retriever comparisons
for retriever in BASIC_RETRIEVERS:
retriever_type = retriever["name"].split()[0]
if comparisons.get(retriever_type, False):
await compare_retriever(
query=query,
old_implementation=retriever["old_implementation"],
new_retriever=retriever["retriever_class"](),
name=retriever["name"],
)
if __name__ == "__main__":
logging.basicConfig(level=logging.ERROR)
test_query = "Who has experience in data science?"
comparisons = {
# Context comparisons
"completion_context": True,
"graph_completion_context": True,
"graph_summary_completion_context": True,
# Result comparisons
"summaries": True,
"chunks": True,
"insights": True,
"code": False,
"completion": True,
"graph_completion": True,
"graph_summary_completion": True,
}
setup_steps = {
"prune_data": True,
"prune_system": True,
"add_text": True,
"cognify": True,
}
asyncio.run(main(test_query, comparisons, setup_steps))

View file

@ -3,18 +3,20 @@ from typing import Callable
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine.utils import parse_id
from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.search.types import SearchType
from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.users.models import User
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.shared.utils import send_telemetry
from cognee.tasks.chunks import query_chunks
from cognee.tasks.graph import query_graph_connections
from cognee.tasks.summarization import query_summaries
from cognee.tasks.completion import query_completion
from cognee.tasks.completion import graph_query_completion
from cognee.tasks.completion import graph_query_summary_completion
from ..operations import log_query, log_result
@ -44,15 +46,14 @@ async def search(
async def specific_search(query_type: SearchType, query: str, user: User) -> list:
# TODO: update after merging COG-1365, see COG-1403
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: query_summaries,
SearchType.INSIGHTS: query_graph_connections,
SearchType.CHUNKS: query_chunks,
SearchType.COMPLETION: query_completion,
SearchType.GRAPH_COMPLETION: graph_query_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: graph_query_summary_completion,
SearchType.CODE: code_graph_retrieval,
SearchType.SUMMARIES: SummariesRetriever.as_search(),
SearchType.INSIGHTS: InsightsRetriever.as_search(),
SearchType.CHUNKS: ChunksRetriever.as_search(),
SearchType.COMPLETION: CompletionRetriever.as_search(),
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever.as_search(),
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever.as_search(),
SearchType.CODE: CodeRetriever.as_search(),
}
search_task = search_tasks.get(query_type)

View file

@ -1,4 +1,3 @@
from .query_chunks import query_chunks
from .chunk_by_word import chunk_by_word
from .chunk_by_sentence import chunk_by_sentence
from .chunk_by_paragraph import chunk_by_paragraph

View file

@ -1,27 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
from cognee.infrastructure.databases.vector import get_vector_engine
async def query_chunks(query: str) -> list[dict]:
"""
Queries the vector database to retrieve chunks related to the given query string.
Parameters:
- query (str): The query string to filter nodes by.
Returns:
- list(dict): A list of objects providing information about the chunks related to query.
Notes:
- The function uses the `search` method of the vector engine to find matches.
- Limits the results to the top 5 matching chunks to balance performance and relevance.
- Ensure that the vector database is properly initialized and contains the "DocumentChunk_text" collection.
"""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5)
chunks = [result.payload for result in found_chunks]
return chunks

View file

@ -1,3 +1 @@
from .query_completion import query_completion
from .graph_query_completion import graph_query_completion
from .graph_query_summary_completion import graph_query_summary_completion
from cognee.tasks.completion.exceptions import NoRelevantDataFound

View file

@ -1,95 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
import json
import logging
import os
from cognee.infrastructure.engine import ExtendableDataPoint
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.tasks.completion.exceptions import NoRelevantDataFound
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from typing import Callable
logger = logging.getLogger(__name__)
async def retrieved_edges_to_string(retrieved_edges: list) -> str:
"""
Converts a list of retrieved graph edges into a human-readable string format.
"""
edge_strings = []
for edge in retrieved_edges:
node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name")
node2_string = edge.node2.attributes.get("text") or edge.node2.attributes.get("name")
edge_string = edge.attributes["relationship_type"]
edge_str = f"{node1_string} -- {edge_string} -- {node2_string}"
edge_strings.append(edge_str)
return "\n---\n".join(edge_strings)
async def graph_query_completion(
query: str, context_resolver: Callable = None, save_context_path: str = None
) -> list:
"""
Executes a query on the graph database and retrieves a relevant completion based on the found data.
Parameters:
- query (str): The query string to compute.
- context_resolver (Callable): A function to convert retrieved edges to a string.
- save_context_path (str): Path to save the retrieved context.
Returns:
- list: Answer to the query.
Notes:
- The `brute_force_triplet_search` is used to retrieve relevant graph data.
- Prompts are dynamically rendered and provided to the LLM for contextual understanding.
- Ensure that the LLM client and graph database are properly configured and accessible.
"""
subclasses = get_all_subclasses(DataPoint)
vector_index_collections = []
for subclass in subclasses:
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
for field_name in index_fields:
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
found_triplets = await brute_force_triplet_search(
query, top_k=5, collections=vector_index_collections or None
)
if len(found_triplets) == 0:
raise NoRelevantDataFound
if not context_resolver:
context_resolver = retrieved_edges_to_string
# Get context and optionally dump it
context = await context_resolver(found_triplets)
if save_context_path:
try:
os.makedirs(os.path.dirname(save_context_path), exist_ok=True)
with open(save_context_path, "w") as f:
json.dump(context, f, indent=2)
except (OSError, TypeError, ValueError) as e:
logger.error(f"Failed to save context to {save_context_path}: {str(e)}")
# Consider whether to raise or continue silently
args = {
"question": query,
"context": context,
}
user_prompt = render_prompt("graph_context_for_question.txt", args)
system_prompt = read_query_prompt("answer_simple_question.txt")
llm_client = get_llm_client()
computed_answer = await llm_client.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
return [computed_answer]

View file

@ -1,30 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.tasks.completion.graph_query_completion import (
graph_query_completion,
retrieved_edges_to_string,
)
async def retrieved_edges_to_summary(retrieved_edges: list) -> str:
"""
Converts a list of retrieved graph edges into a summary without redundancies.
"""
edges_string = await retrieved_edges_to_string(retrieved_edges)
system_prompt = read_query_prompt("summarize_search_results.txt")
llm_client = get_llm_client()
summarized_context = await llm_client.acreate_structured_output(
text_input=edges_string,
system_prompt=system_prompt,
response_model=str,
)
return summarized_context
async def graph_query_summary_completion(query: str, save_context_path: str = None) -> list:
"""Executes a query on the graph database and retrieves a summarized completion with optional context saving."""
return await graph_query_completion(
query, context_resolver=retrieved_edges_to_summary, save_context_path=save_context_path
)

View file

@ -1,63 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
import json
import logging
import os
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.tasks.completion.exceptions import NoRelevantDataFound
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
logger = logging.getLogger(__name__)
async def query_completion(query: str, save_context_path: str = None) -> list:
"""
Executes a query against a vector database and computes a relevant response using an LLM.
Parameters:
- query (str): The query string to compute.
- save_context_path (str): The path to save the context.
Returns:
- list: Answer to the query.
Notes:
- Limits the search to the top 1 matching chunk for simplicity and relevance.
- Ensure that the vector database and LLM client are properly configured and accessible.
- The response model used for the LLM output is expected to be a string.
"""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1)
if len(found_chunks) == 0:
raise NoRelevantDataFound
# Get context and optionally dump it
context = found_chunks[0].payload["text"]
if save_context_path:
try:
os.makedirs(os.path.dirname(save_context_path), exist_ok=True)
with open(save_context_path, "w", encoding="utf-8") as f:
json.dump(context, f, indent=2, ensure_ascii=False)
except OSError as e:
logger.error(f"Failed to save context to {save_context_path}: {str(e)}")
# Continue execution as context saving is optional
args = {
"question": query,
"context": context,
}
user_prompt = render_prompt("context_for_question.txt", args)
system_prompt = read_query_prompt("answer_simple_question.txt")
llm_client = get_llm_client()
computed_answer = await llm_client.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
return [computed_answer]

View file

@ -1,3 +1,2 @@
from .extract_graph_from_data import extract_graph_from_data
from .extract_graph_from_code import extract_graph_from_code
from .query_graph_connections import query_graph_connections

View file

@ -1,62 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
import asyncio
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
async def query_graph_connections(query: str, exploration_levels=1) -> list[(str, str, str)]:
"""
Find the neighbours of a given node in the graph and return formed sentences.
Parameters:
- query (str): The query string to filter nodes by.
- exploration_levels (int): The number of jumps through edges to perform.
Returns:
- list[(str, str, str)]: A list containing the source and destination nodes and relationship.
"""
if query is None:
return []
node_id = query
graph_engine = await get_graph_engine()
exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "id" in exact_node:
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
else:
vector_engine = get_vector_engine()
results = await asyncio.gather(
vector_engine.search("Entity_name", query_text=query, limit=5),
vector_engine.search("EntityType_name", query_text=query, limit=5),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]
if len(relevant_results) == 0:
return []
node_connections_results = await asyncio.gather(
*[graph_engine.get_connections(result.id) for result in relevant_results]
)
node_connections = []
for neighbours in node_connections_results:
node_connections.extend(neighbours)
unique_node_connections_map = {}
unique_node_connections = []
for node_connection in node_connections:
if "id" not in node_connection[0] or "id" not in node_connection[2]:
continue
unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}"
if unique_id not in unique_node_connections_map:
unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection)
return unique_node_connections

View file

@ -1,3 +1,2 @@
from .query_summaries import query_summaries
from .summarize_code import summarize_code
from .summarize_text import summarize_text

View file

@ -1,19 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
from cognee.infrastructure.databases.vector import get_vector_engine
async def query_summaries(query: str) -> list:
"""
Parameters:
- query (str): The query string to filter summaries by.
Returns:
- list[str, UUID]: A list of objects providing information about the summaries related to query.
"""
vector_engine = get_vector_engine()
summaries_results = await vector_engine.search("TextSummary_text", query, limit=5)
summaries = [summary.payload for summary in summaries_results]
return summaries

View file

@ -2,7 +2,7 @@ import cognee
from cognee.modules.search.types import SearchType
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from functools import partial
from cognee.api.v1.cognify.cognify_v2 import get_default_tasks
import logging
@ -122,7 +122,8 @@ async def get_context_with_brute_force_triplet_search(instance: dict) -> str:
found_triplets = await brute_force_triplet_search(instance["question"], top_k=5)
search_results_str = await retrieved_edges_to_string(found_triplets)
retriever = GraphCompletionRetriever()
search_results_str = await retriever.resolve_edges_to_text(found_triplets)
return search_results_str

View file

@ -12,7 +12,7 @@ from cognee.tasks.temporal_awareness.index_graphiti_objects import (
index_and_transform_graphiti_nodes_and_edges,
)
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client
@ -49,9 +49,12 @@ async def main():
collections=["graphitinode_content", "graphitinode_name", "graphitinode_summary"],
)
retriever = GraphCompletionRetriever()
context = await retriever.resolve_edges_to_text(triplets)
args = {
"question": query,
"context": await retrieved_edges_to_string(triplets),
"context": context,
}
user_prompt = render_prompt("graph_context_for_question.txt", args)

View file

@ -37,7 +37,7 @@
" index_and_transform_graphiti_nodes_and_edges,\n",
")\n",
"from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search\n",
"from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string\n",
"from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever\n",
"from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt\n",
"from cognee.infrastructure.llm.get_llm_client import get_llm_client"
]
@ -186,7 +186,8 @@
")\n",
"\n",
"# Step 3: Preparing the Context for the LLM\n",
"context = await retrieved_edges_to_string(triplets)\n",
"retriever = GraphCompletionRetriever()\n",
"context = await retriever.resolve_edges_to_text(triplets)\n",
"\n",
"args = {\"question\": query, \"context\": context}\n",
"\n",