feat: implement combined context search (#1341)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
ba33dca592
commit
b1643414d2
36 changed files with 706 additions and 642 deletions
|
|
@ -27,6 +27,7 @@ class SearchPayloadDTO(InDTO):
|
|||
node_name: Optional[list[str]] = Field(default=None, example=[])
|
||||
top_k: Optional[int] = Field(default=10)
|
||||
only_context: bool = Field(default=False)
|
||||
use_combined_context: bool = Field(default=False)
|
||||
|
||||
|
||||
def get_search_router() -> APIRouter:
|
||||
|
|
@ -115,6 +116,7 @@ def get_search_router() -> APIRouter:
|
|||
"node_name": payload.node_name,
|
||||
"top_k": payload.top_k,
|
||||
"only_context": payload.only_context,
|
||||
"use_combined_context": payload.use_combined_context,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -131,6 +133,7 @@ def get_search_router() -> APIRouter:
|
|||
node_name=payload.node_name,
|
||||
top_k=payload.top_k,
|
||||
only_context=payload.only_context,
|
||||
use_combined_context=payload.use_combined_context,
|
||||
)
|
||||
|
||||
return JSONResponse(content=results)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Union, Optional, List, Type
|
|||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.search.methods import search as search_function
|
||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||
|
|
@ -13,7 +13,7 @@ from cognee.modules.data.exceptions import DatasetNotFoundError
|
|||
async def search(
|
||||
query_text: str,
|
||||
query_type: SearchType = SearchType.GRAPH_COMPLETION,
|
||||
user: User = None,
|
||||
user: Optional[User] = None,
|
||||
datasets: Optional[Union[list[str], str]] = None,
|
||||
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
|
|
@ -24,7 +24,8 @@ async def search(
|
|||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
) -> list:
|
||||
use_combined_context: bool = False,
|
||||
) -> Union[List[SearchResult], CombinedSearchResult]:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
||||
|
|
@ -193,6 +194,7 @@ async def search(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
|
||||
raise ex
|
||||
|
||||
async def calculate_top_triplet_importances(self, k: int) -> List:
|
||||
async def calculate_top_triplet_importances(self, k: int) -> List[Edge]:
|
||||
def score(edge):
|
||||
n1 = edge.node1.attributes.get("vector_distance", 1)
|
||||
n2 = edge.node2.attributes.get("vector_distance", 1)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
async def resolve_edges_to_text(retrieved_edges: list) -> str:
|
||||
from typing import List
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
||||
|
||||
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
|
||||
"""
|
||||
Converts retrieved graph edges into a human-readable string format.
|
||||
|
||||
|
|
@ -13,7 +17,7 @@ async def resolve_edges_to_text(retrieved_edges: list) -> str:
|
|||
- str: A formatted string representation of the nodes and their connections.
|
||||
"""
|
||||
|
||||
def _get_nodes(retrieved_edges: list) -> dict:
|
||||
def _get_nodes(retrieved_edges: List[Edge]) -> dict:
|
||||
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
||||
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
|
||||
"""Concatenates the top N frequent words in text."""
|
||||
|
|
@ -36,9 +40,9 @@ async def resolve_edges_to_text(retrieved_edges: list) -> str:
|
|||
return separator.join(top_words)
|
||||
|
||||
"""Creates a title, by combining first words with most frequent words from the text."""
|
||||
first_n_words = text.split()[:first_n_words]
|
||||
top_n_words = _top_n_words(text, top_n=top_n_words)
|
||||
return f"{' '.join(first_n_words)}... [{top_n_words}]"
|
||||
first_words = text.split()[:first_n_words]
|
||||
top_words = _top_n_words(text, top_n=first_n_words)
|
||||
return f"{' '.join(first_words)}... [{top_words}]"
|
||||
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
nodes = {}
|
||||
|
|
|
|||
18
cognee/modules/retrieval/base_graph_retriever.py
Normal file
18
cognee/modules/retrieval/base_graph_retriever.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from typing import List, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
||||
|
||||
class BaseGraphRetriever(ABC):
|
||||
"""Base class for all graph based retrievers."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_context(self, query: str) -> List[Edge]:
|
||||
"""Retrieves triplets based on the query."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_completion(self, query: str, context: Optional[List[Edge]] = None) -> str:
|
||||
"""Generates a response using the query and optional context (triplets)."""
|
||||
pass
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
import asyncio
|
||||
from functools import reduce
|
||||
from typing import List, Optional
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.tasks.codingagents.coding_rule_associations import get_existing_rules
|
||||
|
||||
|
|
@ -7,16 +10,22 @@ logger = get_logger("CodingRulesRetriever")
|
|||
class CodingRulesRetriever:
|
||||
"""Retriever for handling codeing rule based searches."""
|
||||
|
||||
def __init__(self, rules_nodeset_name="coding_agent_rules"):
|
||||
def __init__(self, rules_nodeset_name: Optional[List[str]] = None):
|
||||
if isinstance(rules_nodeset_name, list):
|
||||
if not rules_nodeset_name:
|
||||
# If there is no provided nodeset set to coding_agent_rules
|
||||
rules_nodeset_name = ["coding_agent_rules"]
|
||||
rules_nodeset_name = rules_nodeset_name[0]
|
||||
|
||||
self.rules_nodeset_name = rules_nodeset_name
|
||||
"""Initialize retriever with search parameters."""
|
||||
|
||||
async def get_existing_rules(self, query_text):
|
||||
return await get_existing_rules(
|
||||
rules_nodeset_name=self.rules_nodeset_name, return_list=True
|
||||
)
|
||||
if self.rules_nodeset_name:
|
||||
rules_list = await asyncio.gather(
|
||||
*[
|
||||
get_existing_rules(rules_nodeset_name=nodeset)
|
||||
for nodeset in self.rules_nodeset_name
|
||||
]
|
||||
)
|
||||
|
||||
return reduce(lambda x, y: x + y, rules_list, [])
|
||||
|
|
|
|||
|
|
@ -23,16 +23,14 @@ class CompletionRetriever(BaseRetriever):
|
|||
self,
|
||||
user_prompt_path: str = "context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: str = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: Optional[int] = 1,
|
||||
only_context: bool = False,
|
||||
):
|
||||
"""Initialize retriever with optional custom prompt paths."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k if top_k is not None else 1
|
||||
self.system_prompt = system_prompt
|
||||
self.only_context = only_context
|
||||
|
||||
async def get_context(self, query: str) -> str:
|
||||
"""
|
||||
|
|
@ -69,7 +67,7 @@ class CompletionRetriever(BaseRetriever):
|
|||
logger.error("DocumentChunk_text collection not found")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> str:
|
||||
"""
|
||||
Generates an LLM completion using the context.
|
||||
|
||||
|
|
@ -97,6 +95,5 @@ class CompletionRetriever(BaseRetriever):
|
|||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
only_context=self.only_context,
|
||||
)
|
||||
return [completion]
|
||||
return completion
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ class TripletSearchContextProvider(BaseContextProvider):
|
|||
tasks = [
|
||||
brute_force_triplet_search(
|
||||
query=f"{entity_text} {query}",
|
||||
user=user,
|
||||
top_k=self.top_k,
|
||||
collections=self.collections,
|
||||
properties_to_project=self.properties_to_project,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Any, Optional, List, Type
|
||||
from typing import Optional, List, Type
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
|
|
@ -31,7 +32,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
only_context: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -41,15 +41,14 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
)
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
context: Optional[List[Edge]] = None,
|
||||
context_extension_rounds=4,
|
||||
) -> List[str]:
|
||||
) -> str:
|
||||
"""
|
||||
Extends the context for a given query by retrieving related triplets and generating new
|
||||
completions based on them.
|
||||
|
|
@ -74,11 +73,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
- List[str]: A list containing the generated answer based on the query and the
|
||||
extended context.
|
||||
"""
|
||||
triplets = []
|
||||
triplets = context
|
||||
|
||||
if context is None:
|
||||
triplets += await self.get_triplets(query)
|
||||
context = await self.resolve_edges_to_text(triplets)
|
||||
if triplets is None:
|
||||
triplets = await self.get_context(query)
|
||||
|
||||
context_text = await self.resolve_edges_to_text(triplets)
|
||||
|
||||
round_idx = 1
|
||||
|
||||
|
|
@ -90,15 +90,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
)
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
|
||||
triplets += await self.get_triplets(completion)
|
||||
triplets += await self.get_context(completion)
|
||||
triplets = list(set(triplets))
|
||||
context = await self.resolve_edges_to_text(triplets)
|
||||
context_text = await self.resolve_edges_to_text(triplets)
|
||||
|
||||
num_triplets = len(triplets)
|
||||
|
||||
|
|
@ -117,19 +117,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
only_context=self.only_context,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
if self.save_interaction and context_text and triplets and completion:
|
||||
await self.save_qa(
|
||||
question=query, answer=completion, context=context, triplets=triplets
|
||||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
if self.only_context:
|
||||
return [context]
|
||||
else:
|
||||
return [completion]
|
||||
return completion
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Any, Optional, List, Tuple, Type
|
||||
from typing import Optional, List, Type
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
|
@ -32,18 +33,16 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
|
||||
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
||||
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
||||
system_prompt: str = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
only_context: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
|
|
@ -57,9 +56,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
context: Optional[List[Edge]] = None,
|
||||
max_iter=4,
|
||||
) -> List[str]:
|
||||
) -> str:
|
||||
"""
|
||||
Generate completion responses based on a user query and contextual information.
|
||||
|
||||
|
|
@ -84,26 +83,29 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
"""
|
||||
followup_question = ""
|
||||
triplets = []
|
||||
completion = [""]
|
||||
completion = ""
|
||||
|
||||
for round_idx in range(max_iter + 1):
|
||||
if round_idx == 0:
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
triplets = await self.get_context(query)
|
||||
context_text = await self.resolve_edges_to_text(triplets)
|
||||
else:
|
||||
context_text = await self.resolve_edges_to_text(context)
|
||||
else:
|
||||
triplets += await self.get_triplets(followup_question)
|
||||
context = await self.resolve_edges_to_text(list(set(triplets)))
|
||||
triplets += await self.get_context(followup_question)
|
||||
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||
if round_idx < max_iter:
|
||||
valid_args = {"query": query, "answer": completion, "context": context}
|
||||
valid_args = {"query": query, "answer": completion, "context": context_text}
|
||||
valid_user_prompt = LLMGateway.render_prompt(
|
||||
filename=self.validation_user_prompt_path, context=valid_args
|
||||
)
|
||||
|
|
@ -133,10 +135,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
await self.save_qa(
|
||||
question=query, answer=completion, context=context, triplets=triplets
|
||||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
if self.only_context:
|
||||
return [context]
|
||||
else:
|
||||
return [completion]
|
||||
return completion
|
||||
|
|
|
|||
|
|
@ -1,16 +1,15 @@
|
|||
from typing import Any, Optional, Type, List
|
||||
from collections import Counter
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
import string
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
|
||||
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
|
||||
|
|
@ -20,7 +19,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
logger = get_logger("GraphCompletionRetriever")
|
||||
|
||||
|
||||
class GraphCompletionRetriever(BaseRetriever):
|
||||
class GraphCompletionRetriever(BaseGraphRetriever):
|
||||
"""
|
||||
Retriever for handling graph-based completion searches.
|
||||
|
||||
|
|
@ -37,19 +36,17 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
self,
|
||||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: str = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
only_context: bool = False,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.save_interaction = save_interaction
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.system_prompt = system_prompt
|
||||
self.only_context = only_context
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
self.node_type = node_type
|
||||
self.node_name = node_name
|
||||
|
|
@ -70,7 +67,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
"""
|
||||
return await resolve_edges_to_text(retrieved_edges)
|
||||
|
||||
async def get_triplets(self, query: str) -> list:
|
||||
async def get_triplets(self, query: str) -> List[Edge]:
|
||||
"""
|
||||
Retrieves relevant graph triplets based on a query string.
|
||||
|
||||
|
|
@ -85,7 +82,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
- list: A list of found triplets that match the query.
|
||||
"""
|
||||
subclasses = get_all_subclasses(DataPoint)
|
||||
vector_index_collections = []
|
||||
vector_index_collections: List[str] = []
|
||||
|
||||
for subclass in subclasses:
|
||||
if "metadata" in subclass.model_fields:
|
||||
|
|
@ -96,8 +93,11 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
for field_name in index_fields:
|
||||
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
found_triplets = await brute_force_triplet_search(
|
||||
query,
|
||||
user=user,
|
||||
top_k=self.top_k,
|
||||
collections=vector_index_collections or None,
|
||||
node_type=self.node_type,
|
||||
|
|
@ -106,7 +106,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
|
||||
return found_triplets
|
||||
|
||||
async def get_context(self, query: str) -> tuple[str, list]:
|
||||
async def get_context(self, query: str) -> List[Edge]:
|
||||
"""
|
||||
Retrieves and resolves graph triplets into context based on a query.
|
||||
|
||||
|
|
@ -125,17 +125,17 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
|
||||
if len(triplets) == 0:
|
||||
logger.warning("Empty context was provided to the completion")
|
||||
return "", triplets
|
||||
return []
|
||||
|
||||
context = await self.resolve_edges_to_text(triplets)
|
||||
# context = await self.resolve_edges_to_text(triplets)
|
||||
|
||||
return context, triplets
|
||||
return triplets
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
) -> List[str]:
|
||||
context: Optional[List[Edge]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Generates a completion using graph connections context based on a query.
|
||||
|
||||
|
|
@ -151,26 +151,27 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
|
||||
- Any: A generated completion based on the query and context provided.
|
||||
"""
|
||||
triplets = None
|
||||
triplets = context
|
||||
|
||||
if context is None:
|
||||
context, triplets = await self.get_context(query)
|
||||
if triplets is None:
|
||||
triplets = await self.get_context(query)
|
||||
|
||||
context_text = await resolve_edges_to_text(triplets)
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
context=context_text,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
only_context=self.only_context,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
await self.save_qa(
|
||||
question=query, answer=completion, context=context, triplets=triplets
|
||||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
return [completion]
|
||||
return completion
|
||||
|
||||
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
logger = get_logger("InsightsRetriever")
|
||||
|
||||
|
||||
class InsightsRetriever(BaseRetriever):
|
||||
class InsightsRetriever(BaseGraphRetriever):
|
||||
"""
|
||||
Retriever for handling graph connection-based insights.
|
||||
|
||||
|
|
@ -95,7 +96,17 @@ class InsightsRetriever(BaseRetriever):
|
|||
unique_node_connections_map[unique_id] = True
|
||||
unique_node_connections.append(node_connection)
|
||||
|
||||
return unique_node_connections
|
||||
return [
|
||||
Edge(
|
||||
node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
||||
node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
||||
attributes={
|
||||
**connection[1],
|
||||
"relationship_type": connection[1]["relationship_name"],
|
||||
},
|
||||
)
|
||||
for connection in unique_node_connections
|
||||
]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from typing import Any, Optional, List, Tuple, Type
|
||||
from typing import Any, Optional, List, Type
|
||||
|
||||
|
||||
from operator import itemgetter
|
||||
|
|
@ -113,8 +113,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
logger.info(
|
||||
"No timestamps identified based on the query, performing retrieval using triplet search on events and entities."
|
||||
)
|
||||
triplets = await self.get_triplets(query)
|
||||
return await self.resolve_edges_to_text(triplets), triplets
|
||||
triplets = await self.get_context(query)
|
||||
return await self.resolve_edges_to_text(triplets)
|
||||
|
||||
if ids:
|
||||
relevant_events = await graph_engine.collect_events(ids=ids)
|
||||
|
|
@ -122,8 +122,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
logger.info(
|
||||
"No events identified based on timestamp filtering, performing retrieval using triplet search on events and entities."
|
||||
)
|
||||
triplets = await self.get_triplets(query)
|
||||
return await self.resolve_edges_to_text(triplets), triplets
|
||||
triplets = await self.get_context(query)
|
||||
return await self.resolve_edges_to_text(triplets)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
|
||||
|
|
@ -134,18 +134,19 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
|
||||
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
|
||||
|
||||
return self.descriptions_to_string(top_k_events), triplets
|
||||
return self.descriptions_to_string(top_k_events)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
|
||||
async def get_completion(self, query: str, context: Optional[str] = None) -> str:
|
||||
"""Generates a response using the query and optional context."""
|
||||
if not context:
|
||||
context = await self.get_context(query=query)
|
||||
|
||||
context, triplets = await self.get_context(query=query)
|
||||
if context:
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
return completion
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFound
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.utils import send_telemetry
|
||||
|
||||
|
|
@ -87,41 +87,15 @@ async def get_memory_fragment(
|
|||
|
||||
|
||||
async def brute_force_triplet_search(
|
||||
query: str,
|
||||
user: User = None,
|
||||
top_k: int = 5,
|
||||
collections: List[str] = None,
|
||||
properties_to_project: List[str] = None,
|
||||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
) -> list:
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
retrieved_results = await brute_force_search(
|
||||
query,
|
||||
user,
|
||||
top_k,
|
||||
collections=collections,
|
||||
properties_to_project=properties_to_project,
|
||||
memory_fragment=memory_fragment,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
return retrieved_results
|
||||
|
||||
|
||||
async def brute_force_search(
|
||||
query: str,
|
||||
user: User,
|
||||
top_k: int,
|
||||
collections: List[str] = None,
|
||||
properties_to_project: List[str] = None,
|
||||
top_k: int = 5,
|
||||
collections: Optional[List[str]] = None,
|
||||
properties_to_project: Optional[List[str]] = None,
|
||||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
) -> list:
|
||||
) -> List[Edge]:
|
||||
"""
|
||||
Performs a brute force search to retrieve the top triplets from the graph.
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ async def generate_completion(
|
|||
user_prompt_path: str,
|
||||
system_prompt_path: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
only_context: bool = False,
|
||||
) -> str:
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
args = {"question": query, "context": context}
|
||||
|
|
@ -17,14 +16,11 @@ async def generate_completion(
|
|||
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
|
||||
)
|
||||
|
||||
if only_context:
|
||||
return context
|
||||
else:
|
||||
return await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
return await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
|
||||
async def summarize_text(
|
||||
|
|
|
|||
168
cognee/modules/search/methods/get_search_type_tools.py
Normal file
168
cognee/modules/search/methods/get_search_type_tools.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
from typing import Callable, List, Optional, Type
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.search.operations import select_search_type
|
||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||
|
||||
# Retrievers
|
||||
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
|
||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever
|
||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
)
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
)
|
||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
||||
|
||||
|
||||
async def get_search_type_tools(
|
||||
query_type: SearchType,
|
||||
query_text: str,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, List[Callable]] = {
|
||||
SearchType.SUMMARIES: [
|
||||
SummariesRetriever(top_k=top_k).get_completion,
|
||||
SummariesRetriever(top_k=top_k).get_context,
|
||||
],
|
||||
SearchType.INSIGHTS: [
|
||||
InsightsRetriever(top_k=top_k).get_completion,
|
||||
InsightsRetriever(top_k=top_k).get_context,
|
||||
],
|
||||
SearchType.CHUNKS: [
|
||||
ChunksRetriever(top_k=top_k).get_completion,
|
||||
ChunksRetriever(top_k=top_k).get_context,
|
||||
],
|
||||
SearchType.RAG_COMPLETION: [
|
||||
CompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
system_prompt=system_prompt,
|
||||
).get_completion,
|
||||
CompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
system_prompt=system_prompt,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_COMPLETION: [
|
||||
GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
).get_completion,
|
||||
GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_COMPLETION_COT: [
|
||||
GraphCompletionCotRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
).get_completion,
|
||||
GraphCompletionCotRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
|
||||
GraphCompletionContextExtensionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
).get_completion,
|
||||
GraphCompletionContextExtensionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION: [
|
||||
GraphSummaryCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
).get_completion,
|
||||
GraphSummaryCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.CODE: [
|
||||
CodeRetriever(top_k=top_k).get_completion,
|
||||
CodeRetriever(top_k=top_k).get_context,
|
||||
],
|
||||
SearchType.CYPHER: [
|
||||
CypherSearchRetriever().get_completion,
|
||||
CypherSearchRetriever().get_context,
|
||||
],
|
||||
SearchType.NATURAL_LANGUAGE: [
|
||||
NaturalLanguageRetriever().get_completion,
|
||||
NaturalLanguageRetriever().get_context,
|
||||
],
|
||||
SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
|
||||
SearchType.TEMPORAL: [
|
||||
TemporalRetriever(top_k=top_k).get_completion,
|
||||
TemporalRetriever(top_k=top_k).get_context,
|
||||
],
|
||||
SearchType.CODING_RULES: [
|
||||
CodingRulesRetriever(rules_nodeset_name=node_name).get_existing_rules,
|
||||
],
|
||||
}
|
||||
|
||||
# If the query type is FEELING_LUCKY, select the search type intelligently
|
||||
if query_type is SearchType.FEELING_LUCKY:
|
||||
query_type = await select_search_type(query_text)
|
||||
|
||||
search_type_tools = search_tasks.get(query_type)
|
||||
|
||||
if not search_type_tools:
|
||||
raise UnsupportedSearchTypeError(str(query_type))
|
||||
|
||||
return search_type_tools
|
||||
47
cognee/modules/search/methods/no_access_control_search.py
Normal file
47
cognee/modules/search/methods/no_access_control_search.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from typing import Any, List, Optional, Tuple, Type, Union
|
||||
|
||||
from cognee.modules.data.models.Dataset import Dataset
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
from .get_search_type_tools import get_search_type_tools
|
||||
|
||||
|
||||
async def no_access_control_search(
|
||||
query_type: SearchType,
|
||||
query_text: str,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||
search_tools = await get_search_type_tools(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
)
|
||||
if len(search_tools) == 2:
|
||||
[get_completion, get_context] = search_tools
|
||||
|
||||
if only_context:
|
||||
return await get_context(query_text)
|
||||
|
||||
context = await get_context(query_text)
|
||||
result = await get_completion(query_text, context)
|
||||
else:
|
||||
unknown_tool = search_tools[0]
|
||||
result = await unknown_tool(query_text)
|
||||
context = ""
|
||||
|
||||
return result, context, []
|
||||
|
|
@ -3,37 +3,27 @@ import json
|
|||
import asyncio
|
||||
from uuid import UUID
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from typing import Callable, List, Optional, Type, Union
|
||||
from typing import Any, List, Optional, Tuple, Type, Union
|
||||
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
|
||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever
|
||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.search.types import (
|
||||
SearchResult,
|
||||
CombinedSearchResult,
|
||||
SearchResultDataset,
|
||||
SearchType,
|
||||
)
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
)
|
||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from cognee.modules.search.operations import log_query, log_result
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.models import Dataset
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
||||
from cognee.modules.search.operations import log_query, log_result, select_search_type
|
||||
|
||||
from .get_search_type_tools import get_search_type_tools
|
||||
from .no_access_control_search import no_access_control_search
|
||||
from ..utils.prepare_search_result import prepare_search_result
|
||||
|
||||
|
||||
async def search(
|
||||
|
|
@ -46,10 +36,11 @@ async def search(
|
|||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
):
|
||||
use_combined_context: bool = False,
|
||||
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
|
|
@ -65,9 +56,12 @@ async def search(
|
|||
Notes:
|
||||
Searching by dataset is only available in ENABLE_BACKEND_ACCESS_CONTROL mode
|
||||
"""
|
||||
query = await log_query(query_text, query_type.value, user.id)
|
||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||
|
||||
# Use search function filtered by permissions if access control is enabled
|
||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
return await authorized_search(
|
||||
search_results = await authorized_search(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=user,
|
||||
|
|
@ -80,119 +74,68 @@ async def search(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
)
|
||||
else:
|
||||
search_results = [
|
||||
await no_access_control_search(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
)
|
||||
]
|
||||
|
||||
query = await log_query(query_text, query_type.value, user.id)
|
||||
|
||||
search_results = await specific_search(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
)
|
||||
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
||||
|
||||
await log_result(
|
||||
query.id,
|
||||
json.dumps(
|
||||
search_results if len(search_results) > 1 else search_results[0], cls=JSONEncoder
|
||||
jsonable_encoder(
|
||||
await prepare_search_result(search_results)
|
||||
if use_combined_context
|
||||
else [
|
||||
await prepare_search_result(search_result) for search_result in search_results
|
||||
]
|
||||
)
|
||||
),
|
||||
user.id,
|
||||
)
|
||||
|
||||
return search_results
|
||||
if use_combined_context:
|
||||
prepared_search_results = await prepare_search_result(search_results)
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
|
||||
|
||||
async def specific_search(
|
||||
query_type: SearchType,
|
||||
query_text: str,
|
||||
user: User,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = None,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
||||
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
|
||||
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
|
||||
SearchType.RAG_COMPLETION: CompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
).get_completion,
|
||||
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
|
||||
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
||||
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
|
||||
SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback,
|
||||
SearchType.TEMPORAL: TemporalRetriever(top_k=top_k).get_completion,
|
||||
SearchType.CODING_RULES: CodingRulesRetriever(
|
||||
rules_nodeset_name=node_name
|
||||
).get_existing_rules,
|
||||
}
|
||||
|
||||
# If the query type is FEELING_LUCKY, select the search type intelligently
|
||||
if query_type is SearchType.FEELING_LUCKY:
|
||||
query_type = await select_search_type(query_text)
|
||||
|
||||
search_task = search_tasks.get(query_type)
|
||||
|
||||
if search_task is None:
|
||||
raise UnsupportedSearchTypeError(str(query_type))
|
||||
|
||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||
|
||||
results = await search_task(query_text)
|
||||
|
||||
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
||||
|
||||
return results
|
||||
return CombinedSearchResult(
|
||||
result=result,
|
||||
graphs=graphs,
|
||||
context=context,
|
||||
datasets=[
|
||||
SearchResultDataset(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
)
|
||||
for dataset in datasets
|
||||
],
|
||||
)
|
||||
else:
|
||||
return [
|
||||
SearchResult(
|
||||
search_result=result,
|
||||
dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
|
||||
dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
|
||||
)
|
||||
for index, (result, _, datasets) in enumerate(search_results)
|
||||
]
|
||||
|
||||
|
||||
async def authorized_search(
|
||||
|
|
@ -205,26 +148,85 @@ async def authorized_search(
|
|||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = None,
|
||||
) -> list:
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
) -> Union[
|
||||
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
||||
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
||||
]:
|
||||
"""
|
||||
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
||||
Not to be used outside of active access control mode.
|
||||
"""
|
||||
|
||||
query = await log_query(query_text, query_type.value, user.id)
|
||||
|
||||
# Find datasets user has read access for (if datasets are provided only return them. Provided user has read access)
|
||||
search_datasets = await get_specific_user_permission_datasets(user.id, "read", dataset_ids)
|
||||
|
||||
if use_combined_context:
|
||||
search_responses = await search_in_datasets_context(
|
||||
search_datasets=search_datasets,
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=True,
|
||||
)
|
||||
|
||||
context = {}
|
||||
datasets: List[Dataset] = []
|
||||
|
||||
for _, search_context, datasets in search_responses:
|
||||
for dataset in datasets:
|
||||
context[str(dataset.id)] = search_context
|
||||
|
||||
datasets.extend(datasets)
|
||||
|
||||
specific_search_tools = await get_search_type_tools(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
)
|
||||
search_tools = specific_search_tools
|
||||
if len(search_tools) == 2:
|
||||
[get_completion, _] = search_tools
|
||||
else:
|
||||
get_completion = search_tools[0]
|
||||
|
||||
def prepare_combined_context(
|
||||
context,
|
||||
) -> Union[List[Edge], str]:
|
||||
combined_context = []
|
||||
|
||||
for dataset_context in context.values():
|
||||
combined_context += dataset_context
|
||||
|
||||
if combined_context and isinstance(combined_context[0], str):
|
||||
return "\n".join(combined_context)
|
||||
|
||||
return combined_context
|
||||
|
||||
combined_context = prepare_combined_context(context)
|
||||
completion = await get_completion(query_text, combined_context)
|
||||
|
||||
return completion, combined_context, datasets
|
||||
|
||||
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
||||
search_results = await specific_search_by_context(
|
||||
search_results = await search_in_datasets_context(
|
||||
search_datasets=search_datasets,
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
|
|
@ -235,51 +237,48 @@ async def authorized_search(
|
|||
only_context=only_context,
|
||||
)
|
||||
|
||||
await log_result(query.id, json.dumps(jsonable_encoder(search_results)), user.id)
|
||||
|
||||
return search_results
|
||||
|
||||
|
||||
async def specific_search_by_context(
|
||||
async def search_in_datasets_context(
|
||||
search_datasets: list[Dataset],
|
||||
query_type: SearchType,
|
||||
query_text: str,
|
||||
user: User,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = None,
|
||||
):
|
||||
only_context: bool = False,
|
||||
context: Optional[Any] = None,
|
||||
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
|
||||
"""
|
||||
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
||||
Not to be used outside of active access control mode.
|
||||
"""
|
||||
|
||||
async def _search_by_context(
|
||||
async def _search_in_dataset_context(
|
||||
dataset: Dataset,
|
||||
query_type: SearchType,
|
||||
query_text: str,
|
||||
user: User,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = None,
|
||||
):
|
||||
only_context: bool = False,
|
||||
context: Optional[Any] = None,
|
||||
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||
# Set database configuration in async context for each dataset user has access for
|
||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||
|
||||
result = await specific_search(
|
||||
specific_search_tools = await get_search_type_tools(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
|
|
@ -287,57 +286,31 @@ async def specific_search_by_context(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
)
|
||||
search_tools = specific_search_tools
|
||||
if len(search_tools) == 2:
|
||||
[get_completion, get_context] = search_tools
|
||||
|
||||
if isinstance(result, tuple):
|
||||
search_results = result[0]
|
||||
triplets = result[1]
|
||||
if only_context:
|
||||
return None, await get_context(query_text), [dataset]
|
||||
|
||||
search_context = context or await get_context(query_text)
|
||||
search_result = await get_completion(query_text, search_context)
|
||||
|
||||
return search_result, search_context, [dataset]
|
||||
else:
|
||||
search_results = result
|
||||
triplets = []
|
||||
unknown_tool = search_tools[0]
|
||||
|
||||
return {
|
||||
"search_result": search_results,
|
||||
"graph": [
|
||||
{
|
||||
"source": {
|
||||
"id": triplet.node1.id,
|
||||
"attributes": {
|
||||
"name": triplet.node1.attributes["name"],
|
||||
"type": triplet.node1.attributes["type"],
|
||||
"description": triplet.node1.attributes["description"],
|
||||
"vector_distance": triplet.node1.attributes["vector_distance"],
|
||||
},
|
||||
},
|
||||
"destination": {
|
||||
"id": triplet.node2.id,
|
||||
"attributes": {
|
||||
"name": triplet.node2.attributes["name"],
|
||||
"type": triplet.node2.attributes["type"],
|
||||
"description": triplet.node2.attributes["description"],
|
||||
"vector_distance": triplet.node2.attributes["vector_distance"],
|
||||
},
|
||||
},
|
||||
"attributes": {
|
||||
"relationship_name": triplet.attributes["relationship_name"],
|
||||
},
|
||||
}
|
||||
for triplet in triplets
|
||||
],
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
}
|
||||
return await unknown_tool(query_text), "", [dataset]
|
||||
|
||||
# Search every dataset async based on query and appropriate database configuration
|
||||
tasks = []
|
||||
for dataset in search_datasets:
|
||||
tasks.append(
|
||||
_search_by_context(
|
||||
_search_in_dataset_context(
|
||||
dataset=dataset,
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
|
|
@ -346,6 +319,7 @@ async def specific_search_by_context(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
21
cognee/modules/search/types/SearchResult.py
Normal file
21
cognee/modules/search/types/SearchResult.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class SearchResultDataset(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
|
||||
|
||||
class CombinedSearchResult(BaseModel):
|
||||
result: Optional[Any]
|
||||
context: Dict[str, Any]
|
||||
graphs: Optional[Dict[str, Any]] = {}
|
||||
datasets: Optional[List[SearchResultDataset]] = None
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
search_result: Any
|
||||
dataset_id: Optional[UUID]
|
||||
dataset_name: Optional[str]
|
||||
|
|
@ -1 +1,2 @@
|
|||
from .SearchType import SearchType
|
||||
from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult
|
||||
|
|
|
|||
2
cognee/modules/search/utils/__init__.py
Normal file
2
cognee/modules/search/utils/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from .prepare_search_result import prepare_search_result
|
||||
from .transform_context_to_graph import transform_context_to_graph
|
||||
41
cognee/modules/search/utils/prepare_search_result.py
Normal file
41
cognee/modules/search/utils/prepare_search_result.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
from typing import List, cast
|
||||
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.search.utils.transform_context_to_graph import transform_context_to_graph
|
||||
|
||||
|
||||
async def prepare_search_result(search_result):
|
||||
result, context, datasets = search_result
|
||||
|
||||
graphs = None
|
||||
result_graph = None
|
||||
context_texts = {}
|
||||
|
||||
if isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge):
|
||||
result_graph = transform_context_to_graph(context)
|
||||
|
||||
graphs = {
|
||||
"*": result_graph,
|
||||
}
|
||||
context_texts = {
|
||||
"*": await resolve_edges_to_text(context),
|
||||
}
|
||||
elif isinstance(context, str):
|
||||
context_texts = {
|
||||
"*": context,
|
||||
}
|
||||
elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], str):
|
||||
context_texts = {
|
||||
"*": "\n".join(cast(List[str], context)),
|
||||
}
|
||||
|
||||
if isinstance(result, List) and len(result) > 0 and isinstance(result[0], Edge):
|
||||
result_graph = transform_context_to_graph(result)
|
||||
|
||||
return {
|
||||
"result": result_graph or result,
|
||||
"graphs": graphs,
|
||||
"context": context_texts,
|
||||
"datasets": datasets,
|
||||
}
|
||||
38
cognee/modules/search/utils/transform_context_to_graph.py
Normal file
38
cognee/modules/search/utils/transform_context_to_graph.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from typing import List
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
||||
|
||||
def transform_context_to_graph(context: List[Edge]):
|
||||
nodes = {}
|
||||
edges = {}
|
||||
|
||||
for triplet in context:
|
||||
nodes[triplet.node1.id] = {
|
||||
"id": triplet.node1.id,
|
||||
"label": triplet.node1.attributes["name"]
|
||||
if "name" in triplet.node1.attributes
|
||||
else triplet.node1.id,
|
||||
"type": triplet.node1.attributes["type"],
|
||||
"attributes": triplet.node2.attributes,
|
||||
}
|
||||
nodes[triplet.node2.id] = {
|
||||
"id": triplet.node2.id,
|
||||
"label": triplet.node2.attributes["name"]
|
||||
if "name" in triplet.node2.attributes
|
||||
else triplet.node2.id,
|
||||
"type": triplet.node2.attributes["type"],
|
||||
"attributes": triplet.node2.attributes,
|
||||
}
|
||||
edges[
|
||||
f"{triplet.node1.id}_{triplet.attributes['relationship_name']}_{triplet.node2.id}"
|
||||
] = {
|
||||
"source": triplet.node1.id,
|
||||
"target": triplet.node2.id,
|
||||
"label": triplet.attributes["relationship_name"],
|
||||
}
|
||||
|
||||
return {
|
||||
"nodes": list(nodes.values()),
|
||||
"edges": list(edges.values()),
|
||||
}
|
||||
|
|
@ -31,7 +31,7 @@ class RuleSet(DataPoint):
|
|||
)
|
||||
|
||||
|
||||
async def get_existing_rules(rules_nodeset_name: str, return_list: bool = False) -> str:
|
||||
async def get_existing_rules(rules_nodeset_name: str) -> List[str]:
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes_data, _ = await graph_engine.get_nodeset_subgraph(
|
||||
node_type=NodeSet, node_name=[rules_nodeset_name]
|
||||
|
|
@ -46,9 +46,6 @@ async def get_existing_rules(rules_nodeset_name: str, return_list: bool = False)
|
|||
and "text" in item[1]
|
||||
]
|
||||
|
||||
if not return_list:
|
||||
existing_rules = "\n".join(f"- {rule}" for rule in existing_rules)
|
||||
|
||||
return existing_rules
|
||||
|
||||
|
||||
|
|
@ -103,6 +100,7 @@ async def add_rule_associations(
|
|||
|
||||
graph_engine = await get_graph_engine()
|
||||
existing_rules = await get_existing_rules(rules_nodeset_name=rules_nodeset_name)
|
||||
existing_rules = "\n".join(f"- {rule}" for rule in existing_rules)
|
||||
|
||||
user_context = {"chat": data, "rules": existing_rules}
|
||||
|
||||
|
|
|
|||
|
|
@ -94,21 +94,21 @@ async def main():
|
|||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
context_nonempty, _ = await GraphCompletionRetriever(
|
||||
context_nonempty = await GraphCompletionRetriever(
|
||||
node_type=NodeSet,
|
||||
node_name=["first"],
|
||||
).get_context("What is in the context?")
|
||||
|
||||
context_empty, _ = await GraphCompletionRetriever(
|
||||
context_empty = await GraphCompletionRetriever(
|
||||
node_type=NodeSet,
|
||||
node_name=["nonexistent"],
|
||||
).get_context("What is in the context?")
|
||||
|
||||
assert isinstance(context_nonempty, str) and context_nonempty != "", (
|
||||
assert isinstance(context_nonempty, list) and context_nonempty != [], (
|
||||
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
||||
)
|
||||
|
||||
assert context_empty == "", (
|
||||
assert context_empty == [], (
|
||||
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -98,21 +98,21 @@ async def main():
|
|||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
context_nonempty, _ = await GraphCompletionRetriever(
|
||||
context_nonempty = await GraphCompletionRetriever(
|
||||
node_type=NodeSet,
|
||||
node_name=["first"],
|
||||
).get_context("What is in the context?")
|
||||
|
||||
context_empty, _ = await GraphCompletionRetriever(
|
||||
context_empty = await GraphCompletionRetriever(
|
||||
node_type=NodeSet,
|
||||
node_name=["nonexistent"],
|
||||
).get_context("What is in the context?")
|
||||
|
||||
assert isinstance(context_nonempty, str) and context_nonempty != "", (
|
||||
assert isinstance(context_nonempty, list) and context_nonempty != [], (
|
||||
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
|
||||
)
|
||||
|
||||
assert context_empty == "", (
|
||||
assert context_empty == [], (
|
||||
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ async def main():
|
|||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
assert search_results[0]["dataset_name"] == "NLP", (
|
||||
assert search_results[0].dataset_name == "NLP", (
|
||||
f"Dict must contain dataset name 'NLP': {search_results[0]}"
|
||||
)
|
||||
|
||||
|
|
@ -93,7 +93,7 @@ async def main():
|
|||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
||||
assert search_results[0].dataset_name == "QUANTUM", (
|
||||
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
||||
)
|
||||
|
||||
|
|
@ -170,7 +170,7 @@ async def main():
|
|||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
||||
assert search_results[0].dataset_name == "QUANTUM", (
|
||||
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import pathlib
|
||||
import os
|
||||
from typing import List
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import (
|
||||
get_migration_relational_engine,
|
||||
|
|
@ -10,7 +10,7 @@ from cognee.infrastructure.databases.vector.pgvector import (
|
|||
create_db_and_tables as create_pgvector_db_and_tables,
|
||||
)
|
||||
from cognee.tasks.ingestion import migrate_relational_database
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.search.types import SearchResult, SearchType
|
||||
import cognee
|
||||
|
||||
|
||||
|
|
@ -45,13 +45,15 @@ async def relational_db_migration():
|
|||
await migrate_relational_database(graph_engine, schema=schema)
|
||||
|
||||
# 1. Search the graph
|
||||
search_results = await cognee.search(
|
||||
search_results: List[SearchResult] = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC"
|
||||
)
|
||||
) # type: ignore
|
||||
print("Search results:", search_results)
|
||||
|
||||
# 2. Assert that the search results contain "AC/DC"
|
||||
assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!"
|
||||
assert any("AC/DC" in r.search_result for r in search_results), (
|
||||
"AC/DC not found in search results!"
|
||||
)
|
||||
|
||||
migration_db_provider = migration_engine.engine.dialect.name
|
||||
if migration_db_provider == "postgresql":
|
||||
|
|
|
|||
|
|
@ -1,11 +1,7 @@
|
|||
import os
|
||||
import pathlib
|
||||
|
||||
from dns.e164 import query
|
||||
|
||||
import cognee
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
|
|
@ -14,11 +10,8 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
|
|||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
)
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.engine.models import NodeSet
|
||||
from collections import Counter
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -46,16 +39,16 @@ async def main():
|
|||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
context_gk, _ = await GraphCompletionRetriever().get_context(
|
||||
context_gk = await GraphCompletionRetriever().get_context(
|
||||
query="Next to which country is Germany located?"
|
||||
)
|
||||
context_gk_cot, _ = await GraphCompletionCotRetriever().get_context(
|
||||
context_gk_cot = await GraphCompletionCotRetriever().get_context(
|
||||
query="Next to which country is Germany located?"
|
||||
)
|
||||
context_gk_ext, _ = await GraphCompletionContextExtensionRetriever().get_context(
|
||||
context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(
|
||||
query="Next to which country is Germany located?"
|
||||
)
|
||||
context_gk_sum, _ = await GraphSummaryCompletionRetriever().get_context(
|
||||
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
|
||||
query="Next to which country is Germany located?"
|
||||
)
|
||||
|
||||
|
|
@ -65,9 +58,11 @@ async def main():
|
|||
("GraphCompletionContextExtensionRetriever", context_gk_ext),
|
||||
("GraphSummaryCompletionRetriever", context_gk_sum),
|
||||
]:
|
||||
assert isinstance(context, str), f"{name}: Context should be a string"
|
||||
assert context.strip(), f"{name}: Context should not be empty"
|
||||
lower = context.lower()
|
||||
assert isinstance(context, list), f"{name}: Context should be a list"
|
||||
assert len(context) > 0, f"{name}: Context should not be empty"
|
||||
|
||||
context_text = await resolve_edges_to_text(context)
|
||||
lower = context_text.lower()
|
||||
assert "germany" in lower or "netherlands" in lower, (
|
||||
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
|
||||
)
|
||||
|
|
@ -143,20 +138,19 @@ async def main():
|
|||
last_k=1,
|
||||
)
|
||||
|
||||
for name, completion in [
|
||||
for name, search_results in [
|
||||
("GRAPH_COMPLETION", completion_gk),
|
||||
("GRAPH_COMPLETION_COT", completion_cot),
|
||||
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
|
||||
("GRAPH_SUMMARY_COMPLETION", completion_sum),
|
||||
]:
|
||||
assert isinstance(completion, list), f"{name}: should return a list"
|
||||
assert len(completion) == 1, f"{name}: expected single-element list, got {len(completion)}"
|
||||
text = completion[0]
|
||||
assert isinstance(text, str), f"{name}: element should be a string"
|
||||
assert text.strip(), f"{name}: string should not be empty"
|
||||
assert "netherlands" in text.lower(), (
|
||||
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
||||
)
|
||||
for search_result in search_results:
|
||||
completion = search_result.search_result
|
||||
assert isinstance(completion, str), f"{name}: should return a string"
|
||||
assert completion.strip(), f"{name}: string should not be empty"
|
||||
assert "netherlands" in completion.lower(), (
|
||||
f"{name}: expected 'netherlands' in result, got: {completion!r}"
|
||||
)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
graph = await graph_engine.get_graph_data()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Optional, Union
|
|||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
|
|
@ -51,17 +52,15 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
context, _ = await retriever.get_context("Who works at Canva?")
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_extension_context_complex(self):
|
||||
|
|
@ -129,7 +128,9 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|||
|
||||
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
|
||||
|
||||
context, _ = await retriever.get_context("Who works at Figma?")
|
||||
context = await resolve_edges_to_text(
|
||||
await retriever.get_context("Who works at Figma and drives Tesla?")
|
||||
)
|
||||
|
||||
print(context)
|
||||
|
||||
|
|
@ -139,10 +140,8 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
||||
|
|
@ -167,12 +166,10 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|||
|
||||
await setup()
|
||||
|
||||
context, _ = await retriever.get_context("Who works at Figma?")
|
||||
assert context == "", "Context should be empty on an empty graph"
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == [], "Context should be empty on an empty graph"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Optional, Union
|
|||
|
||||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
|
|
@ -47,17 +48,15 @@ class TestGraphCompletionCoTRetriever:
|
|||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
context, _ = await retriever.get_context("Who works at Canva?")
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_complex(self):
|
||||
|
|
@ -124,7 +123,7 @@ class TestGraphCompletionCoTRetriever:
|
|||
|
||||
retriever = GraphCompletionCotRetriever(top_k=20)
|
||||
|
||||
context, _ = await retriever.get_context("Who works at Figma?")
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
||||
|
||||
print(context)
|
||||
|
||||
|
|
@ -134,10 +133,8 @@ class TestGraphCompletionCoTRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
||||
|
|
@ -162,12 +159,10 @@ class TestGraphCompletionCoTRetriever:
|
|||
|
||||
await setup()
|
||||
|
||||
context, _ = await retriever.get_context("Who works at Figma?")
|
||||
assert context == "", "Context should be empty on an empty graph"
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == [], "Context should be empty on an empty graph"
|
||||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Optional, Union
|
|||
|
||||
import cognee
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
|
@ -67,7 +68,7 @@ class TestGraphCompletionRetriever:
|
|||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
context, _ = await retriever.get_context("Who works at Canva?")
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
||||
|
||||
# Ensure the top-level sections are present
|
||||
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
|
||||
|
|
@ -191,7 +192,7 @@ class TestGraphCompletionRetriever:
|
|||
|
||||
retriever = GraphCompletionRetriever(top_k=20)
|
||||
|
||||
context, _ = await retriever.get_context("Who works at Figma?")
|
||||
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
||||
|
||||
print(context)
|
||||
|
||||
|
|
@ -222,5 +223,5 @@ class TestGraphCompletionRetriever:
|
|||
|
||||
await setup()
|
||||
|
||||
context, _ = await retriever.get_context("Who works at Figma?")
|
||||
assert context == "", "Context should be empty on an empty graph"
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == [], "Context should be empty on an empty graph"
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class TestInsightsRetriever:
|
|||
|
||||
context = await retriever.get_context("Mike")
|
||||
|
||||
assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski"
|
||||
assert context[0].node1.attributes["name"] == "Mike Broski", "Failed to get Mike Broski"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insights_context_complex(self):
|
||||
|
|
@ -222,7 +222,9 @@ class TestInsightsRetriever:
|
|||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||
assert context[0].node1.attributes["name"] == "Christina Mayer", (
|
||||
"Failed to get Christina Mayer"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insights_context_on_empty_graph(self):
|
||||
|
|
|
|||
|
|
@ -1,230 +0,0 @@
|
|||
import json
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||
from cognee.modules.search.methods.search import search, specific_search
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.models import User
|
||||
import sys
|
||||
|
||||
search_module = sys.modules.get("cognee.modules.search.methods.search")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
user = MagicMock(spec=User)
|
||||
user.id = uuid.uuid4()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(search_module, "log_query")
|
||||
@patch.object(search_module, "log_result")
|
||||
@patch.object(search_module, "specific_search")
|
||||
async def test_search(
|
||||
mock_specific_search,
|
||||
mock_log_result,
|
||||
mock_log_query,
|
||||
mock_user,
|
||||
):
|
||||
# Setup
|
||||
query_text = "test query"
|
||||
query_type = SearchType.CHUNKS
|
||||
datasets = ["dataset1", "dataset2"]
|
||||
|
||||
# Mock the query logging
|
||||
mock_query = MagicMock()
|
||||
mock_query.id = uuid.uuid4()
|
||||
mock_log_query.return_value = mock_query
|
||||
|
||||
# Mock document IDs
|
||||
doc_id1 = uuid.uuid4()
|
||||
doc_id2 = uuid.uuid4()
|
||||
|
||||
# Mock search results
|
||||
search_results = [
|
||||
{"document_id": str(doc_id1), "content": "Result 1"},
|
||||
{"document_id": str(doc_id2), "content": "Result 2"},
|
||||
]
|
||||
mock_specific_search.return_value = search_results
|
||||
|
||||
# Execute
|
||||
await search(query_text, query_type, datasets, mock_user)
|
||||
|
||||
# Verify
|
||||
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
|
||||
mock_specific_search.assert_called_once_with(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=mock_user,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
system_prompt=None,
|
||||
top_k=10,
|
||||
node_type=NodeSet,
|
||||
node_name=None,
|
||||
save_interaction=False,
|
||||
last_k=None,
|
||||
only_context=False,
|
||||
)
|
||||
|
||||
# Verify result logging
|
||||
mock_log_result.assert_called_once()
|
||||
# Check that the first argument is the query ID
|
||||
assert mock_log_result.call_args[0][0] == mock_query.id
|
||||
# The second argument should be the JSON string of the filtered results
|
||||
# We can't directly compare the JSON strings due to potential ordering differences
|
||||
# So we parse the JSON and compare the objects
|
||||
logged_results = json.loads(mock_log_result.call_args[0][1])
|
||||
assert len(logged_results) == 2
|
||||
assert logged_results[0]["document_id"] == str(doc_id1)
|
||||
assert logged_results[1]["document_id"] == str(doc_id2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(search_module, "SummariesRetriever")
|
||||
@patch.object(search_module, "send_telemetry")
|
||||
async def test_specific_search_summaries(mock_send_telemetry, mock_summaries_retriever, mock_user):
|
||||
# Setup
|
||||
query = "test query"
|
||||
query_type = SearchType.SUMMARIES
|
||||
|
||||
# Mock the retriever
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_completion = AsyncMock()
|
||||
mock_retriever.get_completion.return_value = [{"content": "Summary result"}]
|
||||
mock_summaries_retriever.return_value = mock_retriever
|
||||
|
||||
# Execute
|
||||
results = await specific_search(query_type, query, mock_user)
|
||||
|
||||
# Verify
|
||||
mock_summaries_retriever.assert_called_once()
|
||||
mock_retriever.get_completion.assert_called_once_with(query)
|
||||
mock_send_telemetry.assert_called()
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "Summary result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(search_module, "InsightsRetriever")
|
||||
@patch.object(search_module, "send_telemetry")
|
||||
async def test_specific_search_insights(mock_send_telemetry, mock_insights_retriever, mock_user):
|
||||
# Setup
|
||||
query = "test query"
|
||||
query_type = SearchType.INSIGHTS
|
||||
|
||||
# Mock the retriever
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_completion = AsyncMock()
|
||||
mock_retriever.get_completion.return_value = [{"content": "Insight result"}]
|
||||
mock_insights_retriever.return_value = mock_retriever
|
||||
|
||||
# Execute
|
||||
results = await specific_search(query_type, query, mock_user)
|
||||
|
||||
# Verify
|
||||
mock_insights_retriever.assert_called_once()
|
||||
mock_retriever.get_completion.assert_called_once_with(query)
|
||||
mock_send_telemetry.assert_called()
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "Insight result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(search_module, "ChunksRetriever")
|
||||
@patch.object(search_module, "send_telemetry")
|
||||
async def test_specific_search_chunks(mock_send_telemetry, mock_chunks_retriever, mock_user):
|
||||
# Setup
|
||||
query = "test query"
|
||||
query_type = SearchType.CHUNKS
|
||||
|
||||
# Mock the retriever
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.get_completion = AsyncMock()
|
||||
mock_retriever.get_completion.return_value = [{"content": "Chunk result"}]
|
||||
mock_chunks_retriever.return_value = mock_retriever
|
||||
|
||||
# Execute
|
||||
results = await specific_search(query_type, query, mock_user)
|
||||
|
||||
# Verify
|
||||
mock_chunks_retriever.assert_called_once()
|
||||
mock_retriever.get_completion.assert_called_once_with(query)
|
||||
mock_send_telemetry.assert_called()
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "Chunk result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"selected_type, retriever_name, expected_content, top_k",
|
||||
[
|
||||
(SearchType.RAG_COMPLETION, "CompletionRetriever", "RAG result from lucky search", 10),
|
||||
(SearchType.CHUNKS, "ChunksRetriever", "Chunk result from lucky search", 5),
|
||||
(SearchType.SUMMARIES, "SummariesRetriever", "Summary from lucky search", 15),
|
||||
(SearchType.INSIGHTS, "InsightsRetriever", "Insight result from lucky search", 20),
|
||||
],
|
||||
)
|
||||
@patch.object(search_module, "select_search_type")
|
||||
@patch.object(search_module, "send_telemetry")
|
||||
async def test_specific_search_feeling_lucky(
|
||||
mock_send_telemetry,
|
||||
mock_select_search_type,
|
||||
selected_type,
|
||||
retriever_name,
|
||||
expected_content,
|
||||
top_k,
|
||||
mock_user,
|
||||
):
|
||||
with patch.object(search_module, retriever_name) as mock_retriever_class:
|
||||
# Setup
|
||||
query = f"test query for {retriever_name}"
|
||||
query_type = SearchType.FEELING_LUCKY
|
||||
|
||||
# Mock the intelligent search type selection
|
||||
mock_select_search_type.return_value = selected_type
|
||||
|
||||
# Mock the retriever
|
||||
mock_retriever_instance = MagicMock()
|
||||
mock_retriever_instance.get_completion = AsyncMock(
|
||||
return_value=[{"content": expected_content}]
|
||||
)
|
||||
mock_retriever_class.return_value = mock_retriever_instance
|
||||
|
||||
# Execute
|
||||
results = await specific_search(query_type, query, mock_user, top_k=top_k)
|
||||
|
||||
# Verify
|
||||
mock_select_search_type.assert_called_once_with(query)
|
||||
|
||||
if retriever_name == "CompletionRetriever":
|
||||
mock_retriever_class.assert_called_once_with(
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k=top_k,
|
||||
system_prompt=None,
|
||||
only_context=None,
|
||||
)
|
||||
else:
|
||||
mock_retriever_class.assert_called_once_with(top_k=top_k)
|
||||
|
||||
mock_retriever_instance.get_completion.assert_called_once_with(query)
|
||||
mock_send_telemetry.assert_called()
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == expected_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_specific_search_invalid_type(mock_user):
|
||||
# Setup
|
||||
query = "test query"
|
||||
query_type = "INVALID_TYPE" # Not a valid SearchType
|
||||
|
||||
# Execute and verify
|
||||
with pytest.raises(UnsupportedSearchTypeError) as excinfo:
|
||||
await specific_search(query_type, query, mock_user)
|
||||
|
||||
assert "Unsupported search type" in str(excinfo.value)
|
||||
|
|
@ -47,6 +47,7 @@ async def main():
|
|||
query = "When was Kamala Harris in office?"
|
||||
triplets = await brute_force_triplet_search(
|
||||
query=query,
|
||||
user=user,
|
||||
top_k=3,
|
||||
collections=["graphitinode_content", "graphitinode_name", "graphitinode_summary"],
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue