cognee/cognee/modules/retrieval/brute_force_triplet_search.py
Boris f9e6dcf837
fix: simplify code pipeline (#529)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
  - Enhanced code search and dependency analysis for improved accuracy.
  - Introduced a new high-performance text embedding option.
  - Added an additional execution entry point for code graph processing.
- New optional parameters for flexible property selection in retrieval
functions.
- Introduced new classes for handling import statements, function
definitions, and class definitions.
  - Updated embedding engine selection based on configuration options.

- **Bug Fixes**
- Improved error handling in search operations and database queries for
a more stable user experience.
  - Enhanced error logging for source code parsing.

- **Refactor**
- Streamlined asynchronous processing and refactored internal dependency
extraction.
- Updated configuration and integration settings to enhance overall
reliability.
  - Restructured functions for simplified dependency handling.

- **Chores**
- Upgraded and reorganized dependency management with optional libraries
for extended functionality.
- Added new secret parameters for embedding configuration in workflow
settings.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: vasilije <vas.markovic@gmail.com>
2025-02-12 23:58:48 +01:00

151 lines
5.1 KiB
Python

import asyncio
import logging
from typing import List
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
def format_triplets(edges):
print("\n\n\n")
def filter_attributes(obj, attributes):
"""Helper function to filter out non-None properties, including nested dicts."""
result = {}
for attr in attributes:
value = getattr(obj, attr, None)
if value is not None:
# If the value is a dict, extract relevant keys from it
if isinstance(value, dict):
nested_values = {
k: v for k, v in value.items() if k in attributes and v is not None
}
result[attr] = nested_values
else:
result[attr] = value
return result
triplets = []
for edge in edges:
node1 = edge.node1
node2 = edge.node2
edge_attributes = edge.attributes
node1_attributes = node1.attributes
node2_attributes = node2.attributes
# Filter only non-None properties
node1_info = {key: value for key, value in node1_attributes.items() if value is not None}
node2_info = {key: value for key, value in node2_attributes.items() if value is not None}
edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
# Create the formatted triplet
triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n"
triplets.append(triplet)
return "".join(triplets)
async def brute_force_triplet_search(
query: str,
user: User = None,
top_k: int = 5,
collections: List[str] = None,
properties_to_project: List[str] = None,
) -> list:
if user is None:
user = await get_default_user()
if user is None:
raise PermissionError("No user found in the system. Please create a user.")
retrieved_results = await brute_force_search(
query, user, top_k, collections=collections, properties_to_project=properties_to_project
)
return retrieved_results
async def brute_force_search(
query: str,
user: User,
top_k: int,
collections: List[str] = None,
properties_to_project: List[str] = None,
) -> list:
"""
Performs a brute force search to retrieve the top triplets from the graph.
Args:
query (str): The search query.
user (User): The user performing the search.
top_k (int): The number of top results to retrieve.
collections (Optional[List[str]]): List of collections to query. Defaults to predefined collections.
Returns:
list: The top triplet results.
"""
if not query or not isinstance(query, str):
raise ValueError("The query must be a non-empty string.")
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
if collections is None:
collections = [
"Entity_name",
"TextSummary_text",
"EntityType_name",
"DocumentChunk_text",
]
try:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
logging.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
try:
results = await asyncio.gather(
*[
vector_engine.get_distance_from_collection_elements(collection, query_text=query)
for collection in collections
]
)
node_distances = {collection: result for collection, result in zip(collections, results)}
memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project
or ["id", "description", "name", "type", "text"],
edge_properties_to_project=["relationship_name"],
)
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
return results
except Exception as error:
logging.error(
"Error during brute force search for user: %s, query: %s. Error: %s",
user.id,
query,
error,
)
send_telemetry(
"cognee.brute_force_triplet_search EXECUTION FAILED", user.id, {"error": str(error)}
)
raise RuntimeError("An error occurred during brute force search") from error