feat: implement combined context search (#1341)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
ba33dca592
commit
b1643414d2
36 changed files with 706 additions and 642 deletions
|
|
@ -27,6 +27,7 @@ class SearchPayloadDTO(InDTO):
|
||||||
node_name: Optional[list[str]] = Field(default=None, example=[])
|
node_name: Optional[list[str]] = Field(default=None, example=[])
|
||||||
top_k: Optional[int] = Field(default=10)
|
top_k: Optional[int] = Field(default=10)
|
||||||
only_context: bool = Field(default=False)
|
only_context: bool = Field(default=False)
|
||||||
|
use_combined_context: bool = Field(default=False)
|
||||||
|
|
||||||
|
|
||||||
def get_search_router() -> APIRouter:
|
def get_search_router() -> APIRouter:
|
||||||
|
|
@ -115,6 +116,7 @@ def get_search_router() -> APIRouter:
|
||||||
"node_name": payload.node_name,
|
"node_name": payload.node_name,
|
||||||
"top_k": payload.top_k,
|
"top_k": payload.top_k,
|
||||||
"only_context": payload.only_context,
|
"only_context": payload.only_context,
|
||||||
|
"use_combined_context": payload.use_combined_context,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -131,6 +133,7 @@ def get_search_router() -> APIRouter:
|
||||||
node_name=payload.node_name,
|
node_name=payload.node_name,
|
||||||
top_k=payload.top_k,
|
top_k=payload.top_k,
|
||||||
only_context=payload.only_context,
|
only_context=payload.only_context,
|
||||||
|
use_combined_context=payload.use_combined_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(content=results)
|
return JSONResponse(content=results)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import Union, Optional, List, Type
|
||||||
|
|
||||||
from cognee.modules.engine.models.node_set import NodeSet
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.search.methods import search as search_function
|
from cognee.modules.search.methods import search as search_function
|
||||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||||
|
|
@ -13,7 +13,7 @@ from cognee.modules.data.exceptions import DatasetNotFoundError
|
||||||
async def search(
|
async def search(
|
||||||
query_text: str,
|
query_text: str,
|
||||||
query_type: SearchType = SearchType.GRAPH_COMPLETION,
|
query_type: SearchType = SearchType.GRAPH_COMPLETION,
|
||||||
user: User = None,
|
user: Optional[User] = None,
|
||||||
datasets: Optional[Union[list[str], str]] = None,
|
datasets: Optional[Union[list[str], str]] = None,
|
||||||
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
|
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
|
@ -24,7 +24,8 @@ async def search(
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
only_context: bool = False,
|
only_context: bool = False,
|
||||||
) -> list:
|
use_combined_context: bool = False,
|
||||||
|
) -> Union[List[SearchResult], CombinedSearchResult]:
|
||||||
"""
|
"""
|
||||||
Search and query the knowledge graph for insights, information, and connections.
|
Search and query the knowledge graph for insights, information, and connections.
|
||||||
|
|
||||||
|
|
@ -193,6 +194,7 @@ async def search(
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
only_context=only_context,
|
only_context=only_context,
|
||||||
|
use_combined_context=use_combined_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
return filtered_search_results
|
return filtered_search_results
|
||||||
|
|
|
||||||
|
|
@ -180,7 +180,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
|
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
async def calculate_top_triplet_importances(self, k: int) -> List:
|
async def calculate_top_triplet_importances(self, k: int) -> List[Edge]:
|
||||||
def score(edge):
|
def score(edge):
|
||||||
n1 = edge.node1.attributes.get("vector_distance", 1)
|
n1 = edge.node1.attributes.get("vector_distance", 1)
|
||||||
n2 = edge.node2.attributes.get("vector_distance", 1)
|
n2 = edge.node2.attributes.get("vector_distance", 1)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
async def resolve_edges_to_text(retrieved_edges: list) -> str:
|
from typing import List
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
|
||||||
"""
|
"""
|
||||||
Converts retrieved graph edges into a human-readable string format.
|
Converts retrieved graph edges into a human-readable string format.
|
||||||
|
|
||||||
|
|
@ -13,7 +17,7 @@ async def resolve_edges_to_text(retrieved_edges: list) -> str:
|
||||||
- str: A formatted string representation of the nodes and their connections.
|
- str: A formatted string representation of the nodes and their connections.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_nodes(retrieved_edges: list) -> dict:
|
def _get_nodes(retrieved_edges: List[Edge]) -> dict:
|
||||||
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
||||||
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
|
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
|
||||||
"""Concatenates the top N frequent words in text."""
|
"""Concatenates the top N frequent words in text."""
|
||||||
|
|
@ -36,9 +40,9 @@ async def resolve_edges_to_text(retrieved_edges: list) -> str:
|
||||||
return separator.join(top_words)
|
return separator.join(top_words)
|
||||||
|
|
||||||
"""Creates a title, by combining first words with most frequent words from the text."""
|
"""Creates a title, by combining first words with most frequent words from the text."""
|
||||||
first_n_words = text.split()[:first_n_words]
|
first_words = text.split()[:first_n_words]
|
||||||
top_n_words = _top_n_words(text, top_n=top_n_words)
|
top_words = _top_n_words(text, top_n=first_n_words)
|
||||||
return f"{' '.join(first_n_words)}... [{top_n_words}]"
|
return f"{' '.join(first_words)}... [{top_words}]"
|
||||||
|
|
||||||
"""Creates a dictionary of nodes with their names and content."""
|
"""Creates a dictionary of nodes with their names and content."""
|
||||||
nodes = {}
|
nodes = {}
|
||||||
|
|
|
||||||
18
cognee/modules/retrieval/base_graph_retriever.py
Normal file
18
cognee/modules/retrieval/base_graph_retriever.py
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGraphRetriever(ABC):
|
||||||
|
"""Base class for all graph based retrievers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_context(self, query: str) -> List[Edge]:
|
||||||
|
"""Retrieves triplets based on the query."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_completion(self, query: str, context: Optional[List[Edge]] = None) -> str:
|
||||||
|
"""Generates a response using the query and optional context (triplets)."""
|
||||||
|
pass
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional, Callable
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
class BaseRetriever(ABC):
|
class BaseRetriever(ABC):
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,6 @@
|
||||||
|
import asyncio
|
||||||
|
from functools import reduce
|
||||||
|
from typing import List, Optional
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.tasks.codingagents.coding_rule_associations import get_existing_rules
|
from cognee.tasks.codingagents.coding_rule_associations import get_existing_rules
|
||||||
|
|
||||||
|
|
@ -7,16 +10,22 @@ logger = get_logger("CodingRulesRetriever")
|
||||||
class CodingRulesRetriever:
|
class CodingRulesRetriever:
|
||||||
"""Retriever for handling codeing rule based searches."""
|
"""Retriever for handling codeing rule based searches."""
|
||||||
|
|
||||||
def __init__(self, rules_nodeset_name="coding_agent_rules"):
|
def __init__(self, rules_nodeset_name: Optional[List[str]] = None):
|
||||||
if isinstance(rules_nodeset_name, list):
|
if isinstance(rules_nodeset_name, list):
|
||||||
if not rules_nodeset_name:
|
if not rules_nodeset_name:
|
||||||
# If there is no provided nodeset set to coding_agent_rules
|
# If there is no provided nodeset set to coding_agent_rules
|
||||||
rules_nodeset_name = ["coding_agent_rules"]
|
rules_nodeset_name = ["coding_agent_rules"]
|
||||||
rules_nodeset_name = rules_nodeset_name[0]
|
|
||||||
self.rules_nodeset_name = rules_nodeset_name
|
self.rules_nodeset_name = rules_nodeset_name
|
||||||
"""Initialize retriever with search parameters."""
|
"""Initialize retriever with search parameters."""
|
||||||
|
|
||||||
async def get_existing_rules(self, query_text):
|
async def get_existing_rules(self, query_text):
|
||||||
return await get_existing_rules(
|
if self.rules_nodeset_name:
|
||||||
rules_nodeset_name=self.rules_nodeset_name, return_list=True
|
rules_list = await asyncio.gather(
|
||||||
)
|
*[
|
||||||
|
get_existing_rules(rules_nodeset_name=nodeset)
|
||||||
|
for nodeset in self.rules_nodeset_name
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return reduce(lambda x, y: x + y, rules_list, [])
|
||||||
|
|
|
||||||
|
|
@ -23,16 +23,14 @@ class CompletionRetriever(BaseRetriever):
|
||||||
self,
|
self,
|
||||||
user_prompt_path: str = "context_for_question.txt",
|
user_prompt_path: str = "context_for_question.txt",
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
system_prompt: str = None,
|
system_prompt: Optional[str] = None,
|
||||||
top_k: Optional[int] = 1,
|
top_k: Optional[int] = 1,
|
||||||
only_context: bool = False,
|
|
||||||
):
|
):
|
||||||
"""Initialize retriever with optional custom prompt paths."""
|
"""Initialize retriever with optional custom prompt paths."""
|
||||||
self.user_prompt_path = user_prompt_path
|
self.user_prompt_path = user_prompt_path
|
||||||
self.system_prompt_path = system_prompt_path
|
self.system_prompt_path = system_prompt_path
|
||||||
self.top_k = top_k if top_k is not None else 1
|
self.top_k = top_k if top_k is not None else 1
|
||||||
self.system_prompt = system_prompt
|
self.system_prompt = system_prompt
|
||||||
self.only_context = only_context
|
|
||||||
|
|
||||||
async def get_context(self, query: str) -> str:
|
async def get_context(self, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -69,7 +67,7 @@ class CompletionRetriever(BaseRetriever):
|
||||||
logger.error("DocumentChunk_text collection not found")
|
logger.error("DocumentChunk_text collection not found")
|
||||||
raise NoDataError("No data found in the system, please add data first.") from error
|
raise NoDataError("No data found in the system, please add data first.") from error
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Generates an LLM completion using the context.
|
Generates an LLM completion using the context.
|
||||||
|
|
||||||
|
|
@ -97,6 +95,5 @@ class CompletionRetriever(BaseRetriever):
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
only_context=self.only_context,
|
|
||||||
)
|
)
|
||||||
return [completion]
|
return completion
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ class TripletSearchContextProvider(BaseContextProvider):
|
||||||
tasks = [
|
tasks = [
|
||||||
brute_force_triplet_search(
|
brute_force_triplet_search(
|
||||||
query=f"{entity_text} {query}",
|
query=f"{entity_text} {query}",
|
||||||
|
user=user,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
collections=self.collections,
|
collections=self.collections,
|
||||||
properties_to_project=self.properties_to_project,
|
properties_to_project=self.properties_to_project,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, Optional, List, Type
|
from typing import Optional, List, Type
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
|
|
@ -31,7 +32,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = None,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
only_context: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
user_prompt_path=user_prompt_path,
|
user_prompt_path=user_prompt_path,
|
||||||
|
|
@ -41,15 +41,14 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
only_context=only_context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[Any] = None,
|
context: Optional[List[Edge]] = None,
|
||||||
context_extension_rounds=4,
|
context_extension_rounds=4,
|
||||||
) -> List[str]:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Extends the context for a given query by retrieving related triplets and generating new
|
Extends the context for a given query by retrieving related triplets and generating new
|
||||||
completions based on them.
|
completions based on them.
|
||||||
|
|
@ -74,11 +73,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
- List[str]: A list containing the generated answer based on the query and the
|
- List[str]: A list containing the generated answer based on the query and the
|
||||||
extended context.
|
extended context.
|
||||||
"""
|
"""
|
||||||
triplets = []
|
triplets = context
|
||||||
|
|
||||||
if context is None:
|
if triplets is None:
|
||||||
triplets += await self.get_triplets(query)
|
triplets = await self.get_context(query)
|
||||||
context = await self.resolve_edges_to_text(triplets)
|
|
||||||
|
context_text = await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
round_idx = 1
|
round_idx = 1
|
||||||
|
|
||||||
|
|
@ -90,15 +90,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
)
|
)
|
||||||
completion = await generate_completion(
|
completion = await generate_completion(
|
||||||
query=query,
|
query=query,
|
||||||
context=context,
|
context=context_text,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
triplets += await self.get_triplets(completion)
|
triplets += await self.get_context(completion)
|
||||||
triplets = list(set(triplets))
|
triplets = list(set(triplets))
|
||||||
context = await self.resolve_edges_to_text(triplets)
|
context_text = await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
num_triplets = len(triplets)
|
num_triplets = len(triplets)
|
||||||
|
|
||||||
|
|
@ -117,19 +117,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
completion = await generate_completion(
|
completion = await generate_completion(
|
||||||
query=query,
|
query=query,
|
||||||
context=context,
|
context=context_text,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
only_context=self.only_context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context_text and triplets and completion:
|
||||||
await self.save_qa(
|
await self.save_qa(
|
||||||
question=query, answer=completion, context=context, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.only_context:
|
return completion
|
||||||
return [context]
|
|
||||||
else:
|
|
||||||
return [completion]
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, Optional, List, Tuple, Type
|
from typing import Optional, List, Type
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
|
|
@ -32,18 +33,16 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
|
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
|
||||||
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
||||||
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
||||||
system_prompt: str = None,
|
system_prompt: Optional[str] = None,
|
||||||
top_k: Optional[int] = 5,
|
top_k: Optional[int] = 5,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = None,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
only_context: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
user_prompt_path=user_prompt_path,
|
user_prompt_path=user_prompt_path,
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
only_context=only_context,
|
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
|
|
@ -57,9 +56,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[Any] = None,
|
context: Optional[List[Edge]] = None,
|
||||||
max_iter=4,
|
max_iter=4,
|
||||||
) -> List[str]:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate completion responses based on a user query and contextual information.
|
Generate completion responses based on a user query and contextual information.
|
||||||
|
|
||||||
|
|
@ -84,26 +83,29 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
"""
|
"""
|
||||||
followup_question = ""
|
followup_question = ""
|
||||||
triplets = []
|
triplets = []
|
||||||
completion = [""]
|
completion = ""
|
||||||
|
|
||||||
for round_idx in range(max_iter + 1):
|
for round_idx in range(max_iter + 1):
|
||||||
if round_idx == 0:
|
if round_idx == 0:
|
||||||
if context is None:
|
if context is None:
|
||||||
context = await self.get_context(query)
|
triplets = await self.get_context(query)
|
||||||
|
context_text = await self.resolve_edges_to_text(triplets)
|
||||||
|
else:
|
||||||
|
context_text = await self.resolve_edges_to_text(context)
|
||||||
else:
|
else:
|
||||||
triplets += await self.get_triplets(followup_question)
|
triplets += await self.get_context(followup_question)
|
||||||
context = await self.resolve_edges_to_text(list(set(triplets)))
|
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
||||||
|
|
||||||
completion = await generate_completion(
|
completion = await generate_completion(
|
||||||
query=query,
|
query=query,
|
||||||
context=context,
|
context=context_text,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
)
|
)
|
||||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||||
if round_idx < max_iter:
|
if round_idx < max_iter:
|
||||||
valid_args = {"query": query, "answer": completion, "context": context}
|
valid_args = {"query": query, "answer": completion, "context": context_text}
|
||||||
valid_user_prompt = LLMGateway.render_prompt(
|
valid_user_prompt = LLMGateway.render_prompt(
|
||||||
filename=self.validation_user_prompt_path, context=valid_args
|
filename=self.validation_user_prompt_path, context=valid_args
|
||||||
)
|
)
|
||||||
|
|
@ -133,10 +135,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context and triplets and completion:
|
||||||
await self.save_qa(
|
await self.save_qa(
|
||||||
question=query, answer=completion, context=context, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.only_context:
|
return completion
|
||||||
return [context]
|
|
||||||
else:
|
|
||||||
return [completion]
|
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,15 @@
|
||||||
from typing import Any, Optional, Type, List
|
from typing import Any, Optional, Type, List
|
||||||
from collections import Counter
|
|
||||||
from uuid import NAMESPACE_OID, uuid5
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
import string
|
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
||||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
||||||
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
||||||
|
|
@ -20,7 +19,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
logger = get_logger("GraphCompletionRetriever")
|
logger = get_logger("GraphCompletionRetriever")
|
||||||
|
|
||||||
|
|
||||||
class GraphCompletionRetriever(BaseRetriever):
|
class GraphCompletionRetriever(BaseGraphRetriever):
|
||||||
"""
|
"""
|
||||||
Retriever for handling graph-based completion searches.
|
Retriever for handling graph-based completion searches.
|
||||||
|
|
||||||
|
|
@ -37,19 +36,17 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
self,
|
self,
|
||||||
user_prompt_path: str = "graph_context_for_question.txt",
|
user_prompt_path: str = "graph_context_for_question.txt",
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
system_prompt: str = None,
|
system_prompt: Optional[str] = None,
|
||||||
top_k: Optional[int] = 5,
|
top_k: Optional[int] = 5,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = None,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
only_context: bool = False,
|
|
||||||
):
|
):
|
||||||
"""Initialize retriever with prompt paths and search parameters."""
|
"""Initialize retriever with prompt paths and search parameters."""
|
||||||
self.save_interaction = save_interaction
|
self.save_interaction = save_interaction
|
||||||
self.user_prompt_path = user_prompt_path
|
self.user_prompt_path = user_prompt_path
|
||||||
self.system_prompt_path = system_prompt_path
|
self.system_prompt_path = system_prompt_path
|
||||||
self.system_prompt = system_prompt
|
self.system_prompt = system_prompt
|
||||||
self.only_context = only_context
|
|
||||||
self.top_k = top_k if top_k is not None else 5
|
self.top_k = top_k if top_k is not None else 5
|
||||||
self.node_type = node_type
|
self.node_type = node_type
|
||||||
self.node_name = node_name
|
self.node_name = node_name
|
||||||
|
|
@ -70,7 +67,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
"""
|
"""
|
||||||
return await resolve_edges_to_text(retrieved_edges)
|
return await resolve_edges_to_text(retrieved_edges)
|
||||||
|
|
||||||
async def get_triplets(self, query: str) -> list:
|
async def get_triplets(self, query: str) -> List[Edge]:
|
||||||
"""
|
"""
|
||||||
Retrieves relevant graph triplets based on a query string.
|
Retrieves relevant graph triplets based on a query string.
|
||||||
|
|
||||||
|
|
@ -85,7 +82,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
- list: A list of found triplets that match the query.
|
- list: A list of found triplets that match the query.
|
||||||
"""
|
"""
|
||||||
subclasses = get_all_subclasses(DataPoint)
|
subclasses = get_all_subclasses(DataPoint)
|
||||||
vector_index_collections = []
|
vector_index_collections: List[str] = []
|
||||||
|
|
||||||
for subclass in subclasses:
|
for subclass in subclasses:
|
||||||
if "metadata" in subclass.model_fields:
|
if "metadata" in subclass.model_fields:
|
||||||
|
|
@ -96,8 +93,11 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
for field_name in index_fields:
|
for field_name in index_fields:
|
||||||
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
||||||
|
|
||||||
|
user = await get_default_user()
|
||||||
|
|
||||||
found_triplets = await brute_force_triplet_search(
|
found_triplets = await brute_force_triplet_search(
|
||||||
query,
|
query,
|
||||||
|
user=user,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
collections=vector_index_collections or None,
|
collections=vector_index_collections or None,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
|
|
@ -106,7 +106,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
|
|
||||||
return found_triplets
|
return found_triplets
|
||||||
|
|
||||||
async def get_context(self, query: str) -> tuple[str, list]:
|
async def get_context(self, query: str) -> List[Edge]:
|
||||||
"""
|
"""
|
||||||
Retrieves and resolves graph triplets into context based on a query.
|
Retrieves and resolves graph triplets into context based on a query.
|
||||||
|
|
||||||
|
|
@ -125,17 +125,17 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
|
|
||||||
if len(triplets) == 0:
|
if len(triplets) == 0:
|
||||||
logger.warning("Empty context was provided to the completion")
|
logger.warning("Empty context was provided to the completion")
|
||||||
return "", triplets
|
return []
|
||||||
|
|
||||||
context = await self.resolve_edges_to_text(triplets)
|
# context = await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
return context, triplets
|
return triplets
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[Any] = None,
|
context: Optional[List[Edge]] = None,
|
||||||
) -> List[str]:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Generates a completion using graph connections context based on a query.
|
Generates a completion using graph connections context based on a query.
|
||||||
|
|
||||||
|
|
@ -151,26 +151,27 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
|
|
||||||
- Any: A generated completion based on the query and context provided.
|
- Any: A generated completion based on the query and context provided.
|
||||||
"""
|
"""
|
||||||
triplets = None
|
triplets = context
|
||||||
|
|
||||||
if context is None:
|
if triplets is None:
|
||||||
context, triplets = await self.get_context(query)
|
triplets = await self.get_context(query)
|
||||||
|
|
||||||
|
context_text = await resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
completion = await generate_completion(
|
completion = await generate_completion(
|
||||||
query=query,
|
query=query,
|
||||||
context=context,
|
context=context_text,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
only_context=self.only_context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context and triplets and completion:
|
||||||
await self.save_qa(
|
await self.save_qa(
|
||||||
question=query, answer=completion, context=context, triplets=triplets
|
question=query, answer=completion, context=context_text, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
return [completion]
|
return completion
|
||||||
|
|
||||||
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,18 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||||
|
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
||||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||||
|
|
||||||
logger = get_logger("InsightsRetriever")
|
logger = get_logger("InsightsRetriever")
|
||||||
|
|
||||||
|
|
||||||
class InsightsRetriever(BaseRetriever):
|
class InsightsRetriever(BaseGraphRetriever):
|
||||||
"""
|
"""
|
||||||
Retriever for handling graph connection-based insights.
|
Retriever for handling graph connection-based insights.
|
||||||
|
|
||||||
|
|
@ -95,7 +96,17 @@ class InsightsRetriever(BaseRetriever):
|
||||||
unique_node_connections_map[unique_id] = True
|
unique_node_connections_map[unique_id] = True
|
||||||
unique_node_connections.append(node_connection)
|
unique_node_connections.append(node_connection)
|
||||||
|
|
||||||
return unique_node_connections
|
return [
|
||||||
|
Edge(
|
||||||
|
node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
||||||
|
node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
||||||
|
attributes={
|
||||||
|
**connection[1],
|
||||||
|
"relationship_type": connection[1]["relationship_name"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for connection in unique_node_connections
|
||||||
|
]
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional, List, Tuple, Type
|
from typing import Any, Optional, List, Type
|
||||||
|
|
||||||
|
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
|
|
@ -113,8 +113,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
logger.info(
|
logger.info(
|
||||||
"No timestamps identified based on the query, performing retrieval using triplet search on events and entities."
|
"No timestamps identified based on the query, performing retrieval using triplet search on events and entities."
|
||||||
)
|
)
|
||||||
triplets = await self.get_triplets(query)
|
triplets = await self.get_context(query)
|
||||||
return await self.resolve_edges_to_text(triplets), triplets
|
return await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
if ids:
|
if ids:
|
||||||
relevant_events = await graph_engine.collect_events(ids=ids)
|
relevant_events = await graph_engine.collect_events(ids=ids)
|
||||||
|
|
@ -122,8 +122,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
logger.info(
|
logger.info(
|
||||||
"No events identified based on timestamp filtering, performing retrieval using triplet search on events and entities."
|
"No events identified based on timestamp filtering, performing retrieval using triplet search on events and entities."
|
||||||
)
|
)
|
||||||
triplets = await self.get_triplets(query)
|
triplets = await self.get_context(query)
|
||||||
return await self.resolve_edges_to_text(triplets), triplets
|
return await self.resolve_edges_to_text(triplets)
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
|
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
|
||||||
|
|
@ -134,18 +134,19 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
|
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
|
||||||
|
|
||||||
return self.descriptions_to_string(top_k_events), triplets
|
return self.descriptions_to_string(top_k_events)
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
|
async def get_completion(self, query: str, context: Optional[str] = None) -> str:
|
||||||
"""Generates a response using the query and optional context."""
|
"""Generates a response using the query and optional context."""
|
||||||
|
if not context:
|
||||||
|
context = await self.get_context(query=query)
|
||||||
|
|
||||||
context, triplets = await self.get_context(query=query)
|
if context:
|
||||||
|
completion = await generate_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
)
|
||||||
|
|
||||||
completion = await generate_completion(
|
return completion
|
||||||
query=query,
|
|
||||||
context=context,
|
|
||||||
user_prompt_path=self.user_prompt_path,
|
|
||||||
system_prompt_path=self.system_prompt_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [completion]
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFound
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
|
|
||||||
|
|
@ -87,41 +87,15 @@ async def get_memory_fragment(
|
||||||
|
|
||||||
|
|
||||||
async def brute_force_triplet_search(
|
async def brute_force_triplet_search(
|
||||||
query: str,
|
|
||||||
user: User = None,
|
|
||||||
top_k: int = 5,
|
|
||||||
collections: List[str] = None,
|
|
||||||
properties_to_project: List[str] = None,
|
|
||||||
memory_fragment: Optional[CogneeGraph] = None,
|
|
||||||
node_type: Optional[Type] = None,
|
|
||||||
node_name: Optional[List[str]] = None,
|
|
||||||
) -> list:
|
|
||||||
if user is None:
|
|
||||||
user = await get_default_user()
|
|
||||||
|
|
||||||
retrieved_results = await brute_force_search(
|
|
||||||
query,
|
|
||||||
user,
|
|
||||||
top_k,
|
|
||||||
collections=collections,
|
|
||||||
properties_to_project=properties_to_project,
|
|
||||||
memory_fragment=memory_fragment,
|
|
||||||
node_type=node_type,
|
|
||||||
node_name=node_name,
|
|
||||||
)
|
|
||||||
return retrieved_results
|
|
||||||
|
|
||||||
|
|
||||||
async def brute_force_search(
|
|
||||||
query: str,
|
query: str,
|
||||||
user: User,
|
user: User,
|
||||||
top_k: int,
|
top_k: int = 5,
|
||||||
collections: List[str] = None,
|
collections: Optional[List[str]] = None,
|
||||||
properties_to_project: List[str] = None,
|
properties_to_project: Optional[List[str]] = None,
|
||||||
memory_fragment: Optional[CogneeGraph] = None,
|
memory_fragment: Optional[CogneeGraph] = None,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = None,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
) -> list:
|
) -> List[Edge]:
|
||||||
"""
|
"""
|
||||||
Performs a brute force search to retrieve the top triplets from the graph.
|
Performs a brute force search to retrieve the top triplets from the graph.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ async def generate_completion(
|
||||||
user_prompt_path: str,
|
user_prompt_path: str,
|
||||||
system_prompt_path: str,
|
system_prompt_path: str,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
only_context: bool = False,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a completion using LLM with given context and prompts."""
|
"""Generates a completion using LLM with given context and prompts."""
|
||||||
args = {"question": query, "context": context}
|
args = {"question": query, "context": context}
|
||||||
|
|
@ -17,14 +16,11 @@ async def generate_completion(
|
||||||
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
|
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
if only_context:
|
return await LLMGateway.acreate_structured_output(
|
||||||
return context
|
text_input=user_prompt,
|
||||||
else:
|
system_prompt=system_prompt,
|
||||||
return await LLMGateway.acreate_structured_output(
|
response_model=str,
|
||||||
text_input=user_prompt,
|
)
|
||||||
system_prompt=system_prompt,
|
|
||||||
response_model=str,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def summarize_text(
|
async def summarize_text(
|
||||||
|
|
|
||||||
168
cognee/modules/search/methods/get_search_type_tools.py
Normal file
168
cognee/modules/search/methods/get_search_type_tools.py
Normal file
|
|
@ -0,0 +1,168 @@
|
||||||
|
from typing import Callable, List, Optional, Type
|
||||||
|
|
||||||
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
|
from cognee.modules.search.types import SearchType
|
||||||
|
from cognee.modules.search.operations import select_search_type
|
||||||
|
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||||
|
|
||||||
|
# Retrievers
|
||||||
|
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
||||||
|
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||||
|
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
|
||||||
|
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||||
|
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||||
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
|
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||||
|
from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever
|
||||||
|
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||||
|
GraphSummaryCompletionRetriever,
|
||||||
|
)
|
||||||
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||||
|
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||||
|
GraphCompletionContextExtensionRetriever,
|
||||||
|
)
|
||||||
|
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||||
|
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||||
|
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
||||||
|
|
||||||
|
|
||||||
|
async def get_search_type_tools(
|
||||||
|
query_type: SearchType,
|
||||||
|
query_text: str,
|
||||||
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
node_type: Optional[Type] = NodeSet,
|
||||||
|
node_name: Optional[List[str]] = None,
|
||||||
|
save_interaction: bool = False,
|
||||||
|
last_k: Optional[int] = None,
|
||||||
|
) -> list:
|
||||||
|
search_tasks: dict[SearchType, List[Callable]] = {
|
||||||
|
SearchType.SUMMARIES: [
|
||||||
|
SummariesRetriever(top_k=top_k).get_completion,
|
||||||
|
SummariesRetriever(top_k=top_k).get_context,
|
||||||
|
],
|
||||||
|
SearchType.INSIGHTS: [
|
||||||
|
InsightsRetriever(top_k=top_k).get_completion,
|
||||||
|
InsightsRetriever(top_k=top_k).get_context,
|
||||||
|
],
|
||||||
|
SearchType.CHUNKS: [
|
||||||
|
ChunksRetriever(top_k=top_k).get_completion,
|
||||||
|
ChunksRetriever(top_k=top_k).get_context,
|
||||||
|
],
|
||||||
|
SearchType.RAG_COMPLETION: [
|
||||||
|
CompletionRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
).get_completion,
|
||||||
|
CompletionRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
).get_context,
|
||||||
|
],
|
||||||
|
SearchType.GRAPH_COMPLETION: [
|
||||||
|
GraphCompletionRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
).get_completion,
|
||||||
|
GraphCompletionRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
).get_context,
|
||||||
|
],
|
||||||
|
SearchType.GRAPH_COMPLETION_COT: [
|
||||||
|
GraphCompletionCotRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
).get_completion,
|
||||||
|
GraphCompletionCotRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
).get_context,
|
||||||
|
],
|
||||||
|
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
|
||||||
|
GraphCompletionContextExtensionRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
).get_completion,
|
||||||
|
GraphCompletionContextExtensionRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
).get_context,
|
||||||
|
],
|
||||||
|
SearchType.GRAPH_SUMMARY_COMPLETION: [
|
||||||
|
GraphSummaryCompletionRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
).get_completion,
|
||||||
|
GraphSummaryCompletionRetriever(
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
).get_context,
|
||||||
|
],
|
||||||
|
SearchType.CODE: [
|
||||||
|
CodeRetriever(top_k=top_k).get_completion,
|
||||||
|
CodeRetriever(top_k=top_k).get_context,
|
||||||
|
],
|
||||||
|
SearchType.CYPHER: [
|
||||||
|
CypherSearchRetriever().get_completion,
|
||||||
|
CypherSearchRetriever().get_context,
|
||||||
|
],
|
||||||
|
SearchType.NATURAL_LANGUAGE: [
|
||||||
|
NaturalLanguageRetriever().get_completion,
|
||||||
|
NaturalLanguageRetriever().get_context,
|
||||||
|
],
|
||||||
|
SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
|
||||||
|
SearchType.TEMPORAL: [
|
||||||
|
TemporalRetriever(top_k=top_k).get_completion,
|
||||||
|
TemporalRetriever(top_k=top_k).get_context,
|
||||||
|
],
|
||||||
|
SearchType.CODING_RULES: [
|
||||||
|
CodingRulesRetriever(rules_nodeset_name=node_name).get_existing_rules,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# If the query type is FEELING_LUCKY, select the search type intelligently
|
||||||
|
if query_type is SearchType.FEELING_LUCKY:
|
||||||
|
query_type = await select_search_type(query_text)
|
||||||
|
|
||||||
|
search_type_tools = search_tasks.get(query_type)
|
||||||
|
|
||||||
|
if not search_type_tools:
|
||||||
|
raise UnsupportedSearchTypeError(str(query_type))
|
||||||
|
|
||||||
|
return search_type_tools
|
||||||
47
cognee/modules/search/methods/no_access_control_search.py
Normal file
47
cognee/modules/search/methods/no_access_control_search.py
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
from typing import Any, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
from cognee.modules.data.models.Dataset import Dataset
|
||||||
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
|
from cognee.modules.search.types import SearchType
|
||||||
|
|
||||||
|
from .get_search_type_tools import get_search_type_tools
|
||||||
|
|
||||||
|
|
||||||
|
async def no_access_control_search(
|
||||||
|
query_type: SearchType,
|
||||||
|
query_text: str,
|
||||||
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
node_type: Optional[Type] = NodeSet,
|
||||||
|
node_name: Optional[List[str]] = None,
|
||||||
|
save_interaction: bool = False,
|
||||||
|
last_k: Optional[int] = None,
|
||||||
|
only_context: bool = False,
|
||||||
|
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||||
|
search_tools = await get_search_type_tools(
|
||||||
|
query_type=query_type,
|
||||||
|
query_text=query_text,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
last_k=last_k,
|
||||||
|
)
|
||||||
|
if len(search_tools) == 2:
|
||||||
|
[get_completion, get_context] = search_tools
|
||||||
|
|
||||||
|
if only_context:
|
||||||
|
return await get_context(query_text)
|
||||||
|
|
||||||
|
context = await get_context(query_text)
|
||||||
|
result = await get_completion(query_text, context)
|
||||||
|
else:
|
||||||
|
unknown_tool = search_tools[0]
|
||||||
|
result = await unknown_tool(query_text)
|
||||||
|
context = ""
|
||||||
|
|
||||||
|
return result, context, []
|
||||||
|
|
@ -3,37 +3,27 @@ import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from typing import Callable, List, Optional, Type, Union
|
from typing import Any, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
from cognee.shared.utils import send_telemetry
|
||||||
|
from cognee.context_global_variables import set_database_global_context_variables
|
||||||
|
|
||||||
from cognee.modules.engine.models.node_set import NodeSet
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
from cognee.modules.search.types import (
|
||||||
from cognee.context_global_variables import set_database_global_context_variables
|
SearchResult,
|
||||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
CombinedSearchResult,
|
||||||
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
|
SearchResultDataset,
|
||||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
SearchType,
|
||||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
||||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
|
||||||
from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever
|
|
||||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
|
||||||
GraphSummaryCompletionRetriever,
|
|
||||||
)
|
)
|
||||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
from cognee.modules.search.operations import log_query, log_result
|
||||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
|
||||||
GraphCompletionContextExtensionRetriever,
|
|
||||||
)
|
|
||||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
|
||||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
|
||||||
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
|
||||||
from cognee.modules.search.types import SearchType
|
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.data.models import Dataset
|
from cognee.modules.data.models import Dataset
|
||||||
from cognee.shared.utils import send_telemetry
|
|
||||||
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
||||||
from cognee.modules.search.operations import log_query, log_result, select_search_type
|
|
||||||
|
from .get_search_type_tools import get_search_type_tools
|
||||||
|
from .no_access_control_search import no_access_control_search
|
||||||
|
from ..utils.prepare_search_result import prepare_search_result
|
||||||
|
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
|
|
@ -46,10 +36,11 @@ async def search(
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = NodeSet,
|
node_type: Optional[Type] = NodeSet,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: Optional[bool] = False,
|
save_interaction: bool = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
only_context: bool = False,
|
only_context: bool = False,
|
||||||
):
|
use_combined_context: bool = False,
|
||||||
|
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -65,9 +56,12 @@ async def search(
|
||||||
Notes:
|
Notes:
|
||||||
Searching by dataset is only available in ENABLE_BACKEND_ACCESS_CONTROL mode
|
Searching by dataset is only available in ENABLE_BACKEND_ACCESS_CONTROL mode
|
||||||
"""
|
"""
|
||||||
|
query = await log_query(query_text, query_type.value, user.id)
|
||||||
|
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||||
|
|
||||||
# Use search function filtered by permissions if access control is enabled
|
# Use search function filtered by permissions if access control is enabled
|
||||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||||
return await authorized_search(
|
search_results = await authorized_search(
|
||||||
query_type=query_type,
|
query_type=query_type,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
user=user,
|
user=user,
|
||||||
|
|
@ -80,119 +74,68 @@ async def search(
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
only_context=only_context,
|
only_context=only_context,
|
||||||
|
use_combined_context=use_combined_context,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
search_results = [
|
||||||
|
await no_access_control_search(
|
||||||
|
query_type=query_type,
|
||||||
|
query_text=query_text,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
last_k=last_k,
|
||||||
|
only_context=only_context,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
query = await log_query(query_text, query_type.value, user.id)
|
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
||||||
|
|
||||||
search_results = await specific_search(
|
|
||||||
query_type=query_type,
|
|
||||||
query_text=query_text,
|
|
||||||
user=user,
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
top_k=top_k,
|
|
||||||
node_type=node_type,
|
|
||||||
node_name=node_name,
|
|
||||||
save_interaction=save_interaction,
|
|
||||||
last_k=last_k,
|
|
||||||
only_context=only_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
await log_result(
|
await log_result(
|
||||||
query.id,
|
query.id,
|
||||||
json.dumps(
|
json.dumps(
|
||||||
search_results if len(search_results) > 1 else search_results[0], cls=JSONEncoder
|
jsonable_encoder(
|
||||||
|
await prepare_search_result(search_results)
|
||||||
|
if use_combined_context
|
||||||
|
else [
|
||||||
|
await prepare_search_result(search_result) for search_result in search_results
|
||||||
|
]
|
||||||
|
)
|
||||||
),
|
),
|
||||||
user.id,
|
user.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return search_results
|
if use_combined_context:
|
||||||
|
prepared_search_results = await prepare_search_result(search_results)
|
||||||
|
result = prepared_search_results["result"]
|
||||||
|
graphs = prepared_search_results["graphs"]
|
||||||
|
context = prepared_search_results["context"]
|
||||||
|
datasets = prepared_search_results["datasets"]
|
||||||
|
|
||||||
|
return CombinedSearchResult(
|
||||||
async def specific_search(
|
result=result,
|
||||||
query_type: SearchType,
|
graphs=graphs,
|
||||||
query_text: str,
|
context=context,
|
||||||
user: User,
|
datasets=[
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
SearchResultDataset(
|
||||||
system_prompt: Optional[str] = None,
|
id=dataset.id,
|
||||||
top_k: int = 10,
|
name=dataset.name,
|
||||||
node_type: Optional[Type] = NodeSet,
|
)
|
||||||
node_name: Optional[List[str]] = None,
|
for dataset in datasets
|
||||||
save_interaction: Optional[bool] = False,
|
],
|
||||||
last_k: Optional[int] = None,
|
)
|
||||||
only_context: bool = None,
|
else:
|
||||||
) -> list:
|
return [
|
||||||
search_tasks: dict[SearchType, Callable] = {
|
SearchResult(
|
||||||
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
search_result=result,
|
||||||
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
|
dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
|
||||||
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
|
dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
|
||||||
SearchType.RAG_COMPLETION: CompletionRetriever(
|
)
|
||||||
system_prompt_path=system_prompt_path,
|
for index, (result, _, datasets) in enumerate(search_results)
|
||||||
top_k=top_k,
|
]
|
||||||
system_prompt=system_prompt,
|
|
||||||
only_context=only_context,
|
|
||||||
).get_completion,
|
|
||||||
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
top_k=top_k,
|
|
||||||
node_type=node_type,
|
|
||||||
node_name=node_name,
|
|
||||||
save_interaction=save_interaction,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
only_context=only_context,
|
|
||||||
).get_completion,
|
|
||||||
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
top_k=top_k,
|
|
||||||
node_type=node_type,
|
|
||||||
node_name=node_name,
|
|
||||||
save_interaction=save_interaction,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
only_context=only_context,
|
|
||||||
).get_completion,
|
|
||||||
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
top_k=top_k,
|
|
||||||
node_type=node_type,
|
|
||||||
node_name=node_name,
|
|
||||||
save_interaction=save_interaction,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
only_context=only_context,
|
|
||||||
).get_completion,
|
|
||||||
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
top_k=top_k,
|
|
||||||
node_type=node_type,
|
|
||||||
node_name=node_name,
|
|
||||||
save_interaction=save_interaction,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
).get_completion,
|
|
||||||
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
|
|
||||||
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
|
||||||
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
|
|
||||||
SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback,
|
|
||||||
SearchType.TEMPORAL: TemporalRetriever(top_k=top_k).get_completion,
|
|
||||||
SearchType.CODING_RULES: CodingRulesRetriever(
|
|
||||||
rules_nodeset_name=node_name
|
|
||||||
).get_existing_rules,
|
|
||||||
}
|
|
||||||
|
|
||||||
# If the query type is FEELING_LUCKY, select the search type intelligently
|
|
||||||
if query_type is SearchType.FEELING_LUCKY:
|
|
||||||
query_type = await select_search_type(query_text)
|
|
||||||
|
|
||||||
search_task = search_tasks.get(query_type)
|
|
||||||
|
|
||||||
if search_task is None:
|
|
||||||
raise UnsupportedSearchTypeError(str(query_type))
|
|
||||||
|
|
||||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
|
||||||
|
|
||||||
results = await search_task(query_text)
|
|
||||||
|
|
||||||
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
async def authorized_search(
|
async def authorized_search(
|
||||||
|
|
@ -205,26 +148,85 @@ async def authorized_search(
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = NodeSet,
|
node_type: Optional[Type] = NodeSet,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: Optional[bool] = False,
|
save_interaction: bool = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
only_context: bool = None,
|
only_context: bool = False,
|
||||||
) -> list:
|
use_combined_context: bool = False,
|
||||||
|
) -> Union[
|
||||||
|
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
||||||
|
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
||||||
Not to be used outside of active access control mode.
|
Not to be used outside of active access control mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
query = await log_query(query_text, query_type.value, user.id)
|
|
||||||
|
|
||||||
# Find datasets user has read access for (if datasets are provided only return them. Provided user has read access)
|
# Find datasets user has read access for (if datasets are provided only return them. Provided user has read access)
|
||||||
search_datasets = await get_specific_user_permission_datasets(user.id, "read", dataset_ids)
|
search_datasets = await get_specific_user_permission_datasets(user.id, "read", dataset_ids)
|
||||||
|
|
||||||
|
if use_combined_context:
|
||||||
|
search_responses = await search_in_datasets_context(
|
||||||
|
search_datasets=search_datasets,
|
||||||
|
query_type=query_type,
|
||||||
|
query_text=query_text,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
last_k=last_k,
|
||||||
|
only_context=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
context = {}
|
||||||
|
datasets: List[Dataset] = []
|
||||||
|
|
||||||
|
for _, search_context, datasets in search_responses:
|
||||||
|
for dataset in datasets:
|
||||||
|
context[str(dataset.id)] = search_context
|
||||||
|
|
||||||
|
datasets.extend(datasets)
|
||||||
|
|
||||||
|
specific_search_tools = await get_search_type_tools(
|
||||||
|
query_type=query_type,
|
||||||
|
query_text=query_text,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
last_k=last_k,
|
||||||
|
)
|
||||||
|
search_tools = specific_search_tools
|
||||||
|
if len(search_tools) == 2:
|
||||||
|
[get_completion, _] = search_tools
|
||||||
|
else:
|
||||||
|
get_completion = search_tools[0]
|
||||||
|
|
||||||
|
def prepare_combined_context(
|
||||||
|
context,
|
||||||
|
) -> Union[List[Edge], str]:
|
||||||
|
combined_context = []
|
||||||
|
|
||||||
|
for dataset_context in context.values():
|
||||||
|
combined_context += dataset_context
|
||||||
|
|
||||||
|
if combined_context and isinstance(combined_context[0], str):
|
||||||
|
return "\n".join(combined_context)
|
||||||
|
|
||||||
|
return combined_context
|
||||||
|
|
||||||
|
combined_context = prepare_combined_context(context)
|
||||||
|
completion = await get_completion(query_text, combined_context)
|
||||||
|
|
||||||
|
return completion, combined_context, datasets
|
||||||
|
|
||||||
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
||||||
search_results = await specific_search_by_context(
|
search_results = await search_in_datasets_context(
|
||||||
search_datasets=search_datasets,
|
search_datasets=search_datasets,
|
||||||
query_type=query_type,
|
query_type=query_type,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
user=user,
|
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
|
@ -235,51 +237,48 @@ async def authorized_search(
|
||||||
only_context=only_context,
|
only_context=only_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
await log_result(query.id, json.dumps(jsonable_encoder(search_results)), user.id)
|
|
||||||
|
|
||||||
return search_results
|
return search_results
|
||||||
|
|
||||||
|
|
||||||
async def specific_search_by_context(
|
async def search_in_datasets_context(
|
||||||
search_datasets: list[Dataset],
|
search_datasets: list[Dataset],
|
||||||
query_type: SearchType,
|
query_type: SearchType,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
user: User,
|
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = NodeSet,
|
node_type: Optional[Type] = NodeSet,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: Optional[bool] = False,
|
save_interaction: bool = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
only_context: bool = None,
|
only_context: bool = False,
|
||||||
):
|
context: Optional[Any] = None,
|
||||||
|
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
|
||||||
"""
|
"""
|
||||||
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
||||||
Not to be used outside of active access control mode.
|
Not to be used outside of active access control mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _search_by_context(
|
async def _search_in_dataset_context(
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
query_type: SearchType,
|
query_type: SearchType,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
user: User,
|
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = NodeSet,
|
node_type: Optional[Type] = NodeSet,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: Optional[bool] = False,
|
save_interaction: bool = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
only_context: bool = None,
|
only_context: bool = False,
|
||||||
):
|
context: Optional[Any] = None,
|
||||||
|
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||||
# Set database configuration in async context for each dataset user has access for
|
# Set database configuration in async context for each dataset user has access for
|
||||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||||
|
|
||||||
result = await specific_search(
|
specific_search_tools = await get_search_type_tools(
|
||||||
query_type=query_type,
|
query_type=query_type,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
user=user,
|
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
|
@ -287,57 +286,31 @@ async def specific_search_by_context(
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
only_context=only_context,
|
|
||||||
)
|
)
|
||||||
|
search_tools = specific_search_tools
|
||||||
|
if len(search_tools) == 2:
|
||||||
|
[get_completion, get_context] = search_tools
|
||||||
|
|
||||||
if isinstance(result, tuple):
|
if only_context:
|
||||||
search_results = result[0]
|
return None, await get_context(query_text), [dataset]
|
||||||
triplets = result[1]
|
|
||||||
|
search_context = context or await get_context(query_text)
|
||||||
|
search_result = await get_completion(query_text, search_context)
|
||||||
|
|
||||||
|
return search_result, search_context, [dataset]
|
||||||
else:
|
else:
|
||||||
search_results = result
|
unknown_tool = search_tools[0]
|
||||||
triplets = []
|
|
||||||
|
|
||||||
return {
|
return await unknown_tool(query_text), "", [dataset]
|
||||||
"search_result": search_results,
|
|
||||||
"graph": [
|
|
||||||
{
|
|
||||||
"source": {
|
|
||||||
"id": triplet.node1.id,
|
|
||||||
"attributes": {
|
|
||||||
"name": triplet.node1.attributes["name"],
|
|
||||||
"type": triplet.node1.attributes["type"],
|
|
||||||
"description": triplet.node1.attributes["description"],
|
|
||||||
"vector_distance": triplet.node1.attributes["vector_distance"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"destination": {
|
|
||||||
"id": triplet.node2.id,
|
|
||||||
"attributes": {
|
|
||||||
"name": triplet.node2.attributes["name"],
|
|
||||||
"type": triplet.node2.attributes["type"],
|
|
||||||
"description": triplet.node2.attributes["description"],
|
|
||||||
"vector_distance": triplet.node2.attributes["vector_distance"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"attributes": {
|
|
||||||
"relationship_name": triplet.attributes["relationship_name"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for triplet in triplets
|
|
||||||
],
|
|
||||||
"dataset_id": dataset.id,
|
|
||||||
"dataset_name": dataset.name,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Search every dataset async based on query and appropriate database configuration
|
# Search every dataset async based on query and appropriate database configuration
|
||||||
tasks = []
|
tasks = []
|
||||||
for dataset in search_datasets:
|
for dataset in search_datasets:
|
||||||
tasks.append(
|
tasks.append(
|
||||||
_search_by_context(
|
_search_in_dataset_context(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
query_type=query_type,
|
query_type=query_type,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
user=user,
|
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
|
@ -346,6 +319,7 @@ async def specific_search_by_context(
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
only_context=only_context,
|
only_context=only_context,
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
21
cognee/modules/search/types/SearchResult.py
Normal file
21
cognee/modules/search/types/SearchResult.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResultDataset(BaseModel):
|
||||||
|
id: UUID
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedSearchResult(BaseModel):
|
||||||
|
result: Optional[Any]
|
||||||
|
context: Dict[str, Any]
|
||||||
|
graphs: Optional[Dict[str, Any]] = {}
|
||||||
|
datasets: Optional[List[SearchResultDataset]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(BaseModel):
|
||||||
|
search_result: Any
|
||||||
|
dataset_id: Optional[UUID]
|
||||||
|
dataset_name: Optional[str]
|
||||||
|
|
@ -1 +1,2 @@
|
||||||
from .SearchType import SearchType
|
from .SearchType import SearchType
|
||||||
|
from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult
|
||||||
|
|
|
||||||
2
cognee/modules/search/utils/__init__.py
Normal file
2
cognee/modules/search/utils/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .prepare_search_result import prepare_search_result
|
||||||
|
from .transform_context_to_graph import transform_context_to_graph
|
||||||
41
cognee/modules/search/utils/prepare_search_result.py
Normal file
41
cognee/modules/search/utils/prepare_search_result.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
from typing import List, cast
|
||||||
|
|
||||||
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
|
from cognee.modules.search.utils.transform_context_to_graph import transform_context_to_graph
|
||||||
|
|
||||||
|
|
||||||
|
async def prepare_search_result(search_result):
|
||||||
|
result, context, datasets = search_result
|
||||||
|
|
||||||
|
graphs = None
|
||||||
|
result_graph = None
|
||||||
|
context_texts = {}
|
||||||
|
|
||||||
|
if isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge):
|
||||||
|
result_graph = transform_context_to_graph(context)
|
||||||
|
|
||||||
|
graphs = {
|
||||||
|
"*": result_graph,
|
||||||
|
}
|
||||||
|
context_texts = {
|
||||||
|
"*": await resolve_edges_to_text(context),
|
||||||
|
}
|
||||||
|
elif isinstance(context, str):
|
||||||
|
context_texts = {
|
||||||
|
"*": context,
|
||||||
|
}
|
||||||
|
elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], str):
|
||||||
|
context_texts = {
|
||||||
|
"*": "\n".join(cast(List[str], context)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(result, List) and len(result) > 0 and isinstance(result[0], Edge):
|
||||||
|
result_graph = transform_context_to_graph(result)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"result": result_graph or result,
|
||||||
|
"graphs": graphs,
|
||||||
|
"context": context_texts,
|
||||||
|
"datasets": datasets,
|
||||||
|
}
|
||||||
38
cognee/modules/search/utils/transform_context_to_graph.py
Normal file
38
cognee/modules/search/utils/transform_context_to_graph.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
|
|
||||||
|
|
||||||
|
def transform_context_to_graph(context: List[Edge]):
|
||||||
|
nodes = {}
|
||||||
|
edges = {}
|
||||||
|
|
||||||
|
for triplet in context:
|
||||||
|
nodes[triplet.node1.id] = {
|
||||||
|
"id": triplet.node1.id,
|
||||||
|
"label": triplet.node1.attributes["name"]
|
||||||
|
if "name" in triplet.node1.attributes
|
||||||
|
else triplet.node1.id,
|
||||||
|
"type": triplet.node1.attributes["type"],
|
||||||
|
"attributes": triplet.node2.attributes,
|
||||||
|
}
|
||||||
|
nodes[triplet.node2.id] = {
|
||||||
|
"id": triplet.node2.id,
|
||||||
|
"label": triplet.node2.attributes["name"]
|
||||||
|
if "name" in triplet.node2.attributes
|
||||||
|
else triplet.node2.id,
|
||||||
|
"type": triplet.node2.attributes["type"],
|
||||||
|
"attributes": triplet.node2.attributes,
|
||||||
|
}
|
||||||
|
edges[
|
||||||
|
f"{triplet.node1.id}_{triplet.attributes['relationship_name']}_{triplet.node2.id}"
|
||||||
|
] = {
|
||||||
|
"source": triplet.node1.id,
|
||||||
|
"target": triplet.node2.id,
|
||||||
|
"label": triplet.attributes["relationship_name"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"nodes": list(nodes.values()),
|
||||||
|
"edges": list(edges.values()),
|
||||||
|
}
|
||||||
|
|
@ -31,7 +31,7 @@ class RuleSet(DataPoint):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_existing_rules(rules_nodeset_name: str, return_list: bool = False) -> str:
|
async def get_existing_rules(rules_nodeset_name: str) -> List[str]:
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
nodes_data, _ = await graph_engine.get_nodeset_subgraph(
|
nodes_data, _ = await graph_engine.get_nodeset_subgraph(
|
||||||
node_type=NodeSet, node_name=[rules_nodeset_name]
|
node_type=NodeSet, node_name=[rules_nodeset_name]
|
||||||
|
|
@ -46,9 +46,6 @@ async def get_existing_rules(rules_nodeset_name: str, return_list: bool = False)
|
||||||
and "text" in item[1]
|
and "text" in item[1]
|
||||||
]
|
]
|
||||||
|
|
||||||
if not return_list:
|
|
||||||
existing_rules = "\n".join(f"- {rule}" for rule in existing_rules)
|
|
||||||
|
|
||||||
return existing_rules
|
return existing_rules
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -103,6 +100,7 @@ async def add_rule_associations(
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
existing_rules = await get_existing_rules(rules_nodeset_name=rules_nodeset_name)
|
existing_rules = await get_existing_rules(rules_nodeset_name=rules_nodeset_name)
|
||||||
|
existing_rules = "\n".join(f"- {rule}" for rule in existing_rules)
|
||||||
|
|
||||||
user_context = {"chat": data, "rules": existing_rules}
|
user_context = {"chat": data, "rules": existing_rules}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -94,21 +94,21 @@ async def main():
|
||||||
|
|
||||||
await cognee.cognify([dataset_name])
|
await cognee.cognify([dataset_name])
|
||||||
|
|
||||||
context_nonempty, _ = await GraphCompletionRetriever(
|
context_nonempty = await GraphCompletionRetriever(
|
||||||
node_type=NodeSet,
|
node_type=NodeSet,
|
||||||
node_name=["first"],
|
node_name=["first"],
|
||||||
).get_context("What is in the context?")
|
).get_context("What is in the context?")
|
||||||
|
|
||||||
context_empty, _ = await GraphCompletionRetriever(
|
context_empty = await GraphCompletionRetriever(
|
||||||
node_type=NodeSet,
|
node_type=NodeSet,
|
||||||
node_name=["nonexistent"],
|
node_name=["nonexistent"],
|
||||||
).get_context("What is in the context?")
|
).get_context("What is in the context?")
|
||||||
|
|
||||||
assert isinstance(context_nonempty, str) and context_nonempty != "", (
|
assert isinstance(context_nonempty, list) and context_nonempty != [], (
|
||||||
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert context_empty == "", (
|
assert context_empty == [], (
|
||||||
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -98,21 +98,21 @@ async def main():
|
||||||
|
|
||||||
await cognee.cognify([dataset_name])
|
await cognee.cognify([dataset_name])
|
||||||
|
|
||||||
context_nonempty, _ = await GraphCompletionRetriever(
|
context_nonempty = await GraphCompletionRetriever(
|
||||||
node_type=NodeSet,
|
node_type=NodeSet,
|
||||||
node_name=["first"],
|
node_name=["first"],
|
||||||
).get_context("What is in the context?")
|
).get_context("What is in the context?")
|
||||||
|
|
||||||
context_empty, _ = await GraphCompletionRetriever(
|
context_empty = await GraphCompletionRetriever(
|
||||||
node_type=NodeSet,
|
node_type=NodeSet,
|
||||||
node_name=["nonexistent"],
|
node_name=["nonexistent"],
|
||||||
).get_context("What is in the context?")
|
).get_context("What is in the context?")
|
||||||
|
|
||||||
assert isinstance(context_nonempty, str) and context_nonempty != "", (
|
assert isinstance(context_nonempty, list) and context_nonempty != [], (
|
||||||
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert context_empty == "", (
|
assert context_empty == [], (
|
||||||
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,7 @@ async def main():
|
||||||
print("\n\nExtracted sentences are:\n")
|
print("\n\nExtracted sentences are:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
assert search_results[0]["dataset_name"] == "NLP", (
|
assert search_results[0].dataset_name == "NLP", (
|
||||||
f"Dict must contain dataset name 'NLP': {search_results[0]}"
|
f"Dict must contain dataset name 'NLP': {search_results[0]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -93,7 +93,7 @@ async def main():
|
||||||
print("\n\nExtracted sentences are:\n")
|
print("\n\nExtracted sentences are:\n")
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
assert search_results[0].dataset_name == "QUANTUM", (
|
||||||
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -170,7 +170,7 @@ async def main():
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
print(f"{result}\n")
|
print(f"{result}\n")
|
||||||
|
|
||||||
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
assert search_results[0].dataset_name == "QUANTUM", (
|
||||||
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import os
|
import os
|
||||||
|
from typing import List
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.relational import (
|
from cognee.infrastructure.databases.relational import (
|
||||||
get_migration_relational_engine,
|
get_migration_relational_engine,
|
||||||
|
|
@ -10,7 +10,7 @@ from cognee.infrastructure.databases.vector.pgvector import (
|
||||||
create_db_and_tables as create_pgvector_db_and_tables,
|
create_db_and_tables as create_pgvector_db_and_tables,
|
||||||
)
|
)
|
||||||
from cognee.tasks.ingestion import migrate_relational_database
|
from cognee.tasks.ingestion import migrate_relational_database
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchResult, SearchType
|
||||||
import cognee
|
import cognee
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -45,13 +45,15 @@ async def relational_db_migration():
|
||||||
await migrate_relational_database(graph_engine, schema=schema)
|
await migrate_relational_database(graph_engine, schema=schema)
|
||||||
|
|
||||||
# 1. Search the graph
|
# 1. Search the graph
|
||||||
search_results = await cognee.search(
|
search_results: List[SearchResult] = await cognee.search(
|
||||||
query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC"
|
query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC"
|
||||||
)
|
) # type: ignore
|
||||||
print("Search results:", search_results)
|
print("Search results:", search_results)
|
||||||
|
|
||||||
# 2. Assert that the search results contain "AC/DC"
|
# 2. Assert that the search results contain "AC/DC"
|
||||||
assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!"
|
assert any("AC/DC" in r.search_result for r in search_results), (
|
||||||
|
"AC/DC not found in search results!"
|
||||||
|
)
|
||||||
|
|
||||||
migration_db_provider = migration_engine.engine.dialect.name
|
migration_db_provider = migration_engine.engine.dialect.name
|
||||||
if migration_db_provider == "postgresql":
|
if migration_db_provider == "postgresql":
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,7 @@
|
||||||
import os
|
|
||||||
import pathlib
|
|
||||||
|
|
||||||
from dns.e164 import query
|
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||||
GraphCompletionContextExtensionRetriever,
|
GraphCompletionContextExtensionRetriever,
|
||||||
|
|
@ -14,11 +10,8 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
|
||||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||||
GraphSummaryCompletionRetriever,
|
GraphSummaryCompletionRetriever,
|
||||||
)
|
)
|
||||||
from cognee.modules.search.operations import get_history
|
|
||||||
from cognee.modules.users.methods import get_default_user
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.modules.engine.models import NodeSet
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
@ -46,16 +39,16 @@ async def main():
|
||||||
|
|
||||||
await cognee.cognify([dataset_name])
|
await cognee.cognify([dataset_name])
|
||||||
|
|
||||||
context_gk, _ = await GraphCompletionRetriever().get_context(
|
context_gk = await GraphCompletionRetriever().get_context(
|
||||||
query="Next to which country is Germany located?"
|
query="Next to which country is Germany located?"
|
||||||
)
|
)
|
||||||
context_gk_cot, _ = await GraphCompletionCotRetriever().get_context(
|
context_gk_cot = await GraphCompletionCotRetriever().get_context(
|
||||||
query="Next to which country is Germany located?"
|
query="Next to which country is Germany located?"
|
||||||
)
|
)
|
||||||
context_gk_ext, _ = await GraphCompletionContextExtensionRetriever().get_context(
|
context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(
|
||||||
query="Next to which country is Germany located?"
|
query="Next to which country is Germany located?"
|
||||||
)
|
)
|
||||||
context_gk_sum, _ = await GraphSummaryCompletionRetriever().get_context(
|
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
|
||||||
query="Next to which country is Germany located?"
|
query="Next to which country is Germany located?"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -65,9 +58,11 @@ async def main():
|
||||||
("GraphCompletionContextExtensionRetriever", context_gk_ext),
|
("GraphCompletionContextExtensionRetriever", context_gk_ext),
|
||||||
("GraphSummaryCompletionRetriever", context_gk_sum),
|
("GraphSummaryCompletionRetriever", context_gk_sum),
|
||||||
]:
|
]:
|
||||||
assert isinstance(context, str), f"{name}: Context should be a string"
|
assert isinstance(context, list), f"{name}: Context should be a list"
|
||||||
assert context.strip(), f"{name}: Context should not be empty"
|
assert len(context) > 0, f"{name}: Context should not be empty"
|
||||||
lower = context.lower()
|
|
||||||
|
context_text = await resolve_edges_to_text(context)
|
||||||
|
lower = context_text.lower()
|
||||||
assert "germany" in lower or "netherlands" in lower, (
|
assert "germany" in lower or "netherlands" in lower, (
|
||||||
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
|
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
|
||||||
)
|
)
|
||||||
|
|
@ -143,20 +138,19 @@ async def main():
|
||||||
last_k=1,
|
last_k=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, completion in [
|
for name, search_results in [
|
||||||
("GRAPH_COMPLETION", completion_gk),
|
("GRAPH_COMPLETION", completion_gk),
|
||||||
("GRAPH_COMPLETION_COT", completion_cot),
|
("GRAPH_COMPLETION_COT", completion_cot),
|
||||||
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
|
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
|
||||||
("GRAPH_SUMMARY_COMPLETION", completion_sum),
|
("GRAPH_SUMMARY_COMPLETION", completion_sum),
|
||||||
]:
|
]:
|
||||||
assert isinstance(completion, list), f"{name}: should return a list"
|
for search_result in search_results:
|
||||||
assert len(completion) == 1, f"{name}: expected single-element list, got {len(completion)}"
|
completion = search_result.search_result
|
||||||
text = completion[0]
|
assert isinstance(completion, str), f"{name}: should return a string"
|
||||||
assert isinstance(text, str), f"{name}: element should be a string"
|
assert completion.strip(), f"{name}: string should not be empty"
|
||||||
assert text.strip(), f"{name}: string should not be empty"
|
assert "netherlands" in completion.lower(), (
|
||||||
assert "netherlands" in text.lower(), (
|
f"{name}: expected 'netherlands' in result, got: {completion!r}"
|
||||||
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
graph = await graph_engine.get_graph_data()
|
graph = await graph_engine.get_graph_data()
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from typing import Optional, Union
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.low_level import setup, DataPoint
|
from cognee.low_level import setup, DataPoint
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||||
GraphCompletionContextExtensionRetriever,
|
GraphCompletionContextExtensionRetriever,
|
||||||
|
|
@ -51,17 +52,15 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
||||||
|
|
||||||
retriever = GraphCompletionContextExtensionRetriever()
|
retriever = GraphCompletionContextExtensionRetriever()
|
||||||
|
|
||||||
context, _ = await retriever.get_context("Who works at Canva?")
|
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||||
|
|
||||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||||
|
|
||||||
answer = await retriever.get_completion("Who works at Canva?")
|
answer = await retriever.get_completion("Who works at Canva?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||||
"Answer must contain only non-empty strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graph_completion_extension_context_complex(self):
|
async def test_graph_completion_extension_context_complex(self):
|
||||||
|
|
@ -129,7 +128,9 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
||||||
|
|
||||||
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
|
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
|
||||||
|
|
||||||
context, _ = await retriever.get_context("Who works at Figma?")
|
context = await resolve_edges_to_text(
|
||||||
|
await retriever.get_context("Who works at Figma and drives Tesla?")
|
||||||
|
)
|
||||||
|
|
||||||
print(context)
|
print(context)
|
||||||
|
|
||||||
|
|
@ -139,10 +140,8 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
||||||
|
|
||||||
answer = await retriever.get_completion("Who works at Figma?")
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||||
"Answer must contain only non-empty strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
||||||
|
|
@ -167,12 +166,10 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
||||||
|
|
||||||
await setup()
|
await setup()
|
||||||
|
|
||||||
context, _ = await retriever.get_context("Who works at Figma?")
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
assert context == "", "Context should be empty on an empty graph"
|
assert context == [], "Context should be empty on an empty graph"
|
||||||
|
|
||||||
answer = await retriever.get_completion("Who works at Figma?")
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||||
"Answer must contain only non-empty strings"
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.low_level import setup, DataPoint
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||||
|
|
@ -47,17 +48,15 @@ class TestGraphCompletionCoTRetriever:
|
||||||
|
|
||||||
retriever = GraphCompletionCotRetriever()
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
context, _ = await retriever.get_context("Who works at Canva?")
|
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||||
|
|
||||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||||
|
|
||||||
answer = await retriever.get_completion("Who works at Canva?")
|
answer = await retriever.get_completion("Who works at Canva?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||||
"Answer must contain only non-empty strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graph_completion_cot_context_complex(self):
|
async def test_graph_completion_cot_context_complex(self):
|
||||||
|
|
@ -124,7 +123,7 @@ class TestGraphCompletionCoTRetriever:
|
||||||
|
|
||||||
retriever = GraphCompletionCotRetriever(top_k=20)
|
retriever = GraphCompletionCotRetriever(top_k=20)
|
||||||
|
|
||||||
context, _ = await retriever.get_context("Who works at Figma?")
|
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
||||||
|
|
||||||
print(context)
|
print(context)
|
||||||
|
|
||||||
|
|
@ -134,10 +133,8 @@ class TestGraphCompletionCoTRetriever:
|
||||||
|
|
||||||
answer = await retriever.get_completion("Who works at Figma?")
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||||
"Answer must contain only non-empty strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
||||||
|
|
@ -162,12 +159,10 @@ class TestGraphCompletionCoTRetriever:
|
||||||
|
|
||||||
await setup()
|
await setup()
|
||||||
|
|
||||||
context, _ = await retriever.get_context("Who works at Figma?")
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
assert context == "", "Context should be empty on an empty graph"
|
assert context == [], "Context should be empty on an empty graph"
|
||||||
|
|
||||||
answer = await retriever.get_completion("Who works at Figma?")
|
answer = await retriever.get_completion("Who works at Figma?")
|
||||||
|
|
||||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||||
"Answer must contain only non-empty strings"
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.low_level import setup, DataPoint
|
from cognee.low_level import setup, DataPoint
|
||||||
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
|
|
@ -67,7 +68,7 @@ class TestGraphCompletionRetriever:
|
||||||
|
|
||||||
retriever = GraphCompletionRetriever()
|
retriever = GraphCompletionRetriever()
|
||||||
|
|
||||||
context, _ = await retriever.get_context("Who works at Canva?")
|
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||||
|
|
||||||
# Ensure the top-level sections are present
|
# Ensure the top-level sections are present
|
||||||
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
|
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
|
||||||
|
|
@ -191,7 +192,7 @@ class TestGraphCompletionRetriever:
|
||||||
|
|
||||||
retriever = GraphCompletionRetriever(top_k=20)
|
retriever = GraphCompletionRetriever(top_k=20)
|
||||||
|
|
||||||
context, _ = await retriever.get_context("Who works at Figma?")
|
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
||||||
|
|
||||||
print(context)
|
print(context)
|
||||||
|
|
||||||
|
|
@ -222,5 +223,5 @@ class TestGraphCompletionRetriever:
|
||||||
|
|
||||||
await setup()
|
await setup()
|
||||||
|
|
||||||
context, _ = await retriever.get_context("Who works at Figma?")
|
context = await retriever.get_context("Who works at Figma?")
|
||||||
assert context == "", "Context should be empty on an empty graph"
|
assert context == [], "Context should be empty on an empty graph"
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ class TestInsightsRetriever:
|
||||||
|
|
||||||
context = await retriever.get_context("Mike")
|
context = await retriever.get_context("Mike")
|
||||||
|
|
||||||
assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski"
|
assert context[0].node1.attributes["name"] == "Mike Broski", "Failed to get Mike Broski"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_insights_context_complex(self):
|
async def test_insights_context_complex(self):
|
||||||
|
|
@ -222,7 +222,9 @@ class TestInsightsRetriever:
|
||||||
|
|
||||||
context = await retriever.get_context("Christina")
|
context = await retriever.get_context("Christina")
|
||||||
|
|
||||||
assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer"
|
assert context[0].node1.attributes["name"] == "Christina Mayer", (
|
||||||
|
"Failed to get Christina Mayer"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_insights_context_on_empty_graph(self):
|
async def test_insights_context_on_empty_graph(self):
|
||||||
|
|
|
||||||
|
|
@ -1,230 +0,0 @@
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from cognee.modules.engine.models.node_set import NodeSet
|
|
||||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
|
||||||
from cognee.modules.search.methods.search import search, specific_search
|
|
||||||
from cognee.modules.search.types import SearchType
|
|
||||||
from cognee.modules.users.models import User
|
|
||||||
import sys
|
|
||||||
|
|
||||||
search_module = sys.modules.get("cognee.modules.search.methods.search")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_user():
|
|
||||||
user = MagicMock(spec=User)
|
|
||||||
user.id = uuid.uuid4()
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch.object(search_module, "log_query")
|
|
||||||
@patch.object(search_module, "log_result")
|
|
||||||
@patch.object(search_module, "specific_search")
|
|
||||||
async def test_search(
|
|
||||||
mock_specific_search,
|
|
||||||
mock_log_result,
|
|
||||||
mock_log_query,
|
|
||||||
mock_user,
|
|
||||||
):
|
|
||||||
# Setup
|
|
||||||
query_text = "test query"
|
|
||||||
query_type = SearchType.CHUNKS
|
|
||||||
datasets = ["dataset1", "dataset2"]
|
|
||||||
|
|
||||||
# Mock the query logging
|
|
||||||
mock_query = MagicMock()
|
|
||||||
mock_query.id = uuid.uuid4()
|
|
||||||
mock_log_query.return_value = mock_query
|
|
||||||
|
|
||||||
# Mock document IDs
|
|
||||||
doc_id1 = uuid.uuid4()
|
|
||||||
doc_id2 = uuid.uuid4()
|
|
||||||
|
|
||||||
# Mock search results
|
|
||||||
search_results = [
|
|
||||||
{"document_id": str(doc_id1), "content": "Result 1"},
|
|
||||||
{"document_id": str(doc_id2), "content": "Result 2"},
|
|
||||||
]
|
|
||||||
mock_specific_search.return_value = search_results
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
await search(query_text, query_type, datasets, mock_user)
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
|
|
||||||
mock_specific_search.assert_called_once_with(
|
|
||||||
query_type=query_type,
|
|
||||||
query_text=query_text,
|
|
||||||
user=mock_user,
|
|
||||||
system_prompt_path="answer_simple_question.txt",
|
|
||||||
system_prompt=None,
|
|
||||||
top_k=10,
|
|
||||||
node_type=NodeSet,
|
|
||||||
node_name=None,
|
|
||||||
save_interaction=False,
|
|
||||||
last_k=None,
|
|
||||||
only_context=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify result logging
|
|
||||||
mock_log_result.assert_called_once()
|
|
||||||
# Check that the first argument is the query ID
|
|
||||||
assert mock_log_result.call_args[0][0] == mock_query.id
|
|
||||||
# The second argument should be the JSON string of the filtered results
|
|
||||||
# We can't directly compare the JSON strings due to potential ordering differences
|
|
||||||
# So we parse the JSON and compare the objects
|
|
||||||
logged_results = json.loads(mock_log_result.call_args[0][1])
|
|
||||||
assert len(logged_results) == 2
|
|
||||||
assert logged_results[0]["document_id"] == str(doc_id1)
|
|
||||||
assert logged_results[1]["document_id"] == str(doc_id2)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch.object(search_module, "SummariesRetriever")
|
|
||||||
@patch.object(search_module, "send_telemetry")
|
|
||||||
async def test_specific_search_summaries(mock_send_telemetry, mock_summaries_retriever, mock_user):
|
|
||||||
# Setup
|
|
||||||
query = "test query"
|
|
||||||
query_type = SearchType.SUMMARIES
|
|
||||||
|
|
||||||
# Mock the retriever
|
|
||||||
mock_retriever = MagicMock()
|
|
||||||
mock_retriever.get_completion = AsyncMock()
|
|
||||||
mock_retriever.get_completion.return_value = [{"content": "Summary result"}]
|
|
||||||
mock_summaries_retriever.return_value = mock_retriever
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
results = await specific_search(query_type, query, mock_user)
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
mock_summaries_retriever.assert_called_once()
|
|
||||||
mock_retriever.get_completion.assert_called_once_with(query)
|
|
||||||
mock_send_telemetry.assert_called()
|
|
||||||
assert len(results) == 1
|
|
||||||
assert results[0]["content"] == "Summary result"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch.object(search_module, "InsightsRetriever")
|
|
||||||
@patch.object(search_module, "send_telemetry")
|
|
||||||
async def test_specific_search_insights(mock_send_telemetry, mock_insights_retriever, mock_user):
|
|
||||||
# Setup
|
|
||||||
query = "test query"
|
|
||||||
query_type = SearchType.INSIGHTS
|
|
||||||
|
|
||||||
# Mock the retriever
|
|
||||||
mock_retriever = MagicMock()
|
|
||||||
mock_retriever.get_completion = AsyncMock()
|
|
||||||
mock_retriever.get_completion.return_value = [{"content": "Insight result"}]
|
|
||||||
mock_insights_retriever.return_value = mock_retriever
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
results = await specific_search(query_type, query, mock_user)
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
mock_insights_retriever.assert_called_once()
|
|
||||||
mock_retriever.get_completion.assert_called_once_with(query)
|
|
||||||
mock_send_telemetry.assert_called()
|
|
||||||
assert len(results) == 1
|
|
||||||
assert results[0]["content"] == "Insight result"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch.object(search_module, "ChunksRetriever")
|
|
||||||
@patch.object(search_module, "send_telemetry")
|
|
||||||
async def test_specific_search_chunks(mock_send_telemetry, mock_chunks_retriever, mock_user):
|
|
||||||
# Setup
|
|
||||||
query = "test query"
|
|
||||||
query_type = SearchType.CHUNKS
|
|
||||||
|
|
||||||
# Mock the retriever
|
|
||||||
mock_retriever = MagicMock()
|
|
||||||
mock_retriever.get_completion = AsyncMock()
|
|
||||||
mock_retriever.get_completion.return_value = [{"content": "Chunk result"}]
|
|
||||||
mock_chunks_retriever.return_value = mock_retriever
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
results = await specific_search(query_type, query, mock_user)
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
mock_chunks_retriever.assert_called_once()
|
|
||||||
mock_retriever.get_completion.assert_called_once_with(query)
|
|
||||||
mock_send_telemetry.assert_called()
|
|
||||||
assert len(results) == 1
|
|
||||||
assert results[0]["content"] == "Chunk result"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"selected_type, retriever_name, expected_content, top_k",
|
|
||||||
[
|
|
||||||
(SearchType.RAG_COMPLETION, "CompletionRetriever", "RAG result from lucky search", 10),
|
|
||||||
(SearchType.CHUNKS, "ChunksRetriever", "Chunk result from lucky search", 5),
|
|
||||||
(SearchType.SUMMARIES, "SummariesRetriever", "Summary from lucky search", 15),
|
|
||||||
(SearchType.INSIGHTS, "InsightsRetriever", "Insight result from lucky search", 20),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@patch.object(search_module, "select_search_type")
|
|
||||||
@patch.object(search_module, "send_telemetry")
|
|
||||||
async def test_specific_search_feeling_lucky(
|
|
||||||
mock_send_telemetry,
|
|
||||||
mock_select_search_type,
|
|
||||||
selected_type,
|
|
||||||
retriever_name,
|
|
||||||
expected_content,
|
|
||||||
top_k,
|
|
||||||
mock_user,
|
|
||||||
):
|
|
||||||
with patch.object(search_module, retriever_name) as mock_retriever_class:
|
|
||||||
# Setup
|
|
||||||
query = f"test query for {retriever_name}"
|
|
||||||
query_type = SearchType.FEELING_LUCKY
|
|
||||||
|
|
||||||
# Mock the intelligent search type selection
|
|
||||||
mock_select_search_type.return_value = selected_type
|
|
||||||
|
|
||||||
# Mock the retriever
|
|
||||||
mock_retriever_instance = MagicMock()
|
|
||||||
mock_retriever_instance.get_completion = AsyncMock(
|
|
||||||
return_value=[{"content": expected_content}]
|
|
||||||
)
|
|
||||||
mock_retriever_class.return_value = mock_retriever_instance
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
results = await specific_search(query_type, query, mock_user, top_k=top_k)
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
mock_select_search_type.assert_called_once_with(query)
|
|
||||||
|
|
||||||
if retriever_name == "CompletionRetriever":
|
|
||||||
mock_retriever_class.assert_called_once_with(
|
|
||||||
system_prompt_path="answer_simple_question.txt",
|
|
||||||
top_k=top_k,
|
|
||||||
system_prompt=None,
|
|
||||||
only_context=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
mock_retriever_class.assert_called_once_with(top_k=top_k)
|
|
||||||
|
|
||||||
mock_retriever_instance.get_completion.assert_called_once_with(query)
|
|
||||||
mock_send_telemetry.assert_called()
|
|
||||||
assert len(results) == 1
|
|
||||||
assert results[0]["content"] == expected_content
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_specific_search_invalid_type(mock_user):
|
|
||||||
# Setup
|
|
||||||
query = "test query"
|
|
||||||
query_type = "INVALID_TYPE" # Not a valid SearchType
|
|
||||||
|
|
||||||
# Execute and verify
|
|
||||||
with pytest.raises(UnsupportedSearchTypeError) as excinfo:
|
|
||||||
await specific_search(query_type, query, mock_user)
|
|
||||||
|
|
||||||
assert "Unsupported search type" in str(excinfo.value)
|
|
||||||
|
|
@ -47,6 +47,7 @@ async def main():
|
||||||
query = "When was Kamala Harris in office?"
|
query = "When was Kamala Harris in office?"
|
||||||
triplets = await brute_force_triplet_search(
|
triplets = await brute_force_triplet_search(
|
||||||
query=query,
|
query=query,
|
||||||
|
user=user,
|
||||||
top_k=3,
|
top_k=3,
|
||||||
collections=["graphitinode_content", "graphitinode_name", "graphitinode_summary"],
|
collections=["graphitinode_content", "graphitinode_name", "graphitinode_summary"],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue