Fixes to search and evals
This commit is contained in:
parent
8ef23731a3
commit
63356f242a
5 changed files with 26 additions and 12 deletions
|
|
@ -258,9 +258,9 @@ if __name__ == "__main__":
|
||||||
from cognee.shared.SourceCodeGraph import SourceCodeGraph
|
from cognee.shared.SourceCodeGraph import SourceCodeGraph
|
||||||
from cognee.api.v1.config import config
|
from cognee.api.v1.config import config
|
||||||
|
|
||||||
config.set_graph_model(SourceCodeGraph)
|
# config.set_graph_model(SourceCodeGraph)
|
||||||
config.set_classification_model(CodeContentPrediction)
|
# config.set_classification_model(CodeContentPrediction)
|
||||||
graph = await cognify()
|
# graph = await cognify()
|
||||||
vector_client = infrastructure_config.get_config("vector_engine")
|
vector_client = infrastructure_config.get_config("vector_engine")
|
||||||
|
|
||||||
out = await vector_client.search(collection_name ="basic_rag", query_text="show_all_processes", limit=10)
|
out = await vector_client.search(collection_name ="basic_rag", query_text="show_all_processes", limit=10)
|
||||||
|
|
|
||||||
|
|
@ -84,10 +84,12 @@ class InfrastructureConfig():
|
||||||
|
|
||||||
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
|
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
|
||||||
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
|
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
|
||||||
|
|
||||||
if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None:
|
if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None:
|
||||||
self.database_directory_path = self.system_root_directory + "/" + config.db_path
|
self.database_directory_path = self.system_root_directory + "/" + config.db_path
|
||||||
|
|
||||||
|
if self.database_directory_path is None:
|
||||||
|
self.database_directory_path = self.system_root_directory + "/" + config.db_path
|
||||||
|
|
||||||
if (config_entity is None or config_entity == "database_file_path") and self.database_file_path is None:
|
if (config_entity is None or config_entity == "database_file_path") and self.database_file_path is None:
|
||||||
self.database_file_path = self.system_root_directory + "/" + config.db_path + "/" + config.db_name
|
self.database_file_path = self.system_root_directory + "/" + config.db_path + "/" + config.db_name
|
||||||
|
|
||||||
|
|
@ -114,6 +116,9 @@ class InfrastructureConfig():
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from .databases.vector.lancedb.LanceDBAdapter import LanceDBAdapter
|
from .databases.vector.lancedb.LanceDBAdapter import LanceDBAdapter
|
||||||
|
print("Using LanceDB as vector engine", self.database_directory_path)
|
||||||
|
print("Setting system root directory to", self.system_root_directory)
|
||||||
|
|
||||||
lance_db_path = self.database_directory_path + "/cognee.lancedb"
|
lance_db_path = self.database_directory_path + "/cognee.lancedb"
|
||||||
LocalStorage.ensure_directory_exists(lance_db_path)
|
LocalStorage.ensure_directory_exists(lance_db_path)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from typing import Union, Dict, re
|
from typing import Union, Dict
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
|
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
|
||||||
|
|
||||||
|
|
@ -13,7 +15,7 @@ def strip_exact_regex(s, substring):
|
||||||
# Regex to match the exact substring at the start and end
|
# Regex to match the exact substring at the start and end
|
||||||
return re.sub(f"^{pattern}|{pattern}$", "", s)
|
return re.sub(f"^{pattern}|{pattern}$", "", s)
|
||||||
|
|
||||||
async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str, infrastructure_config: Dict):
|
async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str=None, infrastructure_config: Dict=None):
|
||||||
"""
|
"""
|
||||||
Filter nodes in the graph that contain the specified label and return their summary attributes.
|
Filter nodes in the graph that contain the specified label and return their summary attributes.
|
||||||
This function supports both NetworkX graphs and Neo4j graph databases.
|
This function supports both NetworkX graphs and Neo4j graph databases.
|
||||||
|
|
@ -29,6 +31,7 @@ async def search_categories(query:str, graph: Union[nx.Graph, any], query_label:
|
||||||
each representing a node with 'nodeId' and 'summary'.
|
each representing a node with 'nodeId' and 'summary'.
|
||||||
"""
|
"""
|
||||||
# Determine which client is in use based on the configuration
|
# Determine which client is in use based on the configuration
|
||||||
|
from cognee.infrastructure import infrastructure_config
|
||||||
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
|
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
|
||||||
|
|
||||||
categories_and_ids = [
|
categories_and_ids = [
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
async def categorize_relevant_category(query: str, summary, response_model: Type[BaseModel]):
|
async def categorize_relevant_category(query: str, summary, response_model: Type[BaseModel]):
|
||||||
llm_client = get_llm_client()
|
llm_client = get_llm_client()
|
||||||
|
|
||||||
enriched_query= render_prompt("categorize_category.txt", {"query": query, "categories": summary})
|
enriched_query= render_prompt("categorize_categories.txt", {"query": query, "categories": summary})
|
||||||
|
|
||||||
print("enriched_query", enriched_query)
|
print("enriched_query", enriched_query)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -84,17 +84,22 @@ async def run_cognify_base_rag():
|
||||||
|
|
||||||
|
|
||||||
async def cognify_search_base_rag(content:str, context:str):
|
async def cognify_search_base_rag(content:str, context:str):
|
||||||
|
infrastructure_config.set_config({"database_directory_path": "/Users/vasa/Projects/cognee/cognee/.cognee_system/databases/cognee.lancedb"})
|
||||||
|
|
||||||
vector_client = infrastructure_config.get_config("vector_engine")
|
vector_client = infrastructure_config.get_config("vector_engine")
|
||||||
|
|
||||||
return_ = await vector_client.search(collection_name="basic_rag", query_text="show_all_processes", limit=10)
|
return_ = await vector_client.search(collection_name="basic_rag", query_text=content, limit=10)
|
||||||
|
|
||||||
print("results", return_)
|
print("results", return_)
|
||||||
return return_
|
return return_
|
||||||
|
|
||||||
async def cognify_search_graph(content:str, context:str):
|
async def cognify_search_graph(content:str, context:str):
|
||||||
from cognee.api.v1.search.search import search
|
from cognee.api.v1.search.search import search
|
||||||
return_ = await search(content)
|
search_type = 'CATEGORIES'
|
||||||
return return_
|
params = {'query': 'Ministarstvo'}
|
||||||
|
|
||||||
|
results = await search(search_type, params)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -128,8 +133,9 @@ if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
await run_cognify_base_rag_and_search()
|
# await run_cognify_base_rag()
|
||||||
|
# await cognify_search_base_rag("show_all_processes", "context")
|
||||||
|
await cognify_search_graph("show_all_processes", "context")
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
# run_cognify_base_rag_and_search()
|
# run_cognify_base_rag_and_search()
|
||||||
# # Data preprocessing before setting the dataset test cases
|
# # Data preprocessing before setting the dataset test cases
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue