Merge branch 'dev' into feature/cog-2746-time-graph-to-cognify
This commit is contained in:
commit
4e9c0810c2
17 changed files with 274 additions and 115 deletions
|
|
@ -25,6 +25,7 @@ def get_add_router() -> APIRouter:
|
||||||
data: List[UploadFile] = File(default=None),
|
data: List[UploadFile] = File(default=None),
|
||||||
datasetName: Optional[str] = Form(default=None),
|
datasetName: Optional[str] = Form(default=None),
|
||||||
datasetId: Union[UUID, Literal[""], None] = Form(default=None, examples=[""]),
|
datasetId: Union[UUID, Literal[""], None] = Form(default=None, examples=[""]),
|
||||||
|
node_set: Optional[List[str]] = Form(default=[""], example=[""]),
|
||||||
user: User = Depends(get_authenticated_user),
|
user: User = Depends(get_authenticated_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -41,6 +42,8 @@ def get_add_router() -> APIRouter:
|
||||||
- Regular file uploads
|
- Regular file uploads
|
||||||
- **datasetName** (Optional[str]): Name of the dataset to add data to
|
- **datasetName** (Optional[str]): Name of the dataset to add data to
|
||||||
- **datasetId** (Optional[UUID]): UUID of an already existing dataset
|
- **datasetId** (Optional[UUID]): UUID of an already existing dataset
|
||||||
|
- **node_set** Optional[list[str]]: List of node identifiers for graph organization and access control.
|
||||||
|
Used for grouping related data points in the knowledge graph.
|
||||||
|
|
||||||
Either datasetName or datasetId must be provided.
|
Either datasetName or datasetId must be provided.
|
||||||
|
|
||||||
|
|
@ -65,9 +68,7 @@ def get_add_router() -> APIRouter:
|
||||||
send_telemetry(
|
send_telemetry(
|
||||||
"Add API Endpoint Invoked",
|
"Add API Endpoint Invoked",
|
||||||
user.id,
|
user.id,
|
||||||
additional_properties={
|
additional_properties={"endpoint": "POST /v1/add", "node_set": node_set},
|
||||||
"endpoint": "POST /v1/add",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from cognee.api.v1.add import add as cognee_add
|
from cognee.api.v1.add import add as cognee_add
|
||||||
|
|
@ -76,34 +77,13 @@ def get_add_router() -> APIRouter:
|
||||||
raise ValueError("Either datasetId or datasetName must be provided.")
|
raise ValueError("Either datasetId or datasetName must be provided.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if (
|
add_run = await cognee_add(
|
||||||
isinstance(data, str)
|
data, datasetName, user=user, dataset_id=datasetId, node_set=node_set
|
||||||
and data.startswith("http")
|
)
|
||||||
and (os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true")
|
|
||||||
):
|
|
||||||
if "github" in data:
|
|
||||||
# Perform git clone if the URL is from GitHub
|
|
||||||
repo_name = data.split("/")[-1].replace(".git", "")
|
|
||||||
subprocess.run(["git", "clone", data, f".data/{repo_name}"], check=True)
|
|
||||||
# TODO: Update add call with dataset info
|
|
||||||
await cognee_add(
|
|
||||||
"data://.data/",
|
|
||||||
f"{repo_name}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Fetch and store the data from other types of URL using curl
|
|
||||||
response = requests.get(data)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
file_data = await response.content()
|
if isinstance(add_run, PipelineRunErrored):
|
||||||
# TODO: Update add call with dataset info
|
return JSONResponse(status_code=420, content=add_run.model_dump(mode="json"))
|
||||||
return await cognee_add(file_data)
|
return add_run.model_dump()
|
||||||
else:
|
|
||||||
add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId)
|
|
||||||
|
|
||||||
if isinstance(add_run, PipelineRunErrored):
|
|
||||||
return JSONResponse(status_code=420, content=add_run.model_dump(mode="json"))
|
|
||||||
return add_run.model_dump()
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||||
from cognee.modules.observability.get_observe import get_observe
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
|
|
||||||
|
|
@ -28,7 +29,12 @@ logger = get_logger("code_graph_pipeline")
|
||||||
|
|
||||||
|
|
||||||
@observe
|
@observe
|
||||||
async def run_code_graph_pipeline(repo_path, include_docs=False):
|
async def run_code_graph_pipeline(
|
||||||
|
repo_path,
|
||||||
|
include_docs=False,
|
||||||
|
excluded_paths: Optional[list[str]] = None,
|
||||||
|
supported_languages: Optional[list[str]] = None,
|
||||||
|
):
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.low_level import setup
|
from cognee.low_level import setup
|
||||||
|
|
||||||
|
|
@ -40,13 +46,12 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
detailed_extraction = True
|
detailed_extraction = True
|
||||||
|
|
||||||
# Multi-language support: allow passing supported_languages
|
|
||||||
supported_languages = None # defer to task defaults
|
|
||||||
tasks = [
|
tasks = [
|
||||||
Task(
|
Task(
|
||||||
get_repo_file_dependencies,
|
get_repo_file_dependencies,
|
||||||
detailed_extraction=detailed_extraction,
|
detailed_extraction=detailed_extraction,
|
||||||
supported_languages=supported_languages,
|
supported_languages=supported_languages,
|
||||||
|
excluded_paths=excluded_paths,
|
||||||
),
|
),
|
||||||
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
|
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
|
||||||
Task(add_data_points, task_config={"batch_size": 30}),
|
Task(add_data_points, task_config={"batch_size": 30}),
|
||||||
|
|
@ -95,7 +100,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
async for run_status in run_code_graph_pipeline("REPO_PATH"):
|
async for run_status in run_code_graph_pipeline("REPO_PATH"):
|
||||||
print(f"{run_status.pipeline_name}: {run_status.status}")
|
print(f"{run_status.pipeline_run_id}: {run_status.status}")
|
||||||
|
|
||||||
file_path = os.path.join(
|
file_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
|
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ class CognifyPayloadDTO(InDTO):
|
||||||
dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]])
|
dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]])
|
||||||
run_in_background: Optional[bool] = Field(default=False)
|
run_in_background: Optional[bool] = Field(default=False)
|
||||||
custom_prompt: Optional[str] = Field(
|
custom_prompt: Optional[str] = Field(
|
||||||
default=None, description="Custom prompt for entity extraction and graph generation"
|
default="", description="Custom prompt for entity extraction and graph generation"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
import pathlib
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from fastapi import Depends, APIRouter
|
from fastapi import Depends, APIRouter
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.api.DTO import InDTO, OutDTO
|
from cognee.api.DTO import InDTO, OutDTO
|
||||||
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError
|
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError
|
||||||
|
|
@ -20,7 +22,12 @@ class SearchPayloadDTO(InDTO):
|
||||||
datasets: Optional[list[str]] = Field(default=None)
|
datasets: Optional[list[str]] = Field(default=None)
|
||||||
dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]])
|
dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]])
|
||||||
query: str = Field(default="What is in the document?")
|
query: str = Field(default="What is in the document?")
|
||||||
|
system_prompt: Optional[str] = Field(
|
||||||
|
default="Answer the question using the provided context. Be as brief as possible."
|
||||||
|
)
|
||||||
|
node_name: Optional[list[str]] = Field(default=None, example=[])
|
||||||
top_k: Optional[int] = Field(default=10)
|
top_k: Optional[int] = Field(default=10)
|
||||||
|
only_context: bool = Field(default=False)
|
||||||
|
|
||||||
|
|
||||||
def get_search_router() -> APIRouter:
|
def get_search_router() -> APIRouter:
|
||||||
|
|
@ -79,7 +86,10 @@ def get_search_router() -> APIRouter:
|
||||||
- **datasets** (Optional[List[str]]): List of dataset names to search within
|
- **datasets** (Optional[List[str]]): List of dataset names to search within
|
||||||
- **dataset_ids** (Optional[List[UUID]]): List of dataset UUIDs to search within
|
- **dataset_ids** (Optional[List[UUID]]): List of dataset UUIDs to search within
|
||||||
- **query** (str): The search query string
|
- **query** (str): The search query string
|
||||||
|
- **system_prompt** Optional[str]: System prompt to be used for Completion type searches in Cognee
|
||||||
|
- **node_name** Optional[list[str]]: Filter results to specific node_sets defined in the add pipeline (for targeted search).
|
||||||
- **top_k** (Optional[int]): Maximum number of results to return (default: 10)
|
- **top_k** (Optional[int]): Maximum number of results to return (default: 10)
|
||||||
|
- **only_context** bool: Set to true to only return context Cognee will be sending to LLM in Completion type searches. This will be returned instead of LLM calls for completion type searches.
|
||||||
|
|
||||||
## Response
|
## Response
|
||||||
Returns a list of search results containing relevant nodes from the graph.
|
Returns a list of search results containing relevant nodes from the graph.
|
||||||
|
|
@ -102,7 +112,10 @@ def get_search_router() -> APIRouter:
|
||||||
"datasets": payload.datasets,
|
"datasets": payload.datasets,
|
||||||
"dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []],
|
"dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []],
|
||||||
"query": payload.query,
|
"query": payload.query,
|
||||||
|
"system_prompt": payload.system_prompt,
|
||||||
|
"node_name": payload.node_name,
|
||||||
"top_k": payload.top_k,
|
"top_k": payload.top_k,
|
||||||
|
"only_context": payload.only_context,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -115,7 +128,10 @@ def get_search_router() -> APIRouter:
|
||||||
user=user,
|
user=user,
|
||||||
datasets=payload.datasets,
|
datasets=payload.datasets,
|
||||||
dataset_ids=payload.dataset_ids,
|
dataset_ids=payload.dataset_ids,
|
||||||
|
system_prompt=payload.system_prompt,
|
||||||
|
node_name=payload.node_name,
|
||||||
top_k=payload.top_k,
|
top_k=payload.top_k,
|
||||||
|
only_context=payload.only_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Union, Optional, List, Type
|
from typing import Union, Optional, List, Type
|
||||||
|
|
||||||
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
|
@ -16,11 +17,13 @@ async def search(
|
||||||
datasets: Optional[Union[list[str], str]] = None,
|
datasets: Optional[Union[list[str], str]] = None,
|
||||||
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
|
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = NodeSet,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
|
only_context: bool = False,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Search and query the knowledge graph for insights, information, and connections.
|
Search and query the knowledge graph for insights, information, and connections.
|
||||||
|
|
@ -183,11 +186,13 @@ async def search(
|
||||||
dataset_ids=dataset_ids if dataset_ids else datasets,
|
dataset_ids=dataset_ids if dataset_ids else datasets,
|
||||||
user=user,
|
user=user,
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
|
system_prompt=system_prompt,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
|
only_context=only_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
return filtered_search_results
|
return filtered_search_results
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Determine projection strategy
|
# Determine projection strategy
|
||||||
if node_type is not None and node_name is not None:
|
if node_type is not None and node_name not in [None, []]:
|
||||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||||
node_type=node_type, node_name=node_name
|
node_type=node_type, node_name=node_name
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -94,7 +94,15 @@ class CodeRetriever(BaseRetriever):
|
||||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
existing_collection = []
|
||||||
for collection in self.classes_and_functions_collections:
|
for collection in self.classes_and_functions_collections:
|
||||||
|
if await vector_engine.has_collection(collection):
|
||||||
|
existing_collection.append(collection)
|
||||||
|
|
||||||
|
if not existing_collection:
|
||||||
|
raise RuntimeError("No collection found for code retriever")
|
||||||
|
|
||||||
|
for collection in existing_collection:
|
||||||
logger.debug(f"Searching {collection} collection with general query")
|
logger.debug(f"Searching {collection} collection with general query")
|
||||||
search_results_code = await vector_engine.search(
|
search_results_code = await vector_engine.search(
|
||||||
collection, query, limit=self.top_k
|
collection, query, limit=self.top_k
|
||||||
|
|
|
||||||
|
|
@ -23,12 +23,16 @@ class CompletionRetriever(BaseRetriever):
|
||||||
self,
|
self,
|
||||||
user_prompt_path: str = "context_for_question.txt",
|
user_prompt_path: str = "context_for_question.txt",
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
system_prompt: str = None,
|
||||||
top_k: Optional[int] = 1,
|
top_k: Optional[int] = 1,
|
||||||
|
only_context: bool = False,
|
||||||
):
|
):
|
||||||
"""Initialize retriever with optional custom prompt paths."""
|
"""Initialize retriever with optional custom prompt paths."""
|
||||||
self.user_prompt_path = user_prompt_path
|
self.user_prompt_path = user_prompt_path
|
||||||
self.system_prompt_path = system_prompt_path
|
self.system_prompt_path = system_prompt_path
|
||||||
self.top_k = top_k if top_k is not None else 1
|
self.top_k = top_k if top_k is not None else 1
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.only_context = only_context
|
||||||
|
|
||||||
async def get_context(self, query: str) -> str:
|
async def get_context(self, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -88,6 +92,11 @@ class CompletionRetriever(BaseRetriever):
|
||||||
context = await self.get_context(query)
|
context = await self.get_context(query)
|
||||||
|
|
||||||
completion = await generate_completion(
|
completion = await generate_completion(
|
||||||
query, context, self.user_prompt_path, self.system_prompt_path
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
only_context=self.only_context,
|
||||||
)
|
)
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -26,10 +26,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
self,
|
self,
|
||||||
user_prompt_path: str = "graph_context_for_question.txt",
|
user_prompt_path: str = "graph_context_for_question.txt",
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
top_k: Optional[int] = 5,
|
top_k: Optional[int] = 5,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = None,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
|
only_context: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
user_prompt_path=user_prompt_path,
|
user_prompt_path=user_prompt_path,
|
||||||
|
|
@ -38,10 +40,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
only_context=only_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self, query: str, context: Optional[Any] = None, context_extension_rounds=4
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[Any] = None,
|
||||||
|
context_extension_rounds=4,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Extends the context for a given query by retrieving related triplets and generating new
|
Extends the context for a given query by retrieving related triplets and generating new
|
||||||
|
|
@ -86,6 +93,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
triplets += await self.get_triplets(completion)
|
triplets += await self.get_triplets(completion)
|
||||||
|
|
@ -112,6 +120,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
only_context=self.only_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context and triplets and completion:
|
||||||
|
|
@ -119,4 +129,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
||||||
question=query, answer=completion, context=context, triplets=triplets
|
question=query, answer=completion, context=context, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
return [completion]
|
if self.only_context:
|
||||||
|
return [context]
|
||||||
|
else:
|
||||||
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -32,14 +32,18 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
|
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
|
||||||
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
|
||||||
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
|
||||||
|
system_prompt: str = None,
|
||||||
top_k: Optional[int] = 5,
|
top_k: Optional[int] = 5,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = None,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
|
only_context: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
user_prompt_path=user_prompt_path,
|
user_prompt_path=user_prompt_path,
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
only_context=only_context,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
|
|
@ -51,7 +55,10 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
self.followup_user_prompt_path = followup_user_prompt_path
|
self.followup_user_prompt_path = followup_user_prompt_path
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self, query: str, context: Optional[Any] = None, max_iter=4
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[Any] = None,
|
||||||
|
max_iter=4,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Generate completion responses based on a user query and contextual information.
|
Generate completion responses based on a user query and contextual information.
|
||||||
|
|
@ -92,6 +99,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
)
|
)
|
||||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||||
if round_idx < max_iter:
|
if round_idx < max_iter:
|
||||||
|
|
@ -128,4 +136,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
question=query, answer=completion, context=context, triplets=triplets
|
question=query, answer=completion, context=context, triplets=triplets
|
||||||
)
|
)
|
||||||
|
|
||||||
return [completion]
|
if self.only_context:
|
||||||
|
return [context]
|
||||||
|
else:
|
||||||
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -36,15 +36,19 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
self,
|
self,
|
||||||
user_prompt_path: str = "graph_context_for_question.txt",
|
user_prompt_path: str = "graph_context_for_question.txt",
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
system_prompt: str = None,
|
||||||
top_k: Optional[int] = 5,
|
top_k: Optional[int] = 5,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = None,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
|
only_context: bool = False,
|
||||||
):
|
):
|
||||||
"""Initialize retriever with prompt paths and search parameters."""
|
"""Initialize retriever with prompt paths and search parameters."""
|
||||||
self.save_interaction = save_interaction
|
self.save_interaction = save_interaction
|
||||||
self.user_prompt_path = user_prompt_path
|
self.user_prompt_path = user_prompt_path
|
||||||
self.system_prompt_path = system_prompt_path
|
self.system_prompt_path = system_prompt_path
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.only_context = only_context
|
||||||
self.top_k = top_k if top_k is not None else 5
|
self.top_k = top_k if top_k is not None else 5
|
||||||
self.node_type = node_type
|
self.node_type = node_type
|
||||||
self.node_name = node_name
|
self.node_name = node_name
|
||||||
|
|
@ -151,7 +155,11 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
|
|
||||||
return context, triplets
|
return context, triplets
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[Any] = None,
|
||||||
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Generates a completion using graph connections context based on a query.
|
Generates a completion using graph connections context based on a query.
|
||||||
|
|
||||||
|
|
@ -177,6 +185,8 @@ class GraphCompletionRetriever(BaseRetriever):
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt_path=self.user_prompt_path,
|
user_prompt_path=self.user_prompt_path,
|
||||||
system_prompt_path=self.system_prompt_path,
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
only_context=self.only_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
if self.save_interaction and context and triplets and completion:
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
||||||
user_prompt_path: str = "graph_context_for_question.txt",
|
user_prompt_path: str = "graph_context_for_question.txt",
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
summarize_prompt_path: str = "summarize_search_results.txt",
|
summarize_prompt_path: str = "summarize_search_results.txt",
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
top_k: Optional[int] = 5,
|
top_k: Optional[int] = 5,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = None,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
|
|
@ -34,6 +35,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
self.summarize_prompt_path = summarize_prompt_path
|
self.summarize_prompt_path = summarize_prompt_path
|
||||||
|
|
||||||
|
|
@ -57,4 +59,4 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
||||||
- str: A summary string representing the content of the retrieved edges.
|
- str: A summary string representing the content of the retrieved edges.
|
||||||
"""
|
"""
|
||||||
direct_text = await super().resolve_edges_to_text(retrieved_edges)
|
direct_text = await super().resolve_edges_to_text(retrieved_edges)
|
||||||
return await summarize_text(direct_text, self.summarize_prompt_path)
|
return await summarize_text(direct_text, self.summarize_prompt_path, self.system_prompt)
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ class SummariesRetriever(BaseRetriever):
|
||||||
logger.info(f"Returning {len(summary_payloads)} summary payloads")
|
logger.info(f"Returning {len(summary_payloads)} summary payloads")
|
||||||
return summary_payloads
|
return summary_payloads
|
||||||
|
|
||||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
async def get_completion(self, query: str, context: Optional[Any] = None, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
Generates a completion using summaries context.
|
Generates a completion using summaries context.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import Optional
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -6,25 +7,35 @@ async def generate_completion(
|
||||||
context: str,
|
context: str,
|
||||||
user_prompt_path: str,
|
user_prompt_path: str,
|
||||||
system_prompt_path: str,
|
system_prompt_path: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
only_context: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a completion using LLM with given context and prompts."""
|
"""Generates a completion using LLM with given context and prompts."""
|
||||||
args = {"question": query, "context": context}
|
args = {"question": query, "context": context}
|
||||||
user_prompt = LLMGateway.render_prompt(user_prompt_path, args)
|
user_prompt = LLMGateway.render_prompt(user_prompt_path, args)
|
||||||
system_prompt = LLMGateway.read_query_prompt(system_prompt_path)
|
system_prompt = (
|
||||||
|
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
|
||||||
return await LLMGateway.acreate_structured_output(
|
|
||||||
text_input=user_prompt,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
response_model=str,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if only_context:
|
||||||
|
return context
|
||||||
|
else:
|
||||||
|
return await LLMGateway.acreate_structured_output(
|
||||||
|
text_input=user_prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def summarize_text(
|
async def summarize_text(
|
||||||
text: str,
|
text: str,
|
||||||
prompt_path: str = "summarize_search_results.txt",
|
system_prompt_path: str = "summarize_search_results.txt",
|
||||||
|
system_prompt: str = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Summarizes text using LLM with the specified prompt."""
|
"""Summarizes text using LLM with the specified prompt."""
|
||||||
system_prompt = LLMGateway.read_query_prompt(prompt_path)
|
system_prompt = (
|
||||||
|
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
|
||||||
|
)
|
||||||
|
|
||||||
return await LLMGateway.acreate_structured_output(
|
return await LLMGateway.acreate_structured_output(
|
||||||
text_input=text,
|
text_input=text,
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import asyncio
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Callable, List, Optional, Type, Union
|
from typing import Callable, List, Optional, Type, Union
|
||||||
|
|
||||||
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
||||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||||
from cognee.context_global_variables import set_database_global_context_variables
|
from cognee.context_global_variables import set_database_global_context_variables
|
||||||
|
|
@ -38,11 +39,13 @@ async def search(
|
||||||
dataset_ids: Union[list[UUID], None],
|
dataset_ids: Union[list[UUID], None],
|
||||||
user: User,
|
user: User,
|
||||||
system_prompt_path="answer_simple_question.txt",
|
system_prompt_path="answer_simple_question.txt",
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = NodeSet,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: Optional[bool] = False,
|
save_interaction: Optional[bool] = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
|
only_context: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -62,28 +65,34 @@ async def search(
|
||||||
# Use search function filtered by permissions if access control is enabled
|
# Use search function filtered by permissions if access control is enabled
|
||||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||||
return await authorized_search(
|
return await authorized_search(
|
||||||
query_text=query_text,
|
|
||||||
query_type=query_type,
|
query_type=query_type,
|
||||||
|
query_text=query_text,
|
||||||
user=user,
|
user=user,
|
||||||
dataset_ids=dataset_ids,
|
dataset_ids=dataset_ids,
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
|
system_prompt=system_prompt,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
|
only_context=only_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
query = await log_query(query_text, query_type.value, user.id)
|
query = await log_query(query_text, query_type.value, user.id)
|
||||||
|
|
||||||
search_results = await specific_search(
|
search_results = await specific_search(
|
||||||
query_type,
|
query_type=query_type,
|
||||||
query_text,
|
query_text=query_text,
|
||||||
user,
|
user=user,
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
|
system_prompt=system_prompt,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
|
only_context=only_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
await log_result(
|
await log_result(
|
||||||
|
|
@ -99,21 +108,26 @@ async def search(
|
||||||
|
|
||||||
async def specific_search(
|
async def specific_search(
|
||||||
query_type: SearchType,
|
query_type: SearchType,
|
||||||
query: str,
|
query_text: str,
|
||||||
user: User,
|
user: User,
|
||||||
system_prompt_path="answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = NodeSet,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: Optional[bool] = False,
|
save_interaction: Optional[bool] = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
|
only_context: bool = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
search_tasks: dict[SearchType, Callable] = {
|
search_tasks: dict[SearchType, Callable] = {
|
||||||
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
||||||
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
|
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
|
||||||
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
|
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
|
||||||
SearchType.RAG_COMPLETION: CompletionRetriever(
|
SearchType.RAG_COMPLETION: CompletionRetriever(
|
||||||
system_prompt_path=system_prompt_path, top_k=top_k
|
system_prompt_path=system_prompt_path,
|
||||||
|
top_k=top_k,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
only_context=only_context,
|
||||||
).get_completion,
|
).get_completion,
|
||||||
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
|
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
|
|
@ -121,6 +135,8 @@ async def specific_search(
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
only_context=only_context,
|
||||||
).get_completion,
|
).get_completion,
|
||||||
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
|
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
|
|
@ -128,6 +144,8 @@ async def specific_search(
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
only_context=only_context,
|
||||||
).get_completion,
|
).get_completion,
|
||||||
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
|
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
|
|
@ -135,6 +153,8 @@ async def specific_search(
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
only_context=only_context,
|
||||||
).get_completion,
|
).get_completion,
|
||||||
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
|
|
@ -142,6 +162,7 @@ async def specific_search(
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
|
system_prompt=system_prompt,
|
||||||
).get_completion,
|
).get_completion,
|
||||||
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
|
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
|
||||||
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
||||||
|
|
@ -152,7 +173,7 @@ async def specific_search(
|
||||||
|
|
||||||
# If the query type is FEELING_LUCKY, select the search type intelligently
|
# If the query type is FEELING_LUCKY, select the search type intelligently
|
||||||
if query_type is SearchType.FEELING_LUCKY:
|
if query_type is SearchType.FEELING_LUCKY:
|
||||||
query_type = await select_search_type(query)
|
query_type = await select_search_type(query_text)
|
||||||
|
|
||||||
search_task = search_tasks.get(query_type)
|
search_task = search_tasks.get(query_type)
|
||||||
|
|
||||||
|
|
@ -161,7 +182,7 @@ async def specific_search(
|
||||||
|
|
||||||
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
send_telemetry("cognee.search EXECUTION STARTED", user.id)
|
||||||
|
|
||||||
results = await search_task(query)
|
results = await search_task(query_text)
|
||||||
|
|
||||||
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
|
||||||
|
|
||||||
|
|
@ -169,14 +190,18 @@ async def specific_search(
|
||||||
|
|
||||||
|
|
||||||
async def authorized_search(
|
async def authorized_search(
|
||||||
query_text: str,
|
|
||||||
query_type: SearchType,
|
query_type: SearchType,
|
||||||
user: User = None,
|
query_text: str,
|
||||||
|
user: User,
|
||||||
dataset_ids: Optional[list[UUID]] = None,
|
dataset_ids: Optional[list[UUID]] = None,
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
save_interaction: bool = False,
|
node_type: Optional[Type] = NodeSet,
|
||||||
|
node_name: Optional[List[str]] = None,
|
||||||
|
save_interaction: Optional[bool] = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
|
only_context: bool = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
||||||
|
|
@ -190,14 +215,18 @@ async def authorized_search(
|
||||||
|
|
||||||
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
||||||
search_results = await specific_search_by_context(
|
search_results = await specific_search_by_context(
|
||||||
search_datasets,
|
search_datasets=search_datasets,
|
||||||
query_text,
|
query_type=query_type,
|
||||||
query_type,
|
query_text=query_text,
|
||||||
user,
|
user=user,
|
||||||
system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
top_k,
|
system_prompt=system_prompt,
|
||||||
save_interaction,
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
|
only_context=only_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
|
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
|
||||||
|
|
@ -207,13 +236,17 @@ async def authorized_search(
|
||||||
|
|
||||||
async def specific_search_by_context(
|
async def specific_search_by_context(
|
||||||
search_datasets: list[Dataset],
|
search_datasets: list[Dataset],
|
||||||
query_text: str,
|
|
||||||
query_type: SearchType,
|
query_type: SearchType,
|
||||||
|
query_text: str,
|
||||||
user: User,
|
user: User,
|
||||||
system_prompt_path: str,
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
top_k: int,
|
system_prompt: Optional[str] = None,
|
||||||
save_interaction: bool = False,
|
top_k: int = 10,
|
||||||
|
node_type: Optional[Type] = NodeSet,
|
||||||
|
node_name: Optional[List[str]] = None,
|
||||||
|
save_interaction: Optional[bool] = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
|
only_context: bool = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
||||||
|
|
@ -221,18 +254,33 @@ async def specific_search_by_context(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _search_by_context(
|
async def _search_by_context(
|
||||||
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
|
dataset: Dataset,
|
||||||
|
query_type: SearchType,
|
||||||
|
query_text: str,
|
||||||
|
user: User,
|
||||||
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
node_type: Optional[Type] = NodeSet,
|
||||||
|
node_name: Optional[List[str]] = None,
|
||||||
|
save_interaction: Optional[bool] = False,
|
||||||
|
last_k: Optional[int] = None,
|
||||||
|
only_context: bool = None,
|
||||||
):
|
):
|
||||||
# Set database configuration in async context for each dataset user has access for
|
# Set database configuration in async context for each dataset user has access for
|
||||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||||
search_results = await specific_search(
|
search_results = await specific_search(
|
||||||
query_type,
|
query_type=query_type,
|
||||||
query_text,
|
query_text=query_text,
|
||||||
user,
|
user=user,
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
|
system_prompt=system_prompt,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
|
only_context=only_context,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"search_result": search_results,
|
"search_result": search_results,
|
||||||
|
|
@ -245,7 +293,18 @@ async def specific_search_by_context(
|
||||||
for dataset in search_datasets:
|
for dataset in search_datasets:
|
||||||
tasks.append(
|
tasks.append(
|
||||||
_search_by_context(
|
_search_by_context(
|
||||||
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
|
dataset=dataset,
|
||||||
|
query_type=query_type,
|
||||||
|
query_text=query_text,
|
||||||
|
user=user,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
|
last_k=last_k,
|
||||||
|
only_context=only_context,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,48 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
# from concurrent.futures import ProcessPoolExecutor
|
from typing import Set
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, Optional, List
|
||||||
from uuid import NAMESPACE_OID, uuid5
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.shared.CodeGraphEntities import CodeFile, Repository
|
from cognee.shared.CodeGraphEntities import CodeFile, Repository
|
||||||
|
|
||||||
|
# constant, declared only once
|
||||||
|
EXCLUDED_DIRS: Set[str] = {
|
||||||
|
".venv",
|
||||||
|
"venv",
|
||||||
|
"env",
|
||||||
|
".env",
|
||||||
|
"site-packages",
|
||||||
|
"node_modules",
|
||||||
|
"dist",
|
||||||
|
"build",
|
||||||
|
".git",
|
||||||
|
"tests",
|
||||||
|
"test",
|
||||||
|
}
|
||||||
|
|
||||||
async def get_source_code_files(repo_path, language_config: dict[str, list[str]] | None = None):
|
|
||||||
|
async def get_source_code_files(
|
||||||
|
repo_path,
|
||||||
|
language_config: dict[str, list[str]] | None = None,
|
||||||
|
excluded_paths: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Retrieve source code files from the specified repository path for multiple languages.
|
Retrieve Python source code files from the specified repository path.
|
||||||
|
|
||||||
|
This function scans the given repository path for files that have the .py extension
|
||||||
|
while excluding test files and files within a virtual environment. It returns a list of
|
||||||
|
absolute paths to the source code files that are not empty.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
- repo_path: The file path to the repository to search for source files.
|
- repo_path: Root path of the repository to search
|
||||||
- language_config: dict mapping language names to file extensions, e.g.,
|
- language_config: dict mapping language names to file extensions, e.g.,
|
||||||
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
|
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
|
||||||
|
- excluded_paths: Optional list of path fragments or glob patterns to exclude
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
@ -54,28 +78,23 @@ async def get_source_code_files(repo_path, language_config: dict[str, list[str]]
|
||||||
lang = _get_language_from_extension(file, language_config)
|
lang = _get_language_from_extension(file, language_config)
|
||||||
if lang is None:
|
if lang is None:
|
||||||
continue
|
continue
|
||||||
# Exclude tests and common build/venv directories
|
# Exclude tests, common build/venv directories and files provided in exclude_paths
|
||||||
excluded_dirs = {
|
excluded_dirs = EXCLUDED_DIRS
|
||||||
".venv",
|
excluded_paths = {Path(p).resolve() for p in (excluded_paths or [])} # full paths
|
||||||
"venv",
|
|
||||||
"env",
|
root_path = Path(root).resolve()
|
||||||
".env",
|
root_parts = set(root_path.parts) # same as before
|
||||||
"site-packages",
|
|
||||||
"node_modules",
|
|
||||||
"dist",
|
|
||||||
"build",
|
|
||||||
".git",
|
|
||||||
"tests",
|
|
||||||
"test",
|
|
||||||
}
|
|
||||||
root_parts = set(os.path.normpath(root).split(os.sep))
|
|
||||||
base_name, _ext = os.path.splitext(file)
|
base_name, _ext = os.path.splitext(file)
|
||||||
if (
|
if (
|
||||||
base_name.startswith("test_")
|
base_name.startswith("test_")
|
||||||
or base_name.endswith("_test") # catches Go's *_test.go and similar
|
or base_name.endswith("_test")
|
||||||
or ".test." in file
|
or ".test." in file
|
||||||
or ".spec." in file
|
or ".spec." in file
|
||||||
or (excluded_dirs & root_parts)
|
or (excluded_dirs & root_parts) # name match
|
||||||
|
or any(
|
||||||
|
root_path.is_relative_to(p) # full-path match
|
||||||
|
for p in excluded_paths
|
||||||
|
)
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
file_path = os.path.abspath(os.path.join(root, file))
|
file_path = os.path.abspath(os.path.join(root, file))
|
||||||
|
|
@ -115,7 +134,10 @@ def run_coroutine(coroutine_func, *args, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
async def get_repo_file_dependencies(
|
async def get_repo_file_dependencies(
|
||||||
repo_path: str, detailed_extraction: bool = False, supported_languages: list = None
|
repo_path: str,
|
||||||
|
detailed_extraction: bool = False,
|
||||||
|
supported_languages: list = None,
|
||||||
|
excluded_paths: Optional[List[str]] = None,
|
||||||
) -> AsyncGenerator[DataPoint, None]:
|
) -> AsyncGenerator[DataPoint, None]:
|
||||||
"""
|
"""
|
||||||
Generate a dependency graph for source files (multi-language) in the given repository path.
|
Generate a dependency graph for source files (multi-language) in the given repository path.
|
||||||
|
|
@ -150,6 +172,7 @@ async def get_repo_file_dependencies(
|
||||||
"go": [".go"],
|
"go": [".go"],
|
||||||
"rust": [".rs"],
|
"rust": [".rs"],
|
||||||
"cpp": [".cpp", ".c", ".h", ".hpp"],
|
"cpp": [".cpp", ".c", ".h", ".hpp"],
|
||||||
|
"c": [".c", ".h"],
|
||||||
}
|
}
|
||||||
if supported_languages is not None:
|
if supported_languages is not None:
|
||||||
language_config = {
|
language_config = {
|
||||||
|
|
@ -158,7 +181,9 @@ async def get_repo_file_dependencies(
|
||||||
else:
|
else:
|
||||||
language_config = default_language_config
|
language_config = default_language_config
|
||||||
|
|
||||||
source_code_files = await get_source_code_files(repo_path, language_config=language_config)
|
source_code_files = await get_source_code_files(
|
||||||
|
repo_path, language_config=language_config, excluded_paths=excluded_paths
|
||||||
|
)
|
||||||
|
|
||||||
repo = Repository(
|
repo = Repository(
|
||||||
id=uuid5(NAMESPACE_OID, repo_path),
|
id=uuid5(NAMESPACE_OID, repo_path),
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@ import uuid
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pylint.checkers.utils import node_type
|
|
||||||
|
|
||||||
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||||
from cognee.modules.search.methods.search import search, specific_search
|
from cognee.modules.search.methods.search import search, specific_search
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
|
|
@ -58,15 +58,17 @@ async def test_search(
|
||||||
# Verify
|
# Verify
|
||||||
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
|
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
|
||||||
mock_specific_search.assert_called_once_with(
|
mock_specific_search.assert_called_once_with(
|
||||||
query_type,
|
query_type=query_type,
|
||||||
query_text,
|
query_text=query_text,
|
||||||
mock_user,
|
user=mock_user,
|
||||||
system_prompt_path="answer_simple_question.txt",
|
system_prompt_path="answer_simple_question.txt",
|
||||||
|
system_prompt=None,
|
||||||
top_k=10,
|
top_k=10,
|
||||||
node_type=None,
|
node_type=NodeSet,
|
||||||
node_name=None,
|
node_name=None,
|
||||||
save_interaction=False,
|
save_interaction=False,
|
||||||
last_k=None,
|
last_k=None,
|
||||||
|
only_context=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify result logging
|
# Verify result logging
|
||||||
|
|
@ -201,7 +203,10 @@ async def test_specific_search_feeling_lucky(
|
||||||
|
|
||||||
if retriever_name == "CompletionRetriever":
|
if retriever_name == "CompletionRetriever":
|
||||||
mock_retriever_class.assert_called_once_with(
|
mock_retriever_class.assert_called_once_with(
|
||||||
system_prompt_path="answer_simple_question.txt", top_k=top_k
|
system_prompt_path="answer_simple_question.txt",
|
||||||
|
top_k=top_k,
|
||||||
|
system_prompt=None,
|
||||||
|
only_context=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
mock_retriever_class.assert_called_once_with(top_k=top_k)
|
mock_retriever_class.assert_called_once_with(top_k=top_k)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue