From 3fadb277cbf0a44e70bb0186feb8bcdfa8d419b9 Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Tue, 21 May 2024 19:34:18 +0200 Subject: [PATCH] Fixes for searches --- cognee/api/v1/cognify/cognify.py | 10 ++--- cognee/api/v1/search/search.py | 9 ++++- .../modules/search/graph/search_adjacent.py | 2 +- .../modules/search/graph/search_categories.py | 20 +++++----- cognee/modules/search/graph/search_cypher.py | 24 ++++++++++++ .../modules/search/graph/search_neighbour.py | 25 ++++++------ .../categorize_relevant_category.py | 1 - cognee/modules/search/vector/bm25.py | 1 + cognee/modules/search/vector/fusion.py | 1 + .../search/vector/search_similarity.py | 38 ++++++++++++------- evals/simple_rag_vs_cognee_eval.py | 5 ++- 11 files changed, 89 insertions(+), 47 deletions(-) create mode 100644 cognee/modules/search/graph/search_cypher.py create mode 100644 cognee/modules/search/vector/bm25.py create mode 100644 cognee/modules/search/vector/fusion.py diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 2cac909c2..324f25ffe 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -111,11 +111,11 @@ async def cognify(datasets: Union[str, List[str]] = None): added__basic_rag_chunks = await add_data_chunks_basic_rag(data_chunks) - # await asyncio.gather( - # *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"],chunk['document_id']) for chunk in - # added_chunks] - # ) - # + await asyncio.gather( + *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"],chunk['document_id']) for chunk in + added_chunks] + ) + batch_size = 20 file_count = 0 files_batch = [] diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index af8a7b728..4fd78e5a0 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -3,6 +3,8 @@ import asyncio from enum import Enum from typing import Dict, Any, Callable, List from pydantic import BaseModel, field_validator + +from cognee.modules.search.graph import search_cypher from cognee.modules.search.graph.search_adjacent import search_adjacent from cognee.modules.search.vector.search_similarity import search_similarity from cognee.modules.search.graph.search_categories import search_categories @@ -20,7 +22,8 @@ class SearchType(Enum): SUMMARY = 'SUMMARY' SUMMARY_CLASSIFICATION = 'SUMMARY_CLASSIFICATION' NODE_CLASSIFICATION = 'NODE_CLASSIFICATION' - DOCUMENT_CLASSIFICATION = 'DOCUMENT_CLASSIFICATION' + DOCUMENT_CLASSIFICATION = 'DOCUMENT_CLASSIFICATION', + CYPHER = 'CYPHER' @staticmethod def from_str(name: str): @@ -54,7 +57,9 @@ async def specific_search(query_params: List[SearchParameters]) -> List: SearchType.SIMILARITY: search_similarity, SearchType.CATEGORIES: search_categories, SearchType.NEIGHBOR: search_neighbour, - SearchType.SUMMARY: search_summary + SearchType.SUMMARY: search_summary, + SearchType.CYPHER: search_cypher + } results = [] diff --git a/cognee/modules/search/graph/search_adjacent.py b/cognee/modules/search/graph/search_adjacent.py index 8f3886305..97477a93a 100644 --- a/cognee/modules/search/graph/search_adjacent.py +++ b/cognee/modules/search/graph/search_adjacent.py @@ -22,7 +22,7 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructur if node_id is None: return {} - + from cognee.infrastructure import infrastructure_config if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX: if node_id not in graph: return {} diff --git a/cognee/modules/search/graph/search_categories.py b/cognee/modules/search/graph/search_categories.py index 7447a3998..7ff07554e 100644 --- a/cognee/modules/search/graph/search_categories.py +++ b/cognee/modules/search/graph/search_categories.py @@ -1,6 +1,7 @@ from typing import Union, Dict import re +from pydantic import BaseModel from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category @@ -15,6 +16,10 @@ def strip_exact_regex(s, substring): # Regex to match the exact substring at the start and end return re.sub(f"^{pattern}|{pattern}$", "", s) + +class DefaultResponseModel(BaseModel): + document_id: str + async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str=None, infrastructure_config: Dict=None): """ Filter nodes in the graph that contain the specified label and return their summary attributes. @@ -39,20 +44,15 @@ async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: for _, data in graph.nodes(data=True) if 'summary' in data ] - print("summaries_and_ids", categories_and_ids) - check_relevant_category = await categorize_relevant_category(query, categories_and_ids, response_model= infrastructure_config.get_config()["classification_model"]) - print("check_relevant_summary", check_relevant_category) - + connected_nodes = [] + for id in categories_and_ids: + print("id", id) + connected_nodes.append(list(graph.neighbors(id['document_id']))) + check_relevant_category = await categorize_relevant_category(query, categories_and_ids, response_model=DefaultResponseModel ) connected_nodes = list(graph.neighbors(check_relevant_category['document_id'])) - print("connected_nodes", connected_nodes) descriptions = {node: graph.nodes[node].get('description', 'No desc available') for node in connected_nodes} - print("descs", descriptions) return descriptions - # - # # Logic for NetworkX - # return {node: data.get('content_labels') for node, data in graph.nodes(data=True) if query_label in node and 'content_labels' in data} - elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J: # Logic for Neo4j cypher_query = """ diff --git a/cognee/modules/search/graph/search_cypher.py b/cognee/modules/search/graph/search_cypher.py new file mode 100644 index 000000000..1022004c7 --- /dev/null +++ b/cognee/modules/search/graph/search_cypher.py @@ -0,0 +1,24 @@ + +from typing import Union, Dict +import re + +import networkx as nx +from pydantic import BaseModel + +from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category +from cognee.shared.data_models import GraphDBType + + +async def search_cypher(query:str, graph: Union[nx.Graph, any]): + """ + Use a Cypher query to search the graph and return the results. + """ + + + from cognee.infrastructure import infrastructure_config + if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J: + result = await graph.run(query) + return result + + else: + raise ValueError("Unsupported graph engine type.") \ No newline at end of file diff --git a/cognee/modules/search/graph/search_neighbour.py b/cognee/modules/search/graph/search_neighbour.py index fc7b55df8..9faf5ec30 100644 --- a/cognee/modules/search/graph/search_neighbour.py +++ b/cognee/modules/search/graph/search_neighbour.py @@ -1,11 +1,13 @@ """ Fetches the context of a given node in the graph""" from typing import Union, Dict +from neo4j import AsyncSession + from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client import networkx as nx from cognee.shared.data_models import GraphDBType -async def search_neighbour(graph: Union[nx.Graph, any], id: str, infrastructure_config: Dict, +async def search_neighbour(graph: Union[nx.Graph, any], node_id: str, other_param: dict = None): """ Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions. @@ -20,26 +22,23 @@ async def search_neighbour(graph: Union[nx.Graph, any], id: str, infrastructure_ Returns: - List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node. """ - node_id = other_param.get('node_id') if other_param else None + from cognee.infrastructure import infrastructure_config + if node_id is None: + node_id = other_param.get('node_id') if other_param else None if node_id is None: return [] if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX: - if isinstance(graph, nx.Graph): - if node_id not in graph: - return [] + relevant_context = [] + target_layer_uuid = graph.nodes[node_id].get('layer_uuid') - relevant_context = [] - target_layer_uuid = graph.nodes[node_id].get('layer_uuid') + for n, attr in graph.nodes(data=True): + if attr.get('layer_uuid') == target_layer_uuid and 'description' in attr: + relevant_context.append(attr['description']) - for n, attr in graph.nodes(data=True): - if attr.get('layer_uuid') == target_layer_uuid and 'description' in attr: - relevant_context.append(attr['description']) + return relevant_context - return relevant_context - else: - raise ValueError("Graph object does not match the specified graph engine type in the configuration.") elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J: if isinstance(graph, AsyncSession): diff --git a/cognee/modules/search/llm/extraction/categorize_relevant_category.py b/cognee/modules/search/llm/extraction/categorize_relevant_category.py index b10781183..2134780ed 100644 --- a/cognee/modules/search/llm/extraction/categorize_relevant_category.py +++ b/cognee/modules/search/llm/extraction/categorize_relevant_category.py @@ -8,7 +8,6 @@ async def categorize_relevant_category(query: str, summary, response_model: Type enriched_query= render_prompt("categorize_categories.txt", {"query": query, "categories": summary}) - print("enriched_query", enriched_query) system_prompt = " Choose the relevant categories and return appropriate output based on the model" diff --git a/cognee/modules/search/vector/bm25.py b/cognee/modules/search/vector/bm25.py new file mode 100644 index 000000000..134feb819 --- /dev/null +++ b/cognee/modules/search/vector/bm25.py @@ -0,0 +1 @@ +""" Placeholder for BM25 implementation""" \ No newline at end of file diff --git a/cognee/modules/search/vector/fusion.py b/cognee/modules/search/vector/fusion.py new file mode 100644 index 000000000..48ecb7eda --- /dev/null +++ b/cognee/modules/search/vector/fusion.py @@ -0,0 +1 @@ +"""Placeholder for fusions search implementation""" \ No newline at end of file diff --git a/cognee/modules/search/vector/search_similarity.py b/cognee/modules/search/vector/search_similarity.py index 2b288996c..309d98575 100644 --- a/cognee/modules/search/vector/search_similarity.py +++ b/cognee/modules/search/vector/search_similarity.py @@ -11,6 +11,8 @@ async def search_similarity(query: str, graph): layer_nodes = await graph_client.get_layer_nodes() unique_layer_uuids = set(node["layer_id"] for node in layer_nodes) + print("unique_layer_uuids", unique_layer_uuids) + graph_nodes = [] @@ -18,6 +20,8 @@ async def search_similarity(query: str, graph): vector_engine = infrastructure_config.get_config()["vector_engine"] results = await vector_engine.search(layer_id, query_text = query, limit = 10) + print("results", results) + print("len_rs", len(results)) if len(results) > 0: graph_nodes.extend([ @@ -25,25 +29,33 @@ async def search_similarity(query: str, graph): layer_id = result.payload["references"]["cognitive_layer"], node_id = result.payload["references"]["node_id"], score = result.score, - ) for result in results if result.score > 0.8 + ) for result in results if result.score > 0.3 ]) if len(graph_nodes) == 0: return [] - relevant_context = [] - for graph_node_data in graph_nodes: - graph_node = await graph_client.extract_node(graph_node_data["node_id"]) + return graph_nodes - if "chunk_collection" not in graph_node and "chunk_id" not in graph_node: - continue - vector_point = await vector_engine.retrieve( - graph_node["chunk_collection"], - graph_node["chunk_id"], - ) - relevant_context.append(vector_point.payload["text"]) - - return deduplicate(relevant_context) + # for graph_node_data in graph_nodes: + # if graph_node_data['score'] >0.8: + # graph_node = await graph_client.extract_node(graph_node_data["node_id"]) + # + # if "chunk_collection" not in graph_node and "chunk_id" not in graph_node: + # continue + # + # vector_point = await vector_engine.retrieve( + # graph_node["chunk_collection"], + # graph_node["chunk_id"], + # ) + # + # print("vector_point", vector_point.payload["text"]) + # + # relevant_context.append(vector_point.payload["text"]) + # + # print(relevant_context) + # + # return deduplicate(relevant_context) diff --git a/evals/simple_rag_vs_cognee_eval.py b/evals/simple_rag_vs_cognee_eval.py index 247520464..6b9b84672 100644 --- a/evals/simple_rag_vs_cognee_eval.py +++ b/evals/simple_rag_vs_cognee_eval.py @@ -95,10 +95,11 @@ async def cognify_search_base_rag(content:str, context:str): async def cognify_search_graph(content:str, context:str): from cognee.api.v1.search.search import search - search_type = 'CATEGORIES' - params = {'query': 'Ministarstvo'} + search_type = 'SIMILARITY' + params = {'query': 'Donald Trump'} results = await search(search_type, params) + print("results", results) return results