<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced code search and dependency analysis for improved accuracy. - Introduced a new high-performance text embedding option. - Added an additional execution entry point for code graph processing. - New optional parameters for flexible property selection in retrieval functions. - Introduced new classes for handling import statements, function definitions, and class definitions. - Updated embedding engine selection based on configuration options. - **Bug Fixes** - Improved error handling in search operations and database queries for a more stable user experience. - Enhanced error logging for source code parsing. - **Refactor** - Streamlined asynchronous processing and refactored internal dependency extraction. - Updated configuration and integration settings to enhance overall reliability. - Restructured functions for simplified dependency handling. - **Chores** - Upgraded and reorganized dependency management with optional libraries for extended functionality. - Added new secret parameters for embedding configuration in workflow settings. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: vasilije <vas.markovic@gmail.com>
151 lines
5.1 KiB
Python
151 lines
5.1 KiB
Python
import asyncio
|
|
import logging
|
|
from typing import List
|
|
|
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
|
from cognee.modules.users.methods import get_default_user
|
|
from cognee.modules.users.models import User
|
|
from cognee.shared.utils import send_telemetry
|
|
|
|
|
|
def format_triplets(edges):
|
|
print("\n\n\n")
|
|
|
|
def filter_attributes(obj, attributes):
|
|
"""Helper function to filter out non-None properties, including nested dicts."""
|
|
result = {}
|
|
for attr in attributes:
|
|
value = getattr(obj, attr, None)
|
|
if value is not None:
|
|
# If the value is a dict, extract relevant keys from it
|
|
if isinstance(value, dict):
|
|
nested_values = {
|
|
k: v for k, v in value.items() if k in attributes and v is not None
|
|
}
|
|
result[attr] = nested_values
|
|
else:
|
|
result[attr] = value
|
|
return result
|
|
|
|
triplets = []
|
|
for edge in edges:
|
|
node1 = edge.node1
|
|
node2 = edge.node2
|
|
edge_attributes = edge.attributes
|
|
node1_attributes = node1.attributes
|
|
node2_attributes = node2.attributes
|
|
|
|
# Filter only non-None properties
|
|
node1_info = {key: value for key, value in node1_attributes.items() if value is not None}
|
|
node2_info = {key: value for key, value in node2_attributes.items() if value is not None}
|
|
edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
|
|
|
|
# Create the formatted triplet
|
|
triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n"
|
|
triplets.append(triplet)
|
|
|
|
return "".join(triplets)
|
|
|
|
|
|
async def brute_force_triplet_search(
|
|
query: str,
|
|
user: User = None,
|
|
top_k: int = 5,
|
|
collections: List[str] = None,
|
|
properties_to_project: List[str] = None,
|
|
) -> list:
|
|
if user is None:
|
|
user = await get_default_user()
|
|
|
|
if user is None:
|
|
raise PermissionError("No user found in the system. Please create a user.")
|
|
|
|
retrieved_results = await brute_force_search(
|
|
query, user, top_k, collections=collections, properties_to_project=properties_to_project
|
|
)
|
|
return retrieved_results
|
|
|
|
|
|
async def brute_force_search(
|
|
query: str,
|
|
user: User,
|
|
top_k: int,
|
|
collections: List[str] = None,
|
|
properties_to_project: List[str] = None,
|
|
) -> list:
|
|
"""
|
|
Performs a brute force search to retrieve the top triplets from the graph.
|
|
|
|
Args:
|
|
query (str): The search query.
|
|
user (User): The user performing the search.
|
|
top_k (int): The number of top results to retrieve.
|
|
collections (Optional[List[str]]): List of collections to query. Defaults to predefined collections.
|
|
|
|
Returns:
|
|
list: The top triplet results.
|
|
"""
|
|
if not query or not isinstance(query, str):
|
|
raise ValueError("The query must be a non-empty string.")
|
|
if top_k <= 0:
|
|
raise ValueError("top_k must be a positive integer.")
|
|
|
|
if collections is None:
|
|
collections = [
|
|
"Entity_name",
|
|
"TextSummary_text",
|
|
"EntityType_name",
|
|
"DocumentChunk_text",
|
|
]
|
|
|
|
try:
|
|
vector_engine = get_vector_engine()
|
|
graph_engine = await get_graph_engine()
|
|
except Exception as e:
|
|
logging.error("Failed to initialize engines: %s", e)
|
|
raise RuntimeError("Initialization error") from e
|
|
|
|
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
|
|
|
|
try:
|
|
results = await asyncio.gather(
|
|
*[
|
|
vector_engine.get_distance_from_collection_elements(collection, query_text=query)
|
|
for collection in collections
|
|
]
|
|
)
|
|
|
|
node_distances = {collection: result for collection, result in zip(collections, results)}
|
|
|
|
memory_fragment = CogneeGraph()
|
|
|
|
await memory_fragment.project_graph_from_db(
|
|
graph_engine,
|
|
node_properties_to_project=properties_to_project
|
|
or ["id", "description", "name", "type", "text"],
|
|
edge_properties_to_project=["relationship_name"],
|
|
)
|
|
|
|
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
|
|
|
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
|
|
|
|
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
|
|
|
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
|
|
|
|
return results
|
|
|
|
except Exception as error:
|
|
logging.error(
|
|
"Error during brute force search for user: %s, query: %s. Error: %s",
|
|
user.id,
|
|
query,
|
|
error,
|
|
)
|
|
send_telemetry(
|
|
"cognee.brute_force_triplet_search EXECUTION FAILED", user.id, {"error": str(error)}
|
|
)
|
|
raise RuntimeError("An error occurred during brute force search") from error
|