Fixes for searches

This commit is contained in:
Vasilije 2024-05-21 19:34:18 +02:00
parent 63356f242a
commit 3fadb277cb
11 changed files with 89 additions and 47 deletions

View file

@ -111,11 +111,11 @@ async def cognify(datasets: Union[str, List[str]] = None):
added__basic_rag_chunks = await add_data_chunks_basic_rag(data_chunks)
# await asyncio.gather(
# *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"],chunk['document_id']) for chunk in
# added_chunks]
# )
#
await asyncio.gather(
*[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"],chunk['document_id']) for chunk in
added_chunks]
)
batch_size = 20
file_count = 0
files_batch = []

View file

@ -3,6 +3,8 @@ import asyncio
from enum import Enum
from typing import Dict, Any, Callable, List
from pydantic import BaseModel, field_validator
from cognee.modules.search.graph import search_cypher
from cognee.modules.search.graph.search_adjacent import search_adjacent
from cognee.modules.search.vector.search_similarity import search_similarity
from cognee.modules.search.graph.search_categories import search_categories
@ -20,7 +22,8 @@ class SearchType(Enum):
SUMMARY = 'SUMMARY'
SUMMARY_CLASSIFICATION = 'SUMMARY_CLASSIFICATION'
NODE_CLASSIFICATION = 'NODE_CLASSIFICATION'
DOCUMENT_CLASSIFICATION = 'DOCUMENT_CLASSIFICATION'
DOCUMENT_CLASSIFICATION = 'DOCUMENT_CLASSIFICATION',
CYPHER = 'CYPHER'
@staticmethod
def from_str(name: str):
@ -54,7 +57,9 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
SearchType.SIMILARITY: search_similarity,
SearchType.CATEGORIES: search_categories,
SearchType.NEIGHBOR: search_neighbour,
SearchType.SUMMARY: search_summary
SearchType.SUMMARY: search_summary,
SearchType.CYPHER: search_cypher
}
results = []

View file

@ -22,7 +22,7 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructur
if node_id is None:
return {}
from cognee.infrastructure import infrastructure_config
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
if node_id not in graph:
return {}

View file

@ -1,6 +1,7 @@
from typing import Union, Dict
import re
from pydantic import BaseModel
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
@ -15,6 +16,10 @@ def strip_exact_regex(s, substring):
# Regex to match the exact substring at the start and end
return re.sub(f"^{pattern}|{pattern}$", "", s)
class DefaultResponseModel(BaseModel):
document_id: str
async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str=None, infrastructure_config: Dict=None):
"""
Filter nodes in the graph that contain the specified label and return their summary attributes.
@ -39,20 +44,15 @@ async def search_categories(query:str, graph: Union[nx.Graph, any], query_label:
for _, data in graph.nodes(data=True)
if 'summary' in data
]
print("summaries_and_ids", categories_and_ids)
check_relevant_category = await categorize_relevant_category(query, categories_and_ids, response_model= infrastructure_config.get_config()["classification_model"])
print("check_relevant_summary", check_relevant_category)
connected_nodes = []
for id in categories_and_ids:
print("id", id)
connected_nodes.append(list(graph.neighbors(id['document_id'])))
check_relevant_category = await categorize_relevant_category(query, categories_and_ids, response_model=DefaultResponseModel )
connected_nodes = list(graph.neighbors(check_relevant_category['document_id']))
print("connected_nodes", connected_nodes)
descriptions = {node: graph.nodes[node].get('description', 'No desc available') for node in connected_nodes}
print("descs", descriptions)
return descriptions
#
# # Logic for NetworkX
# return {node: data.get('content_labels') for node, data in graph.nodes(data=True) if query_label in node and 'content_labels' in data}
elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J:
# Logic for Neo4j
cypher_query = """

View file

@ -0,0 +1,24 @@
from typing import Union, Dict
import re
import networkx as nx
from pydantic import BaseModel
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
from cognee.shared.data_models import GraphDBType
async def search_cypher(query:str, graph: Union[nx.Graph, any]):
"""
Use a Cypher query to search the graph and return the results.
"""
from cognee.infrastructure import infrastructure_config
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J:
result = await graph.run(query)
return result
else:
raise ValueError("Unsupported graph engine type.")

View file

@ -1,11 +1,13 @@
""" Fetches the context of a given node in the graph"""
from typing import Union, Dict
from neo4j import AsyncSession
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
import networkx as nx
from cognee.shared.data_models import GraphDBType
async def search_neighbour(graph: Union[nx.Graph, any], id: str, infrastructure_config: Dict,
async def search_neighbour(graph: Union[nx.Graph, any], node_id: str,
other_param: dict = None):
"""
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.
@ -20,26 +22,23 @@ async def search_neighbour(graph: Union[nx.Graph, any], id: str, infrastructure_
Returns:
- List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node.
"""
node_id = other_param.get('node_id') if other_param else None
from cognee.infrastructure import infrastructure_config
if node_id is None:
node_id = other_param.get('node_id') if other_param else None
if node_id is None:
return []
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
if isinstance(graph, nx.Graph):
if node_id not in graph:
return []
relevant_context = []
target_layer_uuid = graph.nodes[node_id].get('layer_uuid')
relevant_context = []
target_layer_uuid = graph.nodes[node_id].get('layer_uuid')
for n, attr in graph.nodes(data=True):
if attr.get('layer_uuid') == target_layer_uuid and 'description' in attr:
relevant_context.append(attr['description'])
for n, attr in graph.nodes(data=True):
if attr.get('layer_uuid') == target_layer_uuid and 'description' in attr:
relevant_context.append(attr['description'])
return relevant_context
return relevant_context
else:
raise ValueError("Graph object does not match the specified graph engine type in the configuration.")
elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J:
if isinstance(graph, AsyncSession):

View file

@ -8,7 +8,6 @@ async def categorize_relevant_category(query: str, summary, response_model: Type
enriched_query= render_prompt("categorize_categories.txt", {"query": query, "categories": summary})
print("enriched_query", enriched_query)
system_prompt = " Choose the relevant categories and return appropriate output based on the model"

View file

@ -0,0 +1 @@
""" Placeholder for BM25 implementation"""

View file

@ -0,0 +1 @@
"""Placeholder for fusions search implementation"""

View file

@ -11,6 +11,8 @@ async def search_similarity(query: str, graph):
layer_nodes = await graph_client.get_layer_nodes()
unique_layer_uuids = set(node["layer_id"] for node in layer_nodes)
print("unique_layer_uuids", unique_layer_uuids)
graph_nodes = []
@ -18,6 +20,8 @@ async def search_similarity(query: str, graph):
vector_engine = infrastructure_config.get_config()["vector_engine"]
results = await vector_engine.search(layer_id, query_text = query, limit = 10)
print("results", results)
print("len_rs", len(results))
if len(results) > 0:
graph_nodes.extend([
@ -25,25 +29,33 @@ async def search_similarity(query: str, graph):
layer_id = result.payload["references"]["cognitive_layer"],
node_id = result.payload["references"]["node_id"],
score = result.score,
) for result in results if result.score > 0.8
) for result in results if result.score > 0.3
])
if len(graph_nodes) == 0:
return []
relevant_context = []
for graph_node_data in graph_nodes:
graph_node = await graph_client.extract_node(graph_node_data["node_id"])
return graph_nodes
if "chunk_collection" not in graph_node and "chunk_id" not in graph_node:
continue
vector_point = await vector_engine.retrieve(
graph_node["chunk_collection"],
graph_node["chunk_id"],
)
relevant_context.append(vector_point.payload["text"])
return deduplicate(relevant_context)
# for graph_node_data in graph_nodes:
# if graph_node_data['score'] >0.8:
# graph_node = await graph_client.extract_node(graph_node_data["node_id"])
#
# if "chunk_collection" not in graph_node and "chunk_id" not in graph_node:
# continue
#
# vector_point = await vector_engine.retrieve(
# graph_node["chunk_collection"],
# graph_node["chunk_id"],
# )
#
# print("vector_point", vector_point.payload["text"])
#
# relevant_context.append(vector_point.payload["text"])
#
# print(relevant_context)
#
# return deduplicate(relevant_context)

View file

@ -95,10 +95,11 @@ async def cognify_search_base_rag(content:str, context:str):
async def cognify_search_graph(content:str, context:str):
from cognee.api.v1.search.search import search
search_type = 'CATEGORIES'
params = {'query': 'Ministarstvo'}
search_type = 'SIMILARITY'
params = {'query': 'Donald Trump'}
results = await search(search_type, params)
print("results", results)
return results