Fixes for searches
This commit is contained in:
parent
63356f242a
commit
3fadb277cb
11 changed files with 89 additions and 47 deletions
|
|
@ -111,11 +111,11 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
||||||
added__basic_rag_chunks = await add_data_chunks_basic_rag(data_chunks)
|
added__basic_rag_chunks = await add_data_chunks_basic_rag(data_chunks)
|
||||||
|
|
||||||
|
|
||||||
# await asyncio.gather(
|
await asyncio.gather(
|
||||||
# *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"],chunk['document_id']) for chunk in
|
*[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"],chunk['document_id']) for chunk in
|
||||||
# added_chunks]
|
added_chunks]
|
||||||
# )
|
)
|
||||||
#
|
|
||||||
batch_size = 20
|
batch_size = 20
|
||||||
file_count = 0
|
file_count = 0
|
||||||
files_batch = []
|
files_batch = []
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ import asyncio
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Any, Callable, List
|
from typing import Dict, Any, Callable, List
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
from cognee.modules.search.graph import search_cypher
|
||||||
from cognee.modules.search.graph.search_adjacent import search_adjacent
|
from cognee.modules.search.graph.search_adjacent import search_adjacent
|
||||||
from cognee.modules.search.vector.search_similarity import search_similarity
|
from cognee.modules.search.vector.search_similarity import search_similarity
|
||||||
from cognee.modules.search.graph.search_categories import search_categories
|
from cognee.modules.search.graph.search_categories import search_categories
|
||||||
|
|
@ -20,7 +22,8 @@ class SearchType(Enum):
|
||||||
SUMMARY = 'SUMMARY'
|
SUMMARY = 'SUMMARY'
|
||||||
SUMMARY_CLASSIFICATION = 'SUMMARY_CLASSIFICATION'
|
SUMMARY_CLASSIFICATION = 'SUMMARY_CLASSIFICATION'
|
||||||
NODE_CLASSIFICATION = 'NODE_CLASSIFICATION'
|
NODE_CLASSIFICATION = 'NODE_CLASSIFICATION'
|
||||||
DOCUMENT_CLASSIFICATION = 'DOCUMENT_CLASSIFICATION'
|
DOCUMENT_CLASSIFICATION = 'DOCUMENT_CLASSIFICATION',
|
||||||
|
CYPHER = 'CYPHER'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_str(name: str):
|
def from_str(name: str):
|
||||||
|
|
@ -54,7 +57,9 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
|
||||||
SearchType.SIMILARITY: search_similarity,
|
SearchType.SIMILARITY: search_similarity,
|
||||||
SearchType.CATEGORIES: search_categories,
|
SearchType.CATEGORIES: search_categories,
|
||||||
SearchType.NEIGHBOR: search_neighbour,
|
SearchType.NEIGHBOR: search_neighbour,
|
||||||
SearchType.SUMMARY: search_summary
|
SearchType.SUMMARY: search_summary,
|
||||||
|
SearchType.CYPHER: search_cypher
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructur
|
||||||
|
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
return {}
|
return {}
|
||||||
|
from cognee.infrastructure import infrastructure_config
|
||||||
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
|
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
|
||||||
if node_id not in graph:
|
if node_id not in graph:
|
||||||
return {}
|
return {}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Union, Dict
|
from typing import Union, Dict
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
|
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
|
||||||
|
|
||||||
|
|
@ -15,6 +16,10 @@ def strip_exact_regex(s, substring):
|
||||||
# Regex to match the exact substring at the start and end
|
# Regex to match the exact substring at the start and end
|
||||||
return re.sub(f"^{pattern}|{pattern}$", "", s)
|
return re.sub(f"^{pattern}|{pattern}$", "", s)
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultResponseModel(BaseModel):
|
||||||
|
document_id: str
|
||||||
|
|
||||||
async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str=None, infrastructure_config: Dict=None):
|
async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str=None, infrastructure_config: Dict=None):
|
||||||
"""
|
"""
|
||||||
Filter nodes in the graph that contain the specified label and return their summary attributes.
|
Filter nodes in the graph that contain the specified label and return their summary attributes.
|
||||||
|
|
@ -39,20 +44,15 @@ async def search_categories(query:str, graph: Union[nx.Graph, any], query_label:
|
||||||
for _, data in graph.nodes(data=True)
|
for _, data in graph.nodes(data=True)
|
||||||
if 'summary' in data
|
if 'summary' in data
|
||||||
]
|
]
|
||||||
print("summaries_and_ids", categories_and_ids)
|
connected_nodes = []
|
||||||
check_relevant_category = await categorize_relevant_category(query, categories_and_ids, response_model= infrastructure_config.get_config()["classification_model"])
|
for id in categories_and_ids:
|
||||||
print("check_relevant_summary", check_relevant_category)
|
print("id", id)
|
||||||
|
connected_nodes.append(list(graph.neighbors(id['document_id'])))
|
||||||
|
check_relevant_category = await categorize_relevant_category(query, categories_and_ids, response_model=DefaultResponseModel )
|
||||||
connected_nodes = list(graph.neighbors(check_relevant_category['document_id']))
|
connected_nodes = list(graph.neighbors(check_relevant_category['document_id']))
|
||||||
print("connected_nodes", connected_nodes)
|
|
||||||
descriptions = {node: graph.nodes[node].get('description', 'No desc available') for node in connected_nodes}
|
descriptions = {node: graph.nodes[node].get('description', 'No desc available') for node in connected_nodes}
|
||||||
print("descs", descriptions)
|
|
||||||
return descriptions
|
return descriptions
|
||||||
|
|
||||||
#
|
|
||||||
# # Logic for NetworkX
|
|
||||||
# return {node: data.get('content_labels') for node, data in graph.nodes(data=True) if query_label in node and 'content_labels' in data}
|
|
||||||
|
|
||||||
elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J:
|
elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J:
|
||||||
# Logic for Neo4j
|
# Logic for Neo4j
|
||||||
cypher_query = """
|
cypher_query = """
|
||||||
|
|
|
||||||
24
cognee/modules/search/graph/search_cypher.py
Normal file
24
cognee/modules/search/graph/search_cypher.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
|
||||||
|
from typing import Union, Dict
|
||||||
|
import re
|
||||||
|
|
||||||
|
import networkx as nx
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
|
||||||
|
from cognee.shared.data_models import GraphDBType
|
||||||
|
|
||||||
|
|
||||||
|
async def search_cypher(query:str, graph: Union[nx.Graph, any]):
|
||||||
|
"""
|
||||||
|
Use a Cypher query to search the graph and return the results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from cognee.infrastructure import infrastructure_config
|
||||||
|
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J:
|
||||||
|
result = await graph.run(query)
|
||||||
|
return result
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported graph engine type.")
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
""" Fetches the context of a given node in the graph"""
|
""" Fetches the context of a given node in the graph"""
|
||||||
from typing import Union, Dict
|
from typing import Union, Dict
|
||||||
|
|
||||||
|
from neo4j import AsyncSession
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from cognee.shared.data_models import GraphDBType
|
from cognee.shared.data_models import GraphDBType
|
||||||
|
|
||||||
async def search_neighbour(graph: Union[nx.Graph, any], id: str, infrastructure_config: Dict,
|
async def search_neighbour(graph: Union[nx.Graph, any], node_id: str,
|
||||||
other_param: dict = None):
|
other_param: dict = None):
|
||||||
"""
|
"""
|
||||||
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.
|
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.
|
||||||
|
|
@ -20,26 +22,23 @@ async def search_neighbour(graph: Union[nx.Graph, any], id: str, infrastructure_
|
||||||
Returns:
|
Returns:
|
||||||
- List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node.
|
- List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node.
|
||||||
"""
|
"""
|
||||||
node_id = other_param.get('node_id') if other_param else None
|
from cognee.infrastructure import infrastructure_config
|
||||||
|
if node_id is None:
|
||||||
|
node_id = other_param.get('node_id') if other_param else None
|
||||||
|
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
|
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
|
||||||
if isinstance(graph, nx.Graph):
|
relevant_context = []
|
||||||
if node_id not in graph:
|
target_layer_uuid = graph.nodes[node_id].get('layer_uuid')
|
||||||
return []
|
|
||||||
|
|
||||||
relevant_context = []
|
for n, attr in graph.nodes(data=True):
|
||||||
target_layer_uuid = graph.nodes[node_id].get('layer_uuid')
|
if attr.get('layer_uuid') == target_layer_uuid and 'description' in attr:
|
||||||
|
relevant_context.append(attr['description'])
|
||||||
|
|
||||||
for n, attr in graph.nodes(data=True):
|
return relevant_context
|
||||||
if attr.get('layer_uuid') == target_layer_uuid and 'description' in attr:
|
|
||||||
relevant_context.append(attr['description'])
|
|
||||||
|
|
||||||
return relevant_context
|
|
||||||
else:
|
|
||||||
raise ValueError("Graph object does not match the specified graph engine type in the configuration.")
|
|
||||||
|
|
||||||
elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J:
|
elif infrastructure_config.get_config()["graph_engine"] == GraphDBType.NEO4J:
|
||||||
if isinstance(graph, AsyncSession):
|
if isinstance(graph, AsyncSession):
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ async def categorize_relevant_category(query: str, summary, response_model: Type
|
||||||
|
|
||||||
enriched_query= render_prompt("categorize_categories.txt", {"query": query, "categories": summary})
|
enriched_query= render_prompt("categorize_categories.txt", {"query": query, "categories": summary})
|
||||||
|
|
||||||
print("enriched_query", enriched_query)
|
|
||||||
|
|
||||||
system_prompt = " Choose the relevant categories and return appropriate output based on the model"
|
system_prompt = " Choose the relevant categories and return appropriate output based on the model"
|
||||||
|
|
||||||
|
|
|
||||||
1
cognee/modules/search/vector/bm25.py
Normal file
1
cognee/modules/search/vector/bm25.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
""" Placeholder for BM25 implementation"""
|
||||||
1
cognee/modules/search/vector/fusion.py
Normal file
1
cognee/modules/search/vector/fusion.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Placeholder for fusions search implementation"""
|
||||||
|
|
@ -11,6 +11,8 @@ async def search_similarity(query: str, graph):
|
||||||
layer_nodes = await graph_client.get_layer_nodes()
|
layer_nodes = await graph_client.get_layer_nodes()
|
||||||
|
|
||||||
unique_layer_uuids = set(node["layer_id"] for node in layer_nodes)
|
unique_layer_uuids = set(node["layer_id"] for node in layer_nodes)
|
||||||
|
print("unique_layer_uuids", unique_layer_uuids)
|
||||||
|
|
||||||
|
|
||||||
graph_nodes = []
|
graph_nodes = []
|
||||||
|
|
||||||
|
|
@ -18,6 +20,8 @@ async def search_similarity(query: str, graph):
|
||||||
vector_engine = infrastructure_config.get_config()["vector_engine"]
|
vector_engine = infrastructure_config.get_config()["vector_engine"]
|
||||||
|
|
||||||
results = await vector_engine.search(layer_id, query_text = query, limit = 10)
|
results = await vector_engine.search(layer_id, query_text = query, limit = 10)
|
||||||
|
print("results", results)
|
||||||
|
print("len_rs", len(results))
|
||||||
|
|
||||||
if len(results) > 0:
|
if len(results) > 0:
|
||||||
graph_nodes.extend([
|
graph_nodes.extend([
|
||||||
|
|
@ -25,25 +29,33 @@ async def search_similarity(query: str, graph):
|
||||||
layer_id = result.payload["references"]["cognitive_layer"],
|
layer_id = result.payload["references"]["cognitive_layer"],
|
||||||
node_id = result.payload["references"]["node_id"],
|
node_id = result.payload["references"]["node_id"],
|
||||||
score = result.score,
|
score = result.score,
|
||||||
) for result in results if result.score > 0.8
|
) for result in results if result.score > 0.3
|
||||||
])
|
])
|
||||||
|
|
||||||
if len(graph_nodes) == 0:
|
if len(graph_nodes) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
relevant_context = []
|
|
||||||
|
|
||||||
for graph_node_data in graph_nodes:
|
return graph_nodes
|
||||||
graph_node = await graph_client.extract_node(graph_node_data["node_id"])
|
|
||||||
|
|
||||||
if "chunk_collection" not in graph_node and "chunk_id" not in graph_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
vector_point = await vector_engine.retrieve(
|
|
||||||
graph_node["chunk_collection"],
|
|
||||||
graph_node["chunk_id"],
|
|
||||||
)
|
|
||||||
|
|
||||||
relevant_context.append(vector_point.payload["text"])
|
# for graph_node_data in graph_nodes:
|
||||||
|
# if graph_node_data['score'] >0.8:
|
||||||
return deduplicate(relevant_context)
|
# graph_node = await graph_client.extract_node(graph_node_data["node_id"])
|
||||||
|
#
|
||||||
|
# if "chunk_collection" not in graph_node and "chunk_id" not in graph_node:
|
||||||
|
# continue
|
||||||
|
#
|
||||||
|
# vector_point = await vector_engine.retrieve(
|
||||||
|
# graph_node["chunk_collection"],
|
||||||
|
# graph_node["chunk_id"],
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# print("vector_point", vector_point.payload["text"])
|
||||||
|
#
|
||||||
|
# relevant_context.append(vector_point.payload["text"])
|
||||||
|
#
|
||||||
|
# print(relevant_context)
|
||||||
|
#
|
||||||
|
# return deduplicate(relevant_context)
|
||||||
|
|
|
||||||
|
|
@ -95,10 +95,11 @@ async def cognify_search_base_rag(content:str, context:str):
|
||||||
|
|
||||||
async def cognify_search_graph(content:str, context:str):
|
async def cognify_search_graph(content:str, context:str):
|
||||||
from cognee.api.v1.search.search import search
|
from cognee.api.v1.search.search import search
|
||||||
search_type = 'CATEGORIES'
|
search_type = 'SIMILARITY'
|
||||||
params = {'query': 'Ministarstvo'}
|
params = {'query': 'Donald Trump'}
|
||||||
|
|
||||||
results = await search(search_type, params)
|
results = await search(search_type, params)
|
||||||
|
print("results", results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue