Fixes for searches
This commit is contained in:
parent
63356f242a
commit
3fadb277cb
11 changed files with 89 additions and 47 deletions
|
|
@ -111,11 +111,11 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
added__basic_rag_chunks = await add_data_chunks_basic_rag(data_chunks)
|
||||
|
||||
|
||||
# await asyncio.gather(
|
||||
# *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"],chunk['document_id']) for chunk in
|
||||
# added_chunks]
|
||||
# )
|
||||
#
|
||||
await asyncio.gather(
|
||||
*[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"],chunk['document_id']) for chunk in
|
||||
added_chunks]
|
||||
)
|
||||
|
||||
batch_size = 20
|
||||
file_count = 0
|
||||
files_batch = []
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ import asyncio
|
|||
from enum import Enum
|
||||
from typing import Dict, Any, Callable, List
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from cognee.modules.search.graph import search_cypher
|
||||
from cognee.modules.search.graph.search_adjacent import search_adjacent
|
||||
from cognee.modules.search.vector.search_similarity import search_similarity
|
||||
from cognee.modules.search.graph.search_categories import search_categories
|
||||
|
|
@ -20,7 +22,8 @@ class SearchType(Enum):
|
|||
SUMMARY = 'SUMMARY'
|
||||
SUMMARY_CLASSIFICATION = 'SUMMARY_CLASSIFICATION'
|
||||
NODE_CLASSIFICATION = 'NODE_CLASSIFICATION'
|
||||
DOCUMENT_CLASSIFICATION = 'DOCUMENT_CLASSIFICATION'
|
||||
DOCUMENT_CLASSIFICATION = 'DOCUMENT_CLASSIFICATION',
|
||||
CYPHER = 'CYPHER'
|
||||
|
||||
@staticmethod
|
||||
def from_str(name: str):
|
||||
|
|
@ -54,7 +57,9 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
|
|||
SearchType.SIMILARITY: search_similarity,
|
||||
SearchType.CATEGORIES: search_categories,
|
||||
SearchType.NEIGHBOR: search_neighbour,
|
||||
SearchType.SUMMARY: search_summary
|
||||
SearchType.SUMMARY: search_summary,
|
||||
SearchType.CYPHER: search_cypher
|
||||
|
||||
}
|
||||
|
||||
results = []
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructur
|
|||
|
||||
if node_id is None:
|
||||
return {}
|
||||
|
||||
from cognee.infrastructure import infrastructure_config
|
||||
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
|
||||
if node_id not in graph:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Union, Dict
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
|
||||
|
||||
|
|
@ -15,6 +16,10 @@ def strip_exact_regex(s, substring):
|
|||
# Regex to match the exact substring at the start and end
|
||||
return re.sub(f"^{pattern}|{pattern}$", "", s)
|
||||
|
||||
|
||||
class DefaultResponseModel(BaseModel):
|
||||
document_id: str
|
||||
|
||||
async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str=None, infrastructure_config: Dict=None):
|
||||
"""
|
||||
Filter nodes in the graph that contain the specified label and return their summary attributes.
|
||||
|
|
@ -39,20 +44,15 @@ async def search_categories(query:str, graph: Union[nx.Graph, any], query_label:
|
|||
for _, data in graph.nodes(data=True)
|
||||
if 'summary' in data
|
||||
]
|
||||
print("summaries_and_ids", categories_and_ids)
|
||||
check_relevant_category = await categorize_relevant_category(query, categories_and_ids, response_model= infrastructure_config.get_config()["classification_model"])
|
||||
print("check_relevant_summary", check_relevant_category)
|
||||
|
||||
connected_nodes = []
|
||||
for id in categories_and_ids:
|
||||
print("id", id)
|
||||
connected_nodes.append(list(graph.neighbors(id['document_id'])))
|
||||
check_relevant_category = await categorize_relevant_category(query, categories_and_ids, response_model=DefaultResponseModel )
|
||||
connected_nodes = list(graph.neighbors(check_relevant_category['document_id']))
|
||||
print("connected_nodes", connected_nodes)
|
||||
descriptions = {node: graph.nodes[node].get('description', 'No desc available') for node in connected_nodes}
|
||||
print("descs", descriptions)
|
||||
return descriptions
|
||||
|
||||
#
|
||||
# # Logic for NetworkX
|
||||
# return {node: data.get('content_labels') for node, data in graph.nodes(data=True) if query_label in node and 'content_labels' in data}
|
||||
|
||||
elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J:
|
||||
# Logic for Neo4j
|
||||
cypher_query = """
|
||||
|
|
|
|||
24
cognee/modules/search/graph/search_cypher.py
Normal file
24
cognee/modules/search/graph/search_cypher.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
|
||||
from typing import Union, Dict
|
||||
import re
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
|
||||
from cognee.shared.data_models import GraphDBType
|
||||
|
||||
|
||||
async def search_cypher(query:str, graph: Union[nx.Graph, any]):
|
||||
"""
|
||||
Use a Cypher query to search the graph and return the results.
|
||||
"""
|
||||
|
||||
|
||||
from cognee.infrastructure import infrastructure_config
|
||||
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J:
|
||||
result = await graph.run(query)
|
||||
return result
|
||||
|
||||
else:
|
||||
raise ValueError("Unsupported graph engine type.")
|
||||
|
|
@ -1,11 +1,13 @@
|
|||
""" Fetches the context of a given node in the graph"""
|
||||
from typing import Union, Dict
|
||||
|
||||
from neo4j import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
import networkx as nx
|
||||
from cognee.shared.data_models import GraphDBType
|
||||
|
||||
async def search_neighbour(graph: Union[nx.Graph, any], id: str, infrastructure_config: Dict,
|
||||
async def search_neighbour(graph: Union[nx.Graph, any], node_id: str,
|
||||
other_param: dict = None):
|
||||
"""
|
||||
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.
|
||||
|
|
@ -20,26 +22,23 @@ async def search_neighbour(graph: Union[nx.Graph, any], id: str, infrastructure_
|
|||
Returns:
|
||||
- List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node.
|
||||
"""
|
||||
node_id = other_param.get('node_id') if other_param else None
|
||||
from cognee.infrastructure import infrastructure_config
|
||||
if node_id is None:
|
||||
node_id = other_param.get('node_id') if other_param else None
|
||||
|
||||
if node_id is None:
|
||||
return []
|
||||
|
||||
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
|
||||
if isinstance(graph, nx.Graph):
|
||||
if node_id not in graph:
|
||||
return []
|
||||
relevant_context = []
|
||||
target_layer_uuid = graph.nodes[node_id].get('layer_uuid')
|
||||
|
||||
relevant_context = []
|
||||
target_layer_uuid = graph.nodes[node_id].get('layer_uuid')
|
||||
for n, attr in graph.nodes(data=True):
|
||||
if attr.get('layer_uuid') == target_layer_uuid and 'description' in attr:
|
||||
relevant_context.append(attr['description'])
|
||||
|
||||
for n, attr in graph.nodes(data=True):
|
||||
if attr.get('layer_uuid') == target_layer_uuid and 'description' in attr:
|
||||
relevant_context.append(attr['description'])
|
||||
return relevant_context
|
||||
|
||||
return relevant_context
|
||||
else:
|
||||
raise ValueError("Graph object does not match the specified graph engine type in the configuration.")
|
||||
|
||||
elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J:
|
||||
if isinstance(graph, AsyncSession):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ async def categorize_relevant_category(query: str, summary, response_model: Type
|
|||
|
||||
enriched_query= render_prompt("categorize_categories.txt", {"query": query, "categories": summary})
|
||||
|
||||
print("enriched_query", enriched_query)
|
||||
|
||||
system_prompt = " Choose the relevant categories and return appropriate output based on the model"
|
||||
|
||||
|
|
|
|||
1
cognee/modules/search/vector/bm25.py
Normal file
1
cognee/modules/search/vector/bm25.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
""" Placeholder for BM25 implementation"""
|
||||
1
cognee/modules/search/vector/fusion.py
Normal file
1
cognee/modules/search/vector/fusion.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Placeholder for fusions search implementation"""
|
||||
|
|
@ -11,6 +11,8 @@ async def search_similarity(query: str, graph):
|
|||
layer_nodes = await graph_client.get_layer_nodes()
|
||||
|
||||
unique_layer_uuids = set(node["layer_id"] for node in layer_nodes)
|
||||
print("unique_layer_uuids", unique_layer_uuids)
|
||||
|
||||
|
||||
graph_nodes = []
|
||||
|
||||
|
|
@ -18,6 +20,8 @@ async def search_similarity(query: str, graph):
|
|||
vector_engine = infrastructure_config.get_config()["vector_engine"]
|
||||
|
||||
results = await vector_engine.search(layer_id, query_text = query, limit = 10)
|
||||
print("results", results)
|
||||
print("len_rs", len(results))
|
||||
|
||||
if len(results) > 0:
|
||||
graph_nodes.extend([
|
||||
|
|
@ -25,25 +29,33 @@ async def search_similarity(query: str, graph):
|
|||
layer_id = result.payload["references"]["cognitive_layer"],
|
||||
node_id = result.payload["references"]["node_id"],
|
||||
score = result.score,
|
||||
) for result in results if result.score > 0.8
|
||||
) for result in results if result.score > 0.3
|
||||
])
|
||||
|
||||
if len(graph_nodes) == 0:
|
||||
return []
|
||||
|
||||
relevant_context = []
|
||||
|
||||
for graph_node_data in graph_nodes:
|
||||
graph_node = await graph_client.extract_node(graph_node_data["node_id"])
|
||||
return graph_nodes
|
||||
|
||||
if "chunk_collection" not in graph_node and "chunk_id" not in graph_node:
|
||||
continue
|
||||
|
||||
vector_point = await vector_engine.retrieve(
|
||||
graph_node["chunk_collection"],
|
||||
graph_node["chunk_id"],
|
||||
)
|
||||
|
||||
relevant_context.append(vector_point.payload["text"])
|
||||
|
||||
return deduplicate(relevant_context)
|
||||
# for graph_node_data in graph_nodes:
|
||||
# if graph_node_data['score'] >0.8:
|
||||
# graph_node = await graph_client.extract_node(graph_node_data["node_id"])
|
||||
#
|
||||
# if "chunk_collection" not in graph_node and "chunk_id" not in graph_node:
|
||||
# continue
|
||||
#
|
||||
# vector_point = await vector_engine.retrieve(
|
||||
# graph_node["chunk_collection"],
|
||||
# graph_node["chunk_id"],
|
||||
# )
|
||||
#
|
||||
# print("vector_point", vector_point.payload["text"])
|
||||
#
|
||||
# relevant_context.append(vector_point.payload["text"])
|
||||
#
|
||||
# print(relevant_context)
|
||||
#
|
||||
# return deduplicate(relevant_context)
|
||||
|
|
|
|||
|
|
@ -95,10 +95,11 @@ async def cognify_search_base_rag(content:str, context:str):
|
|||
|
||||
async def cognify_search_graph(content:str, context:str):
|
||||
from cognee.api.v1.search.search import search
|
||||
search_type = 'CATEGORIES'
|
||||
params = {'query': 'Ministarstvo'}
|
||||
search_type = 'SIMILARITY'
|
||||
params = {'query': 'Donald Trump'}
|
||||
|
||||
results = await search(search_type, params)
|
||||
print("results", results)
|
||||
return results
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue