From ee67af5562547a4adee4641e77dedefbbf39ff92 Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Wed, 13 Mar 2024 14:44:47 +0100 Subject: [PATCH] Fix poetry deps --- cognitive_architecture/__init__.py | 1 + .../api/v1/search/search.py | 39 +++++++++++++++++-- .../search/vector/search_similarity.py | 9 ++++- 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/cognitive_architecture/__init__.py b/cognitive_architecture/__init__.py index 4b290afe5..1720f6f8e 100644 --- a/cognitive_architecture/__init__.py +++ b/cognitive_architecture/__init__.py @@ -1,3 +1,4 @@ from .api.v1.add.add import add from .api.v1.cognify.cognify import cognify from .api.v1.list_datasets.list_datasets import list_datasets +from .api.v1.search import search diff --git a/cognitive_architecture/api/v1/search/search.py b/cognitive_architecture/api/v1/search/search.py index 5ab34e5c2..30e970c2d 100644 --- a/cognitive_architecture/api/v1/search/search.py +++ b/cognitive_architecture/api/v1/search/search.py @@ -1,10 +1,13 @@ """ This module contains the search function that is used to search for nodes in the graph.""" from enum import Enum, auto from typing import Dict, Any, Callable, List + +from cognitive_architecture.infrastructure.databases.graph.get_graph_client import get_graph_client from cognitive_architecture.modules.search.graph.search_adjacent import search_adjacent from cognitive_architecture.modules.search.vector.search_similarity import search_similarity from cognitive_architecture.modules.search.graph.search_categories import search_categories from cognitive_architecture.modules.search.graph.search_neighbour import search_neighbour +from cognitive_architecture.shared.data_models import GraphDBType class SearchType(Enum): @@ -14,7 +17,7 @@ class SearchType(Enum): NEIGHBOR = auto() -def complex_search(graph, query_params: Dict[SearchType, Dict[str, Any]]) -> List: +async def complex_search(graph, query_params: Dict[SearchType, Dict[str, Any]]) -> List: search_functions: Dict[SearchType, Callable] = { SearchType.ADJACENT: search_adjacent, SearchType.SIMILARITY: search_similarity, @@ -24,10 +27,40 @@ def complex_search(graph, query_params: Dict[SearchType, Dict[str, Any]]) -> Lis results = set() + # Create a list to hold all the coroutine objects + search_tasks = [] + for search_type, params in query_params.items(): search_func = search_functions.get(search_type) if search_func: - search_result = search_func(graph, **params) - results.update(search_result) + # Schedule the coroutine for execution and store the task + task = search_func(graph, **params) + search_tasks.append(task) + + # Use asyncio.gather to run all scheduled tasks concurrently + search_results = await asyncio.gather(*search_tasks) + + # Update the results set with the results from all tasks + for search_result in search_results: + results.update(search_result) return list(results) + +if __name__ == "__main__": + import asyncio + + + + + query_params = { + SearchType.SIMILARITY: {'query': 'your search query here'} + } + async def main(): + graph_client = get_graph_client(GraphDBType.NETWORKX) + + await graph_client.load_graph_from_file() + graph = graph_client.graph + results = await complex_search(graph, query_params) + print(results) + + asyncio.run(main()) diff --git a/cognitive_architecture/modules/search/vector/search_similarity.py b/cognitive_architecture/modules/search/vector/search_similarity.py index ca78d8666..8f34c370b 100644 --- a/cognitive_architecture/modules/search/vector/search_similarity.py +++ b/cognitive_architecture/modules/search/vector/search_similarity.py @@ -1,7 +1,14 @@ from cognitive_architecture.infrastructure.llm.get_llm_client import get_llm_client +from cognitive_architecture.modules.cognify.graph.add_node_connections import extract_node_descriptions -async def search_similarity(query ,unique_layer_uuids): + +async def search_similarity(query:str ,graph): + + node_descriptions = await extract_node_descriptions(graph.nodes(data = True)) + # print(node_descriptions) + + unique_layer_uuids = set(node["layer_decomposition_uuid"] for node in node_descriptions) client = get_llm_client() out = []