diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 6cbe45655..0b68cb3b7 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -142,12 +142,11 @@ class LanceDBAdapter(VectorDBInterface): score = 0, ) for result in results.to_dict("index").values()] - async def get_distances_of_collection( - self, - collection_name: str, - query_text: str = None, - query_vector: List[float] = None, - with_vector: bool = False + async def get_distance_from_collection_elements( + self, + collection_name: str, + query_text: str = None, + query_vector: List[float] = None ): if query_text is None and query_vector is None: raise ValueError("One of query_text or query_vector must be provided!") diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 97571a274..fd0fd493c 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -176,7 +176,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): ) for result in results ] - async def get_distances_of_collection( + async def get_distance_from_collection_elements( self, collection_name: str, query_text: str = None, @@ -192,8 +192,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # Get PGVectorDataPoint Table from database PGVectorDataPoint = await self.get_table(collection_name) - closest_items = [] - # Use async session to connect to the database async with self.get_async_session() as session: # Find closest vectors to query_vector diff --git a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py index 1efcd47b3..c340928f4 100644 --- a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +++ b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py @@ -142,6 +142,41 @@ class QDrantAdapter(VectorDBInterface): await client.close() return results + async def get_distance_from_collection_elements( + self, + collection_name: str, + query_text: str = None, + query_vector: List[float] = None, + with_vector: bool = False + ) -> List[ScoredResult]: + + if query_text is None and query_vector is None: + raise ValueError("One of query_text or query_vector must be provided!") + + client = self.get_qdrant_client() + + results = await client.search( + collection_name = collection_name, + query_vector = models.NamedVector( + name = "text", + vector = query_vector if query_vector is not None else (await self.embed_data([query_text]))[0], + ), + with_vectors = with_vector + ) + + await client.close() + + return [ + ScoredResult( + id = UUID(result.id), + payload = { + **result.payload, + "id": UUID(result.id), + }, + score = 1 - result.score, + ) for result in results + ] + async def search( self, collection_name: str, diff --git a/cognee/infrastructure/databases/vector/utils.py b/cognee/infrastructure/databases/vector/utils.py index ced161ea3..d5a5897a3 100644 --- a/cognee/infrastructure/databases/vector/utils.py +++ b/cognee/infrastructure/databases/vector/utils.py @@ -1,18 +1,6 @@ from typing import List - def normalize_distances(result_values: List[dict]) -> List[float]: - min_value = 100 - max_value = 0 - - for result in result_values: - value = float(result["_distance"]) - if value > max_value: - max_value = value - if value < min_value: - min_value = value - - normalized_values = [] min_value = min(result["_distance"] for result in result_values) max_value = max(result["_distance"] for result in result_values) @@ -23,4 +11,4 @@ def normalize_distances(result_values: List[dict]) -> List[float]: normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values] - return normalized_values \ No newline at end of file + return normalized_values diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index be356740f..c9848e02c 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -153,6 +153,36 @@ class WeaviateAdapter(VectorDBInterface): return await future + async def get_distance_from_collection_elements( + self, + collection_name: str, + query_text: str = None, + query_vector: List[float] = None, + with_vector: bool = False + ) -> List[ScoredResult]: + import weaviate.classes as wvc + + if query_text is None and query_vector is None: + raise ValueError("One of query_text or query_vector must be provided!") + + if query_vector is None: + query_vector = (await self.embed_data([query_text]))[0] + + search_result = self.get_collection(collection_name).query.hybrid( + query=None, + vector=query_vector, + include_vector=with_vector, + return_metadata=wvc.query.MetadataQuery(score=True), + ) + + return [ + ScoredResult( + id=UUID(str(result.uuid)), + payload=result.properties, + score=1 - float(result.metadata.score) + ) for result in search_result.objects + ] + async def search( self, collection_name: str, diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 158fb9d07..edc449db4 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -42,7 +42,7 @@ class CogneeGraph(CogneeAbstractGraph): def get_node(self, node_id: str) -> Node: return self.nodes.get(node_id, None) - def get_edges_of_node(self, node_id: str) -> List[Edge]: + def get_edges_from_node(self, node_id: str) -> List[Edge]: node = self.get_node(node_id) if node: return node.skeleton_edges @@ -50,16 +50,18 @@ class CogneeGraph(CogneeAbstractGraph): raise ValueError(f"Node with id {node_id} does not exist.") def get_edges(self)-> List[Edge]: - return edges + return self.edges - async def project_graph_from_db(self, - adapter: Union[GraphDBInterface], - node_properties_to_project: List[str], - edge_properties_to_project: List[str], - directed = True, - node_dimension = 1, - edge_dimension = 1, - memory_fragment_filter = []) -> None: + async def project_graph_from_db( + self, + adapter: Union[GraphDBInterface], + node_properties_to_project: List[str], + edge_properties_to_project: List[str], + directed = True, + node_dimension = 1, + edge_dimension = 1, + memory_fragment_filter = [], + ) -> None: if node_dimension < 1 or edge_dimension < 1: raise ValueError("Dimensions must be positive integers") @@ -158,15 +160,15 @@ class CogneeGraph(CogneeAbstractGraph): print(f"Error mapping vector distances to edges: {ex}") - async def calculate_top_triplet_importances(self, k = int) -> List: + async def calculate_top_triplet_importances(self, k: int) -> List: min_heap = [] for i, edge in enumerate(self.edges): source_node = self.get_node(edge.node1.id) target_node = self.get_node(edge.node2.id) - source_distance = source_node.attributes.get("vector_distance", 0) if source_node else 0 - target_distance = target_node.attributes.get("vector_distance", 0) if target_node else 0 - edge_distance = edge.attributes.get("vector_distance", 0) + source_distance = source_node.attributes.get("vector_distance", 1) if source_node else 1 + target_distance = target_node.attributes.get("vector_distance", 1) if target_node else 1 + edge_distance = edge.attributes.get("vector_distance", 1) total_distance = source_distance + target_distance + edge_distance diff --git a/cognee/modules/retrieval/__init__.py b/cognee/modules/retrieval/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cognee/modules/retrieval/brute_force_triplet_search.py b/cognee/modules/retrieval/brute_force_triplet_search.py new file mode 100644 index 000000000..0a4e9dea5 --- /dev/null +++ b/cognee/modules/retrieval/brute_force_triplet_search.py @@ -0,0 +1,150 @@ +import asyncio +import logging +from typing import List +from cognee.modules.users.models import User +from cognee.modules.users.methods import get_default_user +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.shared.utils import send_telemetry + +def format_triplets(edges): + print("\n\n\n") + def filter_attributes(obj, attributes): + """Helper function to filter out non-None properties, including nested dicts.""" + result = {} + for attr in attributes: + value = getattr(obj, attr, None) + if value is not None: + # If the value is a dict, extract relevant keys from it + if isinstance(value, dict): + nested_values = {k: v for k, v in value.items() if k in attributes and v is not None} + result[attr] = nested_values + else: + result[attr] = value + return result + + triplets = [] + for edge in edges: + node1 = edge.node1 + node2 = edge.node2 + edge_attributes = edge.attributes + node1_attributes = node1.attributes + node2_attributes = node2.attributes + + # Filter only non-None properties + node1_info = {key: value for key, value in node1_attributes.items() if value is not None} + node2_info = {key: value for key, value in node2_attributes.items() if value is not None} + edge_info = {key: value for key, value in edge_attributes.items() if value is not None} + + # Create the formatted triplet + triplet = ( + f"Node1: {node1_info}\n" + f"Edge: {edge_info}\n" + f"Node2: {node2_info}\n\n\n" + ) + triplets.append(triplet) + + return "".join(triplets) + + +async def brute_force_triplet_search(query: str, user: User = None, top_k = 5) -> list: + if user is None: + user = await get_default_user() + + if user is None: + raise PermissionError("No user found in the system. Please create a user.") + + retrieved_results = await brute_force_search(query, user, top_k) + + + return retrieved_results + + +def delete_duplicated_vector_db_elements(collections, results): #:TODO: This is just for now to fix vector db duplicates + results_dict = {} + for collection, results in zip(collections, results): + seen_ids = set() + unique_results = [] + for result in results: + if result.id not in seen_ids: + unique_results.append(result) + seen_ids.add(result.id) + else: + print(f"Duplicate found in collection '{collection}': {result.id}") + results_dict[collection] = unique_results + + return results_dict + + +async def brute_force_search( + query: str, + user: User, + top_k: int, + collections: List[str] = None +) -> list: + """ + Performs a brute force search to retrieve the top triplets from the graph. + + Args: + query (str): The search query. + user (User): The user performing the search. + top_k (int): The number of top results to retrieve. + collections (Optional[List[str]]): List of collections to query. Defaults to predefined collections. + + Returns: + list: The top triplet results. + """ + if not query or not isinstance(query, str): + raise ValueError("The query must be a non-empty string.") + if top_k <= 0: + raise ValueError("top_k must be a positive integer.") + + if collections is None: + collections = ["entity_name", "text_summary_text", "entity_type_name", "document_chunk_text"] + + try: + vector_engine = get_vector_engine() + graph_engine = await get_graph_engine() + except Exception as e: + logging.error("Failed to initialize engines: %s", e) + raise RuntimeError("Initialization error") from e + + send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id) + + try: + results = await asyncio.gather( + *[vector_engine.get_distance_from_collection_elements(collection, query_text=query) for collection in collections] + ) + + ############################################# :TODO: Change when vector db does not contain duplicates + node_distances = delete_duplicated_vector_db_elements(collections, results) + # node_distances = {collection: result for collection, result in zip(collections, results)} + ############################################## + + memory_fragment = CogneeGraph() + + await memory_fragment.project_graph_from_db(graph_engine, + node_properties_to_project=['id', + 'description', + 'name', + 'type', + 'text'], + edge_properties_to_project=['relationship_name']) + + await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) + + #:TODO: Change when vectordb contains edge embeddings + await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query) + + results = await memory_fragment.calculate_top_triplet_importances(k=top_k) + + send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id) + + #:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db + return results + + except Exception as e: + logging.error("Error during brute force search for user: %s, query: %s. Error: %s", user.id, query, e) + send_telemetry("cognee.brute_force_triplet_search EXECUTION FAILED", user.id) + raise RuntimeError("An error occurred during brute force search") from e diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index 02f3eaccd..92e5b5f05 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -4,6 +4,7 @@ import logging import pathlib import cognee from cognee.api.v1.search import SearchType +from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search logging.basicConfig(level=logging.DEBUG) @@ -61,6 +62,9 @@ async def main(): assert len(history) == 6, "Search history is not correct." + results = await brute_force_triplet_search('What is a quantum computer?') + assert len(results) > 0 + await cognee.prune.prune_data() assert not os.path.isdir(data_directory_path), "Local data files are not deleted" diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index bd6584cbc..3b4fa19c5 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -3,6 +3,7 @@ import logging import pathlib import cognee from cognee.api.v1.search import SearchType +from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search logging.basicConfig(level=logging.DEBUG) @@ -89,6 +90,9 @@ async def main(): history = await cognee.get_search_history() assert len(history) == 6, "Search history is not correct." + results = await brute_force_triplet_search('What is a quantum computer?') + assert len(results) > 0 + await cognee.prune.prune_data() assert not os.path.isdir(data_directory_path), "Local data files are not deleted" diff --git a/cognee/tests/test_qdrant.py b/cognee/tests/test_qdrant.py index 4c2462c3b..f32e0b4a4 100644 --- a/cognee/tests/test_qdrant.py +++ b/cognee/tests/test_qdrant.py @@ -5,6 +5,7 @@ import logging import pathlib import cognee from cognee.api.v1.search import SearchType +from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search logging.basicConfig(level=logging.DEBUG) @@ -61,6 +62,9 @@ async def main(): history = await cognee.get_search_history() assert len(history) == 6, "Search history is not correct." + results = await brute_force_triplet_search('What is a quantum computer?') + assert len(results) > 0 + await cognee.prune.prune_data() assert not os.path.isdir(data_directory_path), "Local data files are not deleted" diff --git a/cognee/tests/test_weaviate.py b/cognee/tests/test_weaviate.py index c352df13e..43ec30aaf 100644 --- a/cognee/tests/test_weaviate.py +++ b/cognee/tests/test_weaviate.py @@ -3,6 +3,7 @@ import logging import pathlib import cognee from cognee.api.v1.search import SearchType +from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search logging.basicConfig(level=logging.DEBUG) @@ -59,6 +60,9 @@ async def main(): history = await cognee.get_search_history() assert len(history) == 6, "Search history is not correct." + results = await brute_force_triplet_search('What is a quantum computer?') + assert len(results) > 0 + await cognee.prune.prune_data() assert not os.path.isdir(data_directory_path), "Local data files are not deleted" diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index bad474023..e3b748dab 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -77,11 +77,11 @@ def test_get_edges_success(setup_graph): graph.add_node(node2) edge = Edge(node1, node2) graph.add_edge(edge) - assert edge in graph.get_edges_of_node("node1") + assert edge in graph.get_edges_from_node("node1") def test_get_edges_nonexistent_node(setup_graph): """Test retrieving edges for a nonexistent node raises an exception.""" graph = setup_graph with pytest.raises(ValueError, match="Node with id nonexistent does not exist."): - graph.get_edges_of_node("nonexistent") + graph.get_edges_from_node("nonexistent") diff --git a/examples/python/dynamic_steps_example.py b/examples/python/dynamic_steps_example.py index 49b41db1c..ed5c97561 100644 --- a/examples/python/dynamic_steps_example.py +++ b/examples/python/dynamic_steps_example.py @@ -1,6 +1,7 @@ import cognee import asyncio -from cognee.pipelines.retriever.two_steps_retriever import two_step_retriever +from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.retrieval.brute_force_triplet_search import format_triplets job_1 = """ CV 1: Relevant @@ -181,8 +182,8 @@ async def main(enable_steps): # Step 4: Query insights if enable_steps.get("retriever"): - await two_step_retriever('Who has Phd?') - + results = await brute_force_triplet_search('Who has the most experience with graphic design?') + print(format_triplets(results)) if __name__ == '__main__': # Flags to enable/disable steps