Fixes to search and evals

This commit is contained in:
Vasilije 2024-05-21 10:03:52 +02:00
parent 8ef23731a3
commit 63356f242a
5 changed files with 26 additions and 12 deletions

View file

@ -258,9 +258,9 @@ if __name__ == "__main__":
from cognee.shared.SourceCodeGraph import SourceCodeGraph
from cognee.api.v1.config import config
config.set_graph_model(SourceCodeGraph)
config.set_classification_model(CodeContentPrediction)
graph = await cognify()
# config.set_graph_model(SourceCodeGraph)
# config.set_classification_model(CodeContentPrediction)
# graph = await cognify()
vector_client = infrastructure_config.get_config("vector_engine")
out = await vector_client.search(collection_name ="basic_rag", query_text="show_all_processes", limit=10)

View file

@ -84,10 +84,12 @@ class InfrastructureConfig():
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None:
self.database_directory_path = self.system_root_directory + "/" + config.db_path
if self.database_directory_path is None:
self.database_directory_path = self.system_root_directory + "/" + config.db_path
if (config_entity is None or config_entity == "database_file_path") and self.database_file_path is None:
self.database_file_path = self.system_root_directory + "/" + config.db_path + "/" + config.db_name
@ -114,6 +116,9 @@ class InfrastructureConfig():
)
else:
from .databases.vector.lancedb.LanceDBAdapter import LanceDBAdapter
print("Using LanceDB as vector engine", self.database_directory_path)
print("Setting system root directory to", self.system_root_directory)
lance_db_path = self.database_directory_path + "/cognee.lancedb"
LocalStorage.ensure_directory_exists(lance_db_path)

View file

@ -1,4 +1,6 @@
from typing import Union, Dict, re
from typing import Union, Dict
import re
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
@ -13,7 +15,7 @@ def strip_exact_regex(s, substring):
# Regex to match the exact substring at the start and end
return re.sub(f"^{pattern}|{pattern}$", "", s)
async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str, infrastructure_config: Dict):
async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str=None, infrastructure_config: Dict=None):
"""
Filter nodes in the graph that contain the specified label and return their summary attributes.
This function supports both NetworkX graphs and Neo4j graph databases.
@ -29,6 +31,7 @@ async def search_categories(query:str, graph: Union[nx.Graph, any], query_label:
each representing a node with 'nodeId' and 'summary'.
"""
# Determine which client is in use based on the configuration
from cognee.infrastructure import infrastructure_config
if infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
categories_and_ids = [

View file

@ -6,7 +6,7 @@ from cognee.infrastructure.llm.get_llm_client import get_llm_client
async def categorize_relevant_category(query: str, summary, response_model: Type[BaseModel]):
llm_client = get_llm_client()
enriched_query= render_prompt("categorize_category.txt", {"query": query, "categories": summary})
enriched_query= render_prompt("categorize_categories.txt", {"query": query, "categories": summary})
print("enriched_query", enriched_query)

View file

@ -84,17 +84,22 @@ async def run_cognify_base_rag():
async def cognify_search_base_rag(content:str, context:str):
infrastructure_config.set_config({"database_directory_path": "/Users/vasa/Projects/cognee/cognee/.cognee_system/databases/cognee.lancedb"})
vector_client = infrastructure_config.get_config("vector_engine")
return_ = await vector_client.search(collection_name="basic_rag", query_text="show_all_processes", limit=10)
return_ = await vector_client.search(collection_name="basic_rag", query_text=content, limit=10)
print("results", return_)
return return_
async def cognify_search_graph(content:str, context:str):
from cognee.api.v1.search.search import search
return_ = await search(content)
return return_
search_type = 'CATEGORIES'
params = {'query': 'Ministarstvo'}
results = await search(search_type, params)
return results
@ -128,8 +133,9 @@ if __name__ == "__main__":
import asyncio
async def main():
await run_cognify_base_rag_and_search()
# await run_cognify_base_rag()
# await cognify_search_base_rag("show_all_processes", "context")
await cognify_search_graph("show_all_processes", "context")
asyncio.run(main())
# run_cognify_base_rag_and_search()
# # Data preprocessing before setting the dataset test cases