Merge branch 'dev' into configurable-path-exclusion-code-graph
This commit is contained in:
commit
4ee807579b
12 changed files with 167 additions and 58 deletions
|
|
@ -38,7 +38,7 @@ class CognifyPayloadDTO(InDTO):
|
|||
dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]])
|
||||
run_in_background: Optional[bool] = Field(default=False)
|
||||
custom_prompt: Optional[str] = Field(
|
||||
default=None, description="Custom prompt for entity extraction and graph generation"
|
||||
default="", description="Custom prompt for entity extraction and graph generation"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from uuid import UUID
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from pydantic import Field
|
||||
from fastapi import Depends, APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.api.DTO import InDTO, OutDTO
|
||||
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError
|
||||
|
|
@ -20,8 +22,12 @@ class SearchPayloadDTO(InDTO):
|
|||
datasets: Optional[list[str]] = Field(default=None)
|
||||
dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]])
|
||||
query: str = Field(default="What is in the document?")
|
||||
system_prompt: Optional[str] = Field(
|
||||
default="Answer the question using the provided context. Be as brief as possible."
|
||||
)
|
||||
node_name: Optional[list[str]] = Field(default=None, example=[])
|
||||
top_k: Optional[int] = Field(default=10)
|
||||
only_context: bool = Field(default=False)
|
||||
|
||||
|
||||
def get_search_router() -> APIRouter:
|
||||
|
|
@ -80,8 +86,10 @@ def get_search_router() -> APIRouter:
|
|||
- **datasets** (Optional[List[str]]): List of dataset names to search within
|
||||
- **dataset_ids** (Optional[List[UUID]]): List of dataset UUIDs to search within
|
||||
- **query** (str): The search query string
|
||||
- **system_prompt** Optional[str]: System prompt to be used for Completion type searches in Cognee
|
||||
- **node_name** Optional[list[str]]: Filter results to specific node_sets defined in the add pipeline (for targeted search).
|
||||
- **top_k** (Optional[int]): Maximum number of results to return (default: 10)
|
||||
- **only_context** bool: Set to true to only return context Cognee will be sending to LLM in Completion type searches. This will be returned instead of LLM calls for completion type searches.
|
||||
|
||||
## Response
|
||||
Returns a list of search results containing relevant nodes from the graph.
|
||||
|
|
@ -104,8 +112,10 @@ def get_search_router() -> APIRouter:
|
|||
"datasets": payload.datasets,
|
||||
"dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []],
|
||||
"query": payload.query,
|
||||
"system_prompt": payload.system_prompt,
|
||||
"node_name": payload.node_name,
|
||||
"top_k": payload.top_k,
|
||||
"only_context": payload.only_context,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -118,8 +128,10 @@ def get_search_router() -> APIRouter:
|
|||
user=user,
|
||||
datasets=payload.datasets,
|
||||
dataset_ids=payload.dataset_ids,
|
||||
system_prompt=payload.system_prompt,
|
||||
node_name=payload.node_name,
|
||||
top_k=payload.top_k,
|
||||
only_context=payload.only_context,
|
||||
)
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -17,11 +17,13 @@ async def search(
|
|||
datasets: Optional[Union[list[str], str]] = None,
|
||||
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
) -> list:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
|
@ -184,11 +186,13 @@ async def search(
|
|||
dataset_ids=dataset_ids if dataset_ids else datasets,
|
||||
user=user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
|
|
@ -23,12 +23,16 @@ class CompletionRetriever(BaseRetriever):
|
|||
self,
|
||||
user_prompt_path: str = "context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: str = None,
|
||||
top_k: Optional[int] = 1,
|
||||
only_context: bool = False,
|
||||
):
|
||||
"""Initialize retriever with optional custom prompt paths."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k if top_k is not None else 1
|
||||
self.system_prompt = system_prompt
|
||||
self.only_context = only_context
|
||||
|
||||
async def get_context(self, query: str) -> str:
|
||||
"""
|
||||
|
|
@ -88,6 +92,11 @@ class CompletionRetriever(BaseRetriever):
|
|||
context = await self.get_context(query)
|
||||
|
||||
completion = await generate_completion(
|
||||
query, context, self.user_prompt_path, self.system_prompt_path
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
only_context=self.only_context,
|
||||
)
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -26,10 +26,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
self,
|
||||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
only_context: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -38,10 +40,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
)
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, context_extension_rounds=4
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
context_extension_rounds=4,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Extends the context for a given query by retrieving related triplets and generating new
|
||||
|
|
@ -86,6 +93,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
|
||||
triplets += await self.get_triplets(completion)
|
||||
|
|
@ -112,6 +120,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
only_context=self.only_context,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
|
|
@ -119,4 +129,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
question=query, answer=completion, context=context, triplets=triplets
|
||||
)
|
||||
|
||||
return [completion]
|
||||
if self.only_context:
|
||||
return [context]
|
||||
else:
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -32,14 +32,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
|
||||
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
||||
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
||||
system_prompt: str = None,
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
only_context: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
|
|
@ -51,7 +55,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
self.followup_user_prompt_path = followup_user_prompt_path
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, max_iter=4
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
max_iter=4,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate completion responses based on a user query and contextual information.
|
||||
|
|
@ -92,6 +99,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||
if round_idx < max_iter:
|
||||
|
|
@ -128,4 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
question=query, answer=completion, context=context, triplets=triplets
|
||||
)
|
||||
|
||||
return [completion]
|
||||
if self.only_context:
|
||||
return [context]
|
||||
else:
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -36,15 +36,19 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
self,
|
||||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: str = None,
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
only_context: bool = False,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.save_interaction = save_interaction
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.system_prompt = system_prompt
|
||||
self.only_context = only_context
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
self.node_type = node_type
|
||||
self.node_name = node_name
|
||||
|
|
@ -151,7 +155,11 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
|
||||
return context, triplets
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Generates a completion using graph connections context based on a query.
|
||||
|
||||
|
|
@ -177,6 +185,8 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
only_context=self.only_context,
|
||||
)
|
||||
|
||||
if self.save_interaction and context and triplets and completion:
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
summarize_prompt_path: str = "summarize_search_results.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
|
|
@ -34,6 +35,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
self.summarize_prompt_path = summarize_prompt_path
|
||||
|
||||
|
|
@ -57,4 +59,4 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
- str: A summary string representing the content of the retrieved edges.
|
||||
"""
|
||||
direct_text = await super().resolve_edges_to_text(retrieved_edges)
|
||||
return await summarize_text(direct_text, self.summarize_prompt_path)
|
||||
return await summarize_text(direct_text, self.summarize_prompt_path, self.system_prompt)
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever):
|
|||
logger.info(f"Returning {len(summary_payloads)} summary payloads")
|
||||
return summary_payloads
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any:
|
||||
"""
|
||||
Generates a completion using summaries context.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from typing import Optional
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
|
||||
|
||||
|
|
@ -6,25 +7,35 @@ async def generate_completion(
|
|||
context: str,
|
||||
user_prompt_path: str,
|
||||
system_prompt_path: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
only_context: bool = False,
|
||||
) -> str:
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
args = {"question": query, "context": context}
|
||||
user_prompt = LLMGateway.render_prompt(user_prompt_path, args)
|
||||
system_prompt = LLMGateway.read_query_prompt(system_prompt_path)
|
||||
|
||||
return await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
system_prompt = (
|
||||
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
|
||||
)
|
||||
|
||||
if only_context:
|
||||
return context
|
||||
else:
|
||||
return await LLMGateway.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
|
||||
async def summarize_text(
|
||||
text: str,
|
||||
prompt_path: str = "summarize_search_results.txt",
|
||||
system_prompt_path: str = "summarize_search_results.txt",
|
||||
system_prompt: str = None,
|
||||
) -> str:
|
||||
"""Summarizes text using LLM with the specified prompt."""
|
||||
system_prompt = LLMGateway.read_query_prompt(prompt_path)
|
||||
system_prompt = (
|
||||
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
|
||||
)
|
||||
|
||||
return await LLMGateway.acreate_structured_output(
|
||||
text_input=text,
|
||||
|
|
|
|||
|
|
@ -38,11 +38,13 @@ async def search(
|
|||
dataset_ids: Union[list[UUID], None],
|
||||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
|
|
@ -62,30 +64,34 @@ async def search(
|
|||
# Use search function filtered by permissions if access control is enabled
|
||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
return await authorized_search(
|
||||
query_text=query_text,
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=user,
|
||||
dataset_ids=dataset_ids,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
)
|
||||
|
||||
query = await log_query(query_text, query_type.value, user.id)
|
||||
|
||||
search_results = await specific_search(
|
||||
query_type,
|
||||
query_text,
|
||||
user,
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
)
|
||||
|
||||
await log_result(
|
||||
|
|
@ -101,21 +107,26 @@ async def search(
|
|||
|
||||
async def specific_search(
|
||||
query_type: SearchType,
|
||||
query: str,
|
||||
query_text: str,
|
||||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = None,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
||||
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
|
||||
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
|
||||
SearchType.RAG_COMPLETION: CompletionRetriever(
|
||||
system_prompt_path=system_prompt_path, top_k=top_k
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -123,6 +134,8 @@ async def specific_search(
|
|||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -130,6 +143,8 @@ async def specific_search(
|
|||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -137,6 +152,8 @@ async def specific_search(
|
|||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
only_context=only_context,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -144,6 +161,7 @@ async def specific_search(
|
|||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
).get_completion,
|
||||
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
|
||||
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
||||
|
|
@ -153,7 +171,7 @@ async def specific_search(
|
|||
|
||||
# If the query type is FEELING_LUCKY, select the search type intelligently
|
||||
if query_type is SearchType.FEELING_LUCKY:
|
||||
query_type = await select_search_type(query)
|
||||
query_type = await select_search_type(query_text)
|
||||
|
||||
search_task = search_tasks.get(query_type)
|
||||
|
||||
|
|
@ -162,7 +180,7 @@ async def specific_search(
|
|||
|
||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||
|
||||
results = await search_task(query)
|
||||
results = await search_task(query_text)
|
||||
|
||||
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
||||
|
||||
|
|
@ -170,16 +188,18 @@ async def specific_search(
|
|||
|
||||
|
||||
async def authorized_search(
|
||||
query_text: str,
|
||||
query_type: SearchType,
|
||||
user: User = None,
|
||||
query_text: str,
|
||||
user: User,
|
||||
dataset_ids: Optional[list[UUID]] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
save_interaction: Optional[bool] = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = None,
|
||||
) -> list:
|
||||
"""
|
||||
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
||||
|
|
@ -193,16 +213,18 @@ async def authorized_search(
|
|||
|
||||
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
||||
search_results = await specific_search_by_context(
|
||||
search_datasets,
|
||||
query_text,
|
||||
query_type,
|
||||
user,
|
||||
system_prompt_path,
|
||||
top_k,
|
||||
search_datasets=search_datasets,
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
)
|
||||
|
||||
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
|
||||
|
|
@ -212,15 +234,17 @@ async def authorized_search(
|
|||
|
||||
async def specific_search_by_context(
|
||||
search_datasets: list[Dataset],
|
||||
query_text: str,
|
||||
query_type: SearchType,
|
||||
query_text: str,
|
||||
user: User,
|
||||
system_prompt_path: str,
|
||||
top_k: int,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
save_interaction: Optional[bool] = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = None,
|
||||
):
|
||||
"""
|
||||
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
||||
|
|
@ -228,28 +252,33 @@ async def specific_search_by_context(
|
|||
"""
|
||||
|
||||
async def _search_by_context(
|
||||
dataset,
|
||||
user,
|
||||
query_type,
|
||||
query_text,
|
||||
system_prompt_path,
|
||||
top_k,
|
||||
dataset: Dataset,
|
||||
query_type: SearchType,
|
||||
query_text: str,
|
||||
user: User,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = None,
|
||||
):
|
||||
# Set database configuration in async context for each dataset user has access for
|
||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||
search_results = await specific_search(
|
||||
query_type,
|
||||
query_text,
|
||||
user,
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
)
|
||||
return {
|
||||
"search_result": search_results,
|
||||
|
|
@ -262,15 +291,18 @@ async def specific_search_by_context(
|
|||
for dataset in search_datasets:
|
||||
tasks.append(
|
||||
_search_by_context(
|
||||
dataset,
|
||||
user,
|
||||
query_type,
|
||||
query_text,
|
||||
system_prompt_path,
|
||||
top_k,
|
||||
dataset=dataset,
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -58,15 +58,17 @@ async def test_search(
|
|||
# Verify
|
||||
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
|
||||
mock_specific_search.assert_called_once_with(
|
||||
query_type,
|
||||
query_text,
|
||||
mock_user,
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
user=mock_user,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
system_prompt=None,
|
||||
top_k=10,
|
||||
node_type=NodeSet,
|
||||
node_name=None,
|
||||
save_interaction=False,
|
||||
last_k=None,
|
||||
only_context=False,
|
||||
)
|
||||
|
||||
# Verify result logging
|
||||
|
|
@ -201,7 +203,10 @@ async def test_specific_search_feeling_lucky(
|
|||
|
||||
if retriever_name == "CompletionRetriever":
|
||||
mock_retriever_class.assert_called_once_with(
|
||||
system_prompt_path="answer_simple_question.txt", top_k=top_k
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k=top_k,
|
||||
system_prompt=None,
|
||||
only_context=None,
|
||||
)
|
||||
else:
|
||||
mock_retriever_class.assert_called_once_with(top_k=top_k)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue