Fixes to search and evals

This commit is contained in:
Vasilije 2024-05-21 10:03:52 +02:00
parent 8ef23731a3
commit 63356f242a
5 changed files with 26 additions and 12 deletions

View file

@ -258,9 +258,9 @@ if __name__ == "__main__":
from cognee.shared.SourceCodeGraph import SourceCodeGraph from cognee.shared.SourceCodeGraph import SourceCodeGraph
from cognee.api.v1.config import config from cognee.api.v1.config import config
config.set_graph_model(SourceCodeGraph) # config.set_graph_model(SourceCodeGraph)
config.set_classification_model(CodeContentPrediction) # config.set_classification_model(CodeContentPrediction)
graph = await cognify() # graph = await cognify()
vector_client = infrastructure_config.get_config("vector_engine") vector_client = infrastructure_config.get_config("vector_engine")
out = await vector_client.search(collection_name ="basic_rag", query_text="show_all_processes", limit=10) out = await vector_client.search(collection_name ="basic_rag", query_text="show_all_processes", limit=10)

View file

@ -84,10 +84,12 @@ class InfrastructureConfig():
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None: if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model) self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None: if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None:
self.database_directory_path = self.system_root_directory + "/" + config.db_path self.database_directory_path = self.system_root_directory + "/" + config.db_path
if self.database_directory_path is None:
self.database_directory_path = self.system_root_directory + "/" + config.db_path
if (config_entity is None or config_entity == "database_file_path") and self.database_file_path is None: if (config_entity is None or config_entity == "database_file_path") and self.database_file_path is None:
self.database_file_path = self.system_root_directory + "/" + config.db_path + "/" + config.db_name self.database_file_path = self.system_root_directory + "/" + config.db_path + "/" + config.db_name
@ -114,6 +116,9 @@ class InfrastructureConfig():
) )
else: else:
from .databases.vector.lancedb.LanceDBAdapter import LanceDBAdapter from .databases.vector.lancedb.LanceDBAdapter import LanceDBAdapter
print("Using LanceDB as vector engine", self.database_directory_path)
print("Setting system root directory to", self.system_root_directory)
lance_db_path = self.database_directory_path + "/cognee.lancedb" lance_db_path = self.database_directory_path + "/cognee.lancedb"
LocalStorage.ensure_directory_exists(lance_db_path) LocalStorage.ensure_directory_exists(lance_db_path)

View file

@ -1,4 +1,6 @@
from typing import Union, Dict, re from typing import Union, Dict
import re
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
@ -13,7 +15,7 @@ def strip_exact_regex(s, substring):
# Regex to match the exact substring at the start and end # Regex to match the exact substring at the start and end
return re.sub(f"^{pattern}|{pattern}$", "", s) return re.sub(f"^{pattern}|{pattern}$", "", s)
async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str, infrastructure_config: Dict): async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str=None, infrastructure_config: Dict=None):
""" """
Filter nodes in the graph that contain the specified label and return their summary attributes. Filter nodes in the graph that contain the specified label and return their summary attributes.
This function supports both NetworkX graphs and Neo4j graph databases. This function supports both NetworkX graphs and Neo4j graph databases.
@ -29,6 +31,7 @@ async def search_categories(query:str, graph: Union[nx.Graph, any], query_label:
each representing a node with 'nodeId' and 'summary'. each representing a node with 'nodeId' and 'summary'.
""" """
# Determine which client is in use based on the configuration # Determine which client is in use based on the configuration
from cognee.infrastructure import infrastructure_config
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX: if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
categories_and_ids = [ categories_and_ids = [

View file

@ -6,7 +6,7 @@ from cognee.infrastructure.llm.get_llm_client import get_llm_client
async def categorize_relevant_category(query: str, summary, response_model: Type[BaseModel]): async def categorize_relevant_category(query: str, summary, response_model: Type[BaseModel]):
llm_client = get_llm_client() llm_client = get_llm_client()
enriched_query= render_prompt("categorize_category.txt", {"query": query, "categories": summary}) enriched_query= render_prompt("categorize_categories.txt", {"query": query, "categories": summary})
print("enriched_query", enriched_query) print("enriched_query", enriched_query)

View file

@ -84,17 +84,22 @@ async def run_cognify_base_rag():
async def cognify_search_base_rag(content:str, context:str): async def cognify_search_base_rag(content:str, context:str):
infrastructure_config.set_config({"database_directory_path": "/Users/vasa/Projects/cognee/cognee/.cognee_system/databases/cognee.lancedb"})
vector_client = infrastructure_config.get_config("vector_engine") vector_client = infrastructure_config.get_config("vector_engine")
return_ = await vector_client.search(collection_name="basic_rag", query_text="show_all_processes", limit=10) return_ = await vector_client.search(collection_name="basic_rag", query_text=content, limit=10)
print("results", return_) print("results", return_)
return return_ return return_
async def cognify_search_graph(content:str, context:str): async def cognify_search_graph(content:str, context:str):
from cognee.api.v1.search.search import search from cognee.api.v1.search.search import search
return_ = await search(content) search_type = 'CATEGORIES'
return return_ params = {'query': 'Ministarstvo'}
results = await search(search_type, params)
return results
@ -128,8 +133,9 @@ if __name__ == "__main__":
import asyncio import asyncio
async def main(): async def main():
await run_cognify_base_rag_and_search() # await run_cognify_base_rag()
# await cognify_search_base_rag("show_all_processes", "context")
await cognify_search_graph("show_all_processes", "context")
asyncio.run(main()) asyncio.run(main())
# run_cognify_base_rag_and_search() # run_cognify_base_rag_and_search()
# # Data preprocessing before setting the dataset test cases # # Data preprocessing before setting the dataset test cases