Transition to new retrievers, update searches (#585)

<!-- .github/pull_request_template.md -->

## Description
Delete legacy search implementations after migrating to new retriever
classes

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced search and retrieval capabilities, providing improved context
resolution for code queries, completions, summaries, and graph
connections.
  
- **Refactor**
- Shifted to a modular, object-oriented approach that consolidates query
logic and streamlines error management for a more robust and scalable
experience.
  
- **Bug Fixes**
- Improved error handling for unsupported search types and retrieval
operations.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Daniel Molnar 2025-02-27 15:25:24 +01:00 committed by GitHub
parent f9b6630024
commit d27f847753
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 38 additions and 676 deletions

View file

@ -4,7 +4,7 @@ from fastapi import APIRouter
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from cognee.api.DTO import InDTO from cognee.api.DTO import InDTO
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.modules.retrieval import code_graph_retrieval from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.storage.utils import JSONEncoder from cognee.modules.storage.utils import JSONEncoder
@ -43,7 +43,8 @@ def get_code_pipeline_router() -> APIRouter:
else payload.full_input else payload.full_input
) )
retrieved_files = await code_graph_retrieval(query) retriever = CodeRetriever()
retrieved_files = await retriever.get_context(query)
return json.dumps(retrieved_files, cls=JSONEncoder) return json.dumps(retrieved_files, cls=JSONEncoder)
except Exception as error: except Exception as error:

View file

@ -1 +1 @@
from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval from cognee.modules.retrieval.code_retriever import CodeRetriever

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional from typing import Any, Optional, Callable
class BaseRetriever(ABC): class BaseRetriever(ABC):
@ -14,3 +14,8 @@ class BaseRetriever(ABC):
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates a response using the query and optional context.""" """Generates a response using the query and optional context."""
pass pass
@classmethod
def as_search(cls) -> Callable:
"""Creates a search function from the retriever class."""
return lambda query: cls().get_completion(query)

View file

@ -1,128 +0,0 @@
import asyncio
import aiofiles
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from typing import List, Dict, Any
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
class CodeQueryInfo(BaseModel):
"""Response model for information extraction from the query"""
filenames: List[str] = []
sourcecode: str
async def code_graph_retrieval(query: str) -> list[dict[str, Any]]:
if not query or not isinstance(query, str):
raise ValueError("The query must be a non-empty string.")
file_name_collections = ["CodeFile_name"]
classes_and_functions_collections = [
"ClassDefinition_source_code",
"FunctionDefinition_source_code",
]
try:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
llm_client = get_llm_client()
try:
files_and_codeparts = await llm_client.acreate_structured_output(
text_input=query,
system_prompt=system_prompt,
response_model=CodeQueryInfo,
)
except Exception as e:
raise RuntimeError("Failed to retrieve structured output from LLM") from e
similar_filenames = []
similar_codepieces = []
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
for collection in file_name_collections:
search_results_file = await vector_engine.search(collection, query, limit=3)
for res in search_results_file:
similar_filenames.append({"id": res.id, "score": res.score, "payload": res.payload})
for collection in classes_and_functions_collections:
search_results_code = await vector_engine.search(collection, query, limit=3)
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
else:
for collection in file_name_collections:
for file_from_query in files_and_codeparts.filenames:
search_results_file = await vector_engine.search(
collection, file_from_query, limit=3
)
for res in search_results_file:
similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
for collection in classes_and_functions_collections:
for code_from_query in files_and_codeparts.sourcecode:
search_results_code = await vector_engine.search(
collection, code_from_query, limit=3
)
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
file_ids = [str(item["id"]) for item in similar_filenames]
code_ids = [str(item["id"]) for item in similar_codepieces]
relevant_triplets = await asyncio.gather(
*[graph_engine.get_connections(node_id) for node_id in code_ids + file_ids]
)
paths = set()
for sublist in relevant_triplets:
for tpl in sublist:
if isinstance(tpl, tuple) and len(tpl) >= 3:
if "file_path" in tpl[0]:
paths.add(tpl[0]["file_path"])
if "file_path" in tpl[2]: # Third tuple element
paths.add(tpl[2]["file_path"])
retrieved_files = {}
read_tasks = []
for file_path in paths:
async def read_file(fp):
try:
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
retrieved_files[fp] = await f.read()
except Exception as e:
print(f"Error reading {fp}: {e}")
retrieved_files[fp] = ""
read_tasks.append(read_file(file_path))
await asyncio.gather(*read_tasks)
result = [
{
"name": file_path,
"description": file_path,
"content": retrieved_files[file_path],
}
for file_path in paths
]
return result

View file

@ -1,221 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
import asyncio
import json
import logging
import os
from typing import Any, Callable, Dict, Type
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval
from cognee.tasks.chunks import query_chunks
from cognee.tasks.completion import (
query_completion,
graph_query_completion,
graph_query_summary_completion,
)
from cognee.tasks.graph import query_graph_connections
from cognee.tasks.summarization import query_summaries
from examples.python.dynamic_steps_example import main as setup_main
CONTEXT_DUMP_DIR = "context_dumps"
# Define retriever configurations
COMPLETION_RETRIEVERS = [
{
"name": "completion",
"old_implementation": query_completion,
"retriever_class": CompletionRetriever,
"type": "completion",
},
{
"name": "graph completion",
"old_implementation": graph_query_completion,
"retriever_class": GraphCompletionRetriever,
"type": "graph_completion",
},
{
"name": "graph summary completion",
"old_implementation": graph_query_summary_completion,
"retriever_class": GraphSummaryCompletionRetriever,
"type": "graph_summary_completion",
},
]
BASIC_RETRIEVERS = [
{
"name": "summaries search",
"old_implementation": query_summaries,
"retriever_class": SummariesRetriever,
},
{
"name": "chunks search",
"old_implementation": query_chunks,
"retriever_class": ChunksRetriever,
},
{
"name": "insights search",
"old_implementation": query_graph_connections,
"retriever_class": InsightsRetriever,
},
{
"name": "code search",
"old_implementation": code_graph_retrieval,
"retriever_class": CodeRetriever,
},
]
async def compare_completion(old_results: list, new_results: list) -> Dict:
"""Compare two lists of completion results and print differences."""
lengths_match = len(old_results) == len(new_results)
matches = []
if lengths_match:
print("Results length match")
matches = [old == new for old, new in zip(old_results, new_results)]
if all(matches):
print("All entries match")
else:
print(f"Differences found at indices: {[i for i, m in enumerate(matches) if not m]}")
print("\nDifferences:")
for i, (old, new) in enumerate(zip(old_results, new_results)):
if old != new:
print(f"\nIndex {i}:")
print("Old:", json.dumps(old, indent=2))
print("New:", json.dumps(new, indent=2))
else:
print(f"Results length mismatch: {len(old_results)} vs {len(new_results)}")
print("\nOld results:", json.dumps(old_results, indent=2))
print("\nNew results:", json.dumps(new_results, indent=2))
return {
"old_results": old_results,
"new_results": new_results,
"lengths_match": lengths_match,
"element_matches": matches,
}
async def compare_retriever(
query: str, old_implementation: Callable, new_retriever: Any, name: str
) -> Dict:
"""Compare old and new retriever implementations."""
print(f"\nComparing {name}...")
# Get results from both implementations
old_results = await old_implementation(query)
new_results = await new_retriever.get_completion(query)
return await compare_completion(old_results, new_results)
async def compare_completion_context(
query: str, old_implementation: Callable, retriever_class: Type, name: str, retriever_type: str
) -> Dict:
"""Compare context between old completion implementation and new retriever."""
print(f"\nComparing {name} contexts...")
# Get context from old implementation with dumping
context_path = f"{CONTEXT_DUMP_DIR}/{retriever_type}_{hash(query)}_context.json"
os.makedirs(CONTEXT_DUMP_DIR, exist_ok=True)
await old_implementation(query, save_context_path=context_path)
# Get context from new implementation
retriever = retriever_class()
new_context = await retriever.get_context(query)
# Read dumped context
with open(context_path, "r") as f:
old_context = json.load(f)
# Compare contexts
contexts_match = old_context == new_context
if contexts_match:
print("Contexts match exactly")
else:
print("Contexts differ:")
print("\nOld context:", json.dumps(old_context, indent=2))
print("\nNew context:", json.dumps(new_context, indent=2))
return {
"old_context": old_context,
"new_context": new_context,
"contexts_match": contexts_match,
}
async def main(query: str, comparisons: Dict[str, bool], setup_steps: Dict[str, bool]):
"""Run comparison tests for selected retrievers with the given setup configuration."""
# Ensure retriever is always False in setup steps
setup_steps["retriever"] = False
await setup_main(setup_steps)
# Compare contexts for completion-based retrievers
for retriever in COMPLETION_RETRIEVERS:
context_key = f"{retriever['type']}_context"
if comparisons.get(context_key, False):
await compare_completion_context(
query=query,
old_implementation=retriever["old_implementation"],
retriever_class=retriever["retriever_class"],
name=retriever["name"],
retriever_type=retriever["type"],
)
# Run completion comparisons
for retriever in COMPLETION_RETRIEVERS:
if comparisons.get(retriever["type"], False):
await compare_retriever(
query=query,
old_implementation=retriever["old_implementation"],
new_retriever=retriever["retriever_class"](),
name=retriever["name"],
)
# Run basic retriever comparisons
for retriever in BASIC_RETRIEVERS:
retriever_type = retriever["name"].split()[0]
if comparisons.get(retriever_type, False):
await compare_retriever(
query=query,
old_implementation=retriever["old_implementation"],
new_retriever=retriever["retriever_class"](),
name=retriever["name"],
)
if __name__ == "__main__":
logging.basicConfig(level=logging.ERROR)
test_query = "Who has experience in data science?"
comparisons = {
# Context comparisons
"completion_context": True,
"graph_completion_context": True,
"graph_summary_completion_context": True,
# Result comparisons
"summaries": True,
"chunks": True,
"insights": True,
"code": False,
"completion": True,
"graph_completion": True,
"graph_summary_completion": True,
}
setup_steps = {
"prune_data": True,
"prune_system": True,
"add_text": True,
"cognify": True,
}
asyncio.run(main(test_query, comparisons, setup_steps))

View file

@ -3,18 +3,20 @@ from typing import Callable
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType
from cognee.modules.storage.utils import JSONEncoder from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.users.permissions.methods import get_document_ids_for_user from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.tasks.chunks import query_chunks
from cognee.tasks.graph import query_graph_connections
from cognee.tasks.summarization import query_summaries
from cognee.tasks.completion import query_completion
from cognee.tasks.completion import graph_query_completion
from cognee.tasks.completion import graph_query_summary_completion
from ..operations import log_query, log_result from ..operations import log_query, log_result
@ -44,15 +46,14 @@ async def search(
async def specific_search(query_type: SearchType, query: str, user: User) -> list: async def specific_search(query_type: SearchType, query: str, user: User) -> list:
# TODO: update after merging COG-1365, see COG-1403
search_tasks: dict[SearchType, Callable] = { search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: query_summaries, SearchType.SUMMARIES: SummariesRetriever.as_search(),
SearchType.INSIGHTS: query_graph_connections, SearchType.INSIGHTS: InsightsRetriever.as_search(),
SearchType.CHUNKS: query_chunks, SearchType.CHUNKS: ChunksRetriever.as_search(),
SearchType.COMPLETION: query_completion, SearchType.COMPLETION: CompletionRetriever.as_search(),
SearchType.GRAPH_COMPLETION: graph_query_completion, SearchType.GRAPH_COMPLETION: GraphCompletionRetriever.as_search(),
SearchType.GRAPH_SUMMARY_COMPLETION: graph_query_summary_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever.as_search(),
SearchType.CODE: code_graph_retrieval, SearchType.CODE: CodeRetriever.as_search(),
} }
search_task = search_tasks.get(query_type) search_task = search_tasks.get(query_type)

View file

@ -1,4 +1,3 @@
from .query_chunks import query_chunks
from .chunk_by_word import chunk_by_word from .chunk_by_word import chunk_by_word
from .chunk_by_sentence import chunk_by_sentence from .chunk_by_sentence import chunk_by_sentence
from .chunk_by_paragraph import chunk_by_paragraph from .chunk_by_paragraph import chunk_by_paragraph

View file

@ -1,27 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
from cognee.infrastructure.databases.vector import get_vector_engine
async def query_chunks(query: str) -> list[dict]:
"""
Queries the vector database to retrieve chunks related to the given query string.
Parameters:
- query (str): The query string to filter nodes by.
Returns:
- list(dict): A list of objects providing information about the chunks related to query.
Notes:
- The function uses the `search` method of the vector engine to find matches.
- Limits the results to the top 5 matching chunks to balance performance and relevance.
- Ensure that the vector database is properly initialized and contains the "DocumentChunk_text" collection.
"""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5)
chunks = [result.payload for result in found_chunks]
return chunks

View file

@ -1,3 +1 @@
from .query_completion import query_completion from cognee.tasks.completion.exceptions import NoRelevantDataFound
from .graph_query_completion import graph_query_completion
from .graph_query_summary_completion import graph_query_summary_completion

View file

@ -1,95 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
import json
import logging
import os
from cognee.infrastructure.engine import ExtendableDataPoint
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.tasks.completion.exceptions import NoRelevantDataFound
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from typing import Callable
logger = logging.getLogger(__name__)
async def retrieved_edges_to_string(retrieved_edges: list) -> str:
"""
Converts a list of retrieved graph edges into a human-readable string format.
"""
edge_strings = []
for edge in retrieved_edges:
node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name")
node2_string = edge.node2.attributes.get("text") or edge.node2.attributes.get("name")
edge_string = edge.attributes["relationship_type"]
edge_str = f"{node1_string} -- {edge_string} -- {node2_string}"
edge_strings.append(edge_str)
return "\n---\n".join(edge_strings)
async def graph_query_completion(
query: str, context_resolver: Callable = None, save_context_path: str = None
) -> list:
"""
Executes a query on the graph database and retrieves a relevant completion based on the found data.
Parameters:
- query (str): The query string to compute.
- context_resolver (Callable): A function to convert retrieved edges to a string.
- save_context_path (str): Path to save the retrieved context.
Returns:
- list: Answer to the query.
Notes:
- The `brute_force_triplet_search` is used to retrieve relevant graph data.
- Prompts are dynamically rendered and provided to the LLM for contextual understanding.
- Ensure that the LLM client and graph database are properly configured and accessible.
"""
subclasses = get_all_subclasses(DataPoint)
vector_index_collections = []
for subclass in subclasses:
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
for field_name in index_fields:
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
found_triplets = await brute_force_triplet_search(
query, top_k=5, collections=vector_index_collections or None
)
if len(found_triplets) == 0:
raise NoRelevantDataFound
if not context_resolver:
context_resolver = retrieved_edges_to_string
# Get context and optionally dump it
context = await context_resolver(found_triplets)
if save_context_path:
try:
os.makedirs(os.path.dirname(save_context_path), exist_ok=True)
with open(save_context_path, "w") as f:
json.dump(context, f, indent=2)
except (OSError, TypeError, ValueError) as e:
logger.error(f"Failed to save context to {save_context_path}: {str(e)}")
# Consider whether to raise or continue silently
args = {
"question": query,
"context": context,
}
user_prompt = render_prompt("graph_context_for_question.txt", args)
system_prompt = read_query_prompt("answer_simple_question.txt")
llm_client = get_llm_client()
computed_answer = await llm_client.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
return [computed_answer]

View file

@ -1,30 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.tasks.completion.graph_query_completion import (
graph_query_completion,
retrieved_edges_to_string,
)
async def retrieved_edges_to_summary(retrieved_edges: list) -> str:
"""
Converts a list of retrieved graph edges into a summary without redundancies.
"""
edges_string = await retrieved_edges_to_string(retrieved_edges)
system_prompt = read_query_prompt("summarize_search_results.txt")
llm_client = get_llm_client()
summarized_context = await llm_client.acreate_structured_output(
text_input=edges_string,
system_prompt=system_prompt,
response_model=str,
)
return summarized_context
async def graph_query_summary_completion(query: str, save_context_path: str = None) -> list:
"""Executes a query on the graph database and retrieves a summarized completion with optional context saving."""
return await graph_query_completion(
query, context_resolver=retrieved_edges_to_summary, save_context_path=save_context_path
)

View file

@ -1,63 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
import json
import logging
import os
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.tasks.completion.exceptions import NoRelevantDataFound
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
logger = logging.getLogger(__name__)
async def query_completion(query: str, save_context_path: str = None) -> list:
"""
Executes a query against a vector database and computes a relevant response using an LLM.
Parameters:
- query (str): The query string to compute.
- save_context_path (str): The path to save the context.
Returns:
- list: Answer to the query.
Notes:
- Limits the search to the top 1 matching chunk for simplicity and relevance.
- Ensure that the vector database and LLM client are properly configured and accessible.
- The response model used for the LLM output is expected to be a string.
"""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1)
if len(found_chunks) == 0:
raise NoRelevantDataFound
# Get context and optionally dump it
context = found_chunks[0].payload["text"]
if save_context_path:
try:
os.makedirs(os.path.dirname(save_context_path), exist_ok=True)
with open(save_context_path, "w", encoding="utf-8") as f:
json.dump(context, f, indent=2, ensure_ascii=False)
except OSError as e:
logger.error(f"Failed to save context to {save_context_path}: {str(e)}")
# Continue execution as context saving is optional
args = {
"question": query,
"context": context,
}
user_prompt = render_prompt("context_for_question.txt", args)
system_prompt = read_query_prompt("answer_simple_question.txt")
llm_client = get_llm_client()
computed_answer = await llm_client.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
return [computed_answer]

View file

@ -1,3 +1,2 @@
from .extract_graph_from_data import extract_graph_from_data from .extract_graph_from_data import extract_graph_from_data
from .extract_graph_from_code import extract_graph_from_code from .extract_graph_from_code import extract_graph_from_code
from .query_graph_connections import query_graph_connections

View file

@ -1,62 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
import asyncio
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
async def query_graph_connections(query: str, exploration_levels=1) -> list[(str, str, str)]:
"""
Find the neighbours of a given node in the graph and return formed sentences.
Parameters:
- query (str): The query string to filter nodes by.
- exploration_levels (int): The number of jumps through edges to perform.
Returns:
- list[(str, str, str)]: A list containing the source and destination nodes and relationship.
"""
if query is None:
return []
node_id = query
graph_engine = await get_graph_engine()
exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "id" in exact_node:
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
else:
vector_engine = get_vector_engine()
results = await asyncio.gather(
vector_engine.search("Entity_name", query_text=query, limit=5),
vector_engine.search("EntityType_name", query_text=query, limit=5),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]
if len(relevant_results) == 0:
return []
node_connections_results = await asyncio.gather(
*[graph_engine.get_connections(result.id) for result in relevant_results]
)
node_connections = []
for neighbours in node_connections_results:
node_connections.extend(neighbours)
unique_node_connections_map = {}
unique_node_connections = []
for node_connection in node_connections:
if "id" not in node_connection[0] or "id" not in node_connection[2]:
continue
unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}"
if unique_id not in unique_node_connections_map:
unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection)
return unique_node_connections

View file

@ -1,3 +1,2 @@
from .query_summaries import query_summaries
from .summarize_code import summarize_code from .summarize_code import summarize_code
from .summarize_text import summarize_text from .summarize_text import summarize_text

View file

@ -1,19 +0,0 @@
# TODO: delete after merging COG-1365, see COG-1403
from cognee.infrastructure.databases.vector import get_vector_engine
async def query_summaries(query: str) -> list:
"""
Parameters:
- query (str): The query string to filter summaries by.
Returns:
- list[str, UUID]: A list of objects providing information about the summaries related to query.
"""
vector_engine = get_vector_engine()
summaries_results = await vector_engine.search("TextSummary_text", query, limit=5)
summaries = [summary.payload for summary in summaries_results]
return summaries

View file

@ -2,7 +2,7 @@ import cognee
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from functools import partial from functools import partial
from cognee.api.v1.cognify.cognify_v2 import get_default_tasks from cognee.api.v1.cognify.cognify_v2 import get_default_tasks
import logging import logging
@ -122,7 +122,8 @@ async def get_context_with_brute_force_triplet_search(instance: dict) -> str:
found_triplets = await brute_force_triplet_search(instance["question"], top_k=5) found_triplets = await brute_force_triplet_search(instance["question"], top_k=5)
search_results_str = await retrieved_edges_to_string(found_triplets) retriever = GraphCompletionRetriever()
search_results_str = await retriever.resolve_edges_to_text(found_triplets)
return search_results_str return search_results_str

View file

@ -12,7 +12,7 @@ from cognee.tasks.temporal_awareness.index_graphiti_objects import (
index_and_transform_graphiti_nodes_and_edges, index_and_transform_graphiti_nodes_and_edges,
) )
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
@ -49,9 +49,12 @@ async def main():
collections=["graphitinode_content", "graphitinode_name", "graphitinode_summary"], collections=["graphitinode_content", "graphitinode_name", "graphitinode_summary"],
) )
retriever = GraphCompletionRetriever()
context = await retriever.resolve_edges_to_text(triplets)
args = { args = {
"question": query, "question": query,
"context": await retrieved_edges_to_string(triplets), "context": context,
} }
user_prompt = render_prompt("graph_context_for_question.txt", args) user_prompt = render_prompt("graph_context_for_question.txt", args)

View file

@ -37,7 +37,7 @@
" index_and_transform_graphiti_nodes_and_edges,\n", " index_and_transform_graphiti_nodes_and_edges,\n",
")\n", ")\n",
"from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search\n", "from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search\n",
"from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string\n", "from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever\n",
"from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt\n", "from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt\n",
"from cognee.infrastructure.llm.get_llm_client import get_llm_client" "from cognee.infrastructure.llm.get_llm_client import get_llm_client"
] ]
@ -186,7 +186,8 @@
")\n", ")\n",
"\n", "\n",
"# Step 3: Preparing the Context for the LLM\n", "# Step 3: Preparing the Context for the LLM\n",
"context = await retrieved_edges_to_string(triplets)\n", "retriever = GraphCompletionRetriever()\n",
"context = await retriever.resolve_edges_to_text(triplets)\n",
"\n", "\n",
"args = {\"question\": query, \"context\": context}\n", "args = {\"question\": query, \"context\": context}\n",
"\n", "\n",