chore: merge dev

This commit is contained in:
Andrej Milicevic 2026-01-14 15:54:11 +01:00
commit 2cdbc02b35
39 changed files with 7673 additions and 7387 deletions

View file

@ -76,7 +76,7 @@ git clone https://github.com/<your-github-username>/cognee.git
cd cognee
```
In case you are working on Vector and Graph Adapters
1. Fork the [**cognee**](https://github.com/topoteretes/cognee-community) repository
1. Fork the [**cognee-community**](https://github.com/topoteretes/cognee-community) repository
2. Clone your fork:
```shell
git clone https://github.com/<your-github-username>/cognee-community.git
@ -120,6 +120,21 @@ or
uv run python examples/python/simple_example.py
```
### Running Simple Example
Change .env.example into .env and provide your OPENAI_API_KEY as LLM_API_KEY
Make sure to run ```shell uv sync ``` in the root cloned folder or set up a virtual environment to run cognee
```shell
python cognee/cognee/examples/python/simple_example.py
```
or
```shell
uv run python cognee/cognee/examples/python/simple_example.py
```
## 4. 📤 Submitting Changes
1. Make sure that `pre-commit` and hooks are installed. See `Required tools` section for more information. Try executing `pre-commit run` if you are not sure.

View file

@ -126,6 +126,7 @@ Now, run a minimal pipeline:
```python
import cognee
import asyncio
from pprint import pprint
async def main():
@ -143,7 +144,7 @@ async def main():
# Display the results
for result in results:
print(result)
pprint(result)
if __name__ == '__main__':

File diff suppressed because it is too large Load diff

View file

@ -13,7 +13,7 @@
"classnames": "^2.5.1",
"culori": "^4.0.1",
"d3-force-3d": "^3.0.6",
"next": "^16.1.7",
"next": "^16.1.1",
"react": "^19.2.3",
"react-dom": "^19.2.3",
"react-force-graph-2d": "^1.27.1",

View file

@ -192,7 +192,7 @@ class CogneeClient:
with redirect_stdout(sys.stderr):
results = await self.cognee.search(
query_type=SearchType[query_type.upper()], query_text=query_text
query_type=SearchType[query_type.upper()], query_text=query_text, top_k=top_k
)
return results

View file

@ -316,7 +316,7 @@ async def save_interaction(data: str) -> list:
@mcp.tool()
async def search(search_query: str, search_type: str) -> list:
async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -389,6 +389,13 @@ async def search(search_query: str, search_type: str) -> list:
The search_type is case-insensitive and will be converted to uppercase.
top_k : int, optional
Maximum number of results to return (default: 10).
Controls the amount of context retrieved from the knowledge graph.
- Lower values (3-5): Faster, more focused results
- Higher values (10-20): More comprehensive, but slower and more context-heavy
Helps manage response size and context window usage in MCP clients.
Returns
-------
list
@ -425,13 +432,32 @@ async def search(search_query: str, search_type: str) -> list:
"""
async def search_task(search_query: str, search_type: str) -> str:
"""Search the knowledge graph"""
async def search_task(search_query: str, search_type: str, top_k: int) -> str:
"""
Internal task to execute knowledge graph search with result formatting.
Handles the actual search execution and formats results appropriately
for MCP clients based on the search type and execution mode (API vs direct).
Parameters
----------
search_query : str
The search query in natural language
search_type : str
Type of search to perform (GRAPH_COMPLETION, CHUNKS, etc.)
top_k : int
Maximum number of results to return
Returns
-------
str
Formatted search results as a string, with format depending on search_type
"""
# NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr):
search_results = await cognee_client.search(
query_text=search_query, query_type=search_type
query_text=search_query, query_type=search_type, top_k=top_k
)
# Handle different result formats based on API vs direct mode
@ -465,7 +491,7 @@ async def search(search_query: str, search_type: str) -> list:
else:
return str(search_results)
search_results = await search_task(search_query, search_type)
search_results = await search_task(search_query, search_type, top_k)
return [types.TextContent(type="text", text=search_results)]

View file

@ -6,7 +6,7 @@ from fastapi import Depends, APIRouter
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from cognee.modules.search.types import SearchType, SearchResult, CombinedSearchResult
from cognee.modules.search.types import SearchType, SearchResult
from cognee.api.DTO import InDTO, OutDTO
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError, UserNotFoundError
from cognee.modules.users.models import User
@ -31,7 +31,7 @@ class SearchPayloadDTO(InDTO):
node_name: Optional[list[str]] = Field(default=None, example=[])
top_k: Optional[int] = Field(default=10)
only_context: bool = Field(default=False)
use_combined_context: bool = Field(default=False)
verbose: bool = Field(default=False)
def get_search_router() -> APIRouter:
@ -74,7 +74,7 @@ def get_search_router() -> APIRouter:
except Exception as error:
return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
@router.post("", response_model=Union[List[SearchResult], List])
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Search for nodes in the graph database.
@ -118,7 +118,7 @@ def get_search_router() -> APIRouter:
"node_name": payload.node_name,
"top_k": payload.top_k,
"only_context": payload.only_context,
"use_combined_context": payload.use_combined_context,
"verbose": payload.verbose,
"cognee_version": cognee_version,
},
)
@ -135,8 +135,8 @@ def get_search_router() -> APIRouter:
system_prompt=payload.system_prompt,
node_name=payload.node_name,
top_k=payload.top_k,
verbose=payload.verbose,
only_context=payload.only_context,
use_combined_context=payload.use_combined_context,
)
return jsonable_encoder(results)

View file

@ -4,7 +4,7 @@ from typing import Union, Optional, List, Type
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.users.models import User
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult
from cognee.modules.search.types import SearchResult, SearchType
from cognee.modules.users.methods import get_default_user
from cognee.modules.search.methods import search as search_function
from cognee.modules.data.methods import get_authorized_existing_datasets
@ -32,11 +32,11 @@ async def search(
save_interaction: bool = False,
last_k: Optional[int] = 1,
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[List[SearchResult], CombinedSearchResult]:
verbose: bool = False,
) -> List[SearchResult]:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -126,6 +126,8 @@ async def search(
session_id: Optional session identifier for caching Q&A interactions. Defaults to 'default_session' if None.
verbose: If True, returns detailed result information including graph representation (when possible).
Returns:
list: Search results in format determined by query_type:
@ -214,10 +216,10 @@ async def search(
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
verbose=verbose,
)
return filtered_search_results

View file

@ -17,3 +17,9 @@ async def setup():
await create_relational_db_and_tables()
if not backend_access_control_enabled():
await create_pgvector_db_and_tables()
if __name__ == "__main__":
import asyncio
asyncio.run(setup())

View file

@ -215,9 +215,6 @@ class CogneeGraph(CogneeAbstractGraph):
edge_penalty=triplet_distance_penalty,
)
self.add_edge(edge)
source_node.add_skeleton_edge(edge)
target_node.add_skeleton_edge(edge)
else:
raise EntityNotFoundError(
message=f"Edge references nonexistent nodes: {source_id} -> {target_id}"

View file

@ -27,6 +27,5 @@ await cognee.cognify(datasets=["python-development-with-cognee"], temporal_cogni
results = await cognee.search(
"What Python type hinting challenges did I face, and how does Guido approach similar problems in mypy?",
datasets=["python-development-with-cognee"],
use_combined_context=True, # Used to show reasoning graph visualization
)
print(results)

View file

@ -1,21 +1,18 @@
import asyncio
import time
from typing import List, Optional, Type
from typing import List, Optional, Type, Union
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
logger = get_logger(level=ERROR)
def format_triplets(edges):
"""Formats edges into human-readable triplet strings."""
triplets = []
for edge in edges:
node1 = edge.node1
@ -24,12 +21,10 @@ def format_triplets(edges):
node1_attributes = node1.attributes
node2_attributes = node2.attributes
# Filter only non-None properties
node1_info = {key: value for key, value in node1_attributes.items() if value is not None}
node2_info = {key: value for key, value in node2_attributes.items() if value is not None}
edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
# Create the formatted triplet
triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n"
triplets.append(triplet)
@ -51,7 +46,6 @@ async def get_memory_fragment(
try:
graph_engine = await get_graph_engine()
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project,
@ -61,20 +55,64 @@ async def get_memory_fragment(
relevant_ids_to_filter=relevant_ids_to_filter,
triplet_distance_penalty=triplet_distance_penalty,
)
except EntityNotFoundError:
# This is expected behavior - continue with empty fragment
pass
except Exception as e:
logger.error(f"Error during memory fragment creation: {str(e)}")
# Still return the fragment even if projection failed
pass
return memory_fragment
async def _get_top_triplet_importances(
memory_fragment: Optional[CogneeGraph],
vector_search: NodeEdgeVectorSearch,
properties_to_project: Optional[List[str]],
node_type: Optional[Type],
node_name: Optional[List[str]],
triplet_distance_penalty: float,
wide_search_limit: Optional[int],
top_k: int,
query_list_length: Optional[int] = None,
) -> Union[List[Edge], List[List[Edge]]]:
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances.
Args:
query_list_length: Number of queries in batch mode (None for single-query mode).
When None, node_distances/edge_distances are flat lists; when set, they are list-of-lists.
Returns:
List[Edge]: For single-query mode (query_list_length is None).
List[List[Edge]]: For batch mode (query_list_length is set), one list per query.
"""
if memory_fragment is None:
if wide_search_limit is None:
relevant_node_ids = None
else:
relevant_node_ids = vector_search.extract_relevant_node_ids()
memory_fragment = await get_memory_fragment(
properties_to_project=properties_to_project,
node_type=node_type,
node_name=node_name,
relevant_ids_to_filter=relevant_node_ids,
triplet_distance_penalty=triplet_distance_penalty,
)
await memory_fragment.map_vector_distances_to_graph_nodes(
node_distances=vector_search.node_distances, query_list_length=query_list_length
)
await memory_fragment.map_vector_distances_to_graph_edges(
edge_distances=vector_search.edge_distances, query_list_length=query_list_length
)
return await memory_fragment.calculate_top_triplet_importances(
k=top_k, query_list_length=query_list_length
)
async def brute_force_triplet_search(
query: str,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
top_k: int = 5,
collections: Optional[List[str]] = None,
properties_to_project: Optional[List[str]] = None,
@ -83,33 +121,49 @@ async def brute_force_triplet_search(
node_name: Optional[List[str]] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Edge]:
) -> Union[List[Edge], List[List[Edge]]]:
"""
Performs a brute force search to retrieve the top triplets from the graph.
Args:
query (str): The search query.
query (Optional[str]): The search query (single query mode). Exactly one of query or query_batch must be provided.
query_batch (Optional[List[str]]): List of search queries (batch mode). Exactly one of query or query_batch must be provided.
top_k (int): The number of top results to retrieve.
collections (Optional[List[str]]): List of collections to query.
properties_to_project (Optional[List[str]]): List of properties to project.
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
node_type: node type to filter
node_name: node name to filter
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections.
Ignored in batch mode (always None to project full graph).
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
Returns:
list: The top triplet results.
List[Edge]: The top triplet results for single query mode (flat list).
List[List[Edge]]: List of top triplet results (one per query) for batch mode (list-of-lists).
Note:
In single-query mode, node_distances and edge_distances are stored as flat lists.
In batch mode, they are stored as list-of-lists (one list per query).
"""
if not query or not isinstance(query, str):
if query is not None and query_batch is not None:
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
if query is None and query_batch is None:
raise ValueError("Must provide either 'query' or 'query_batch'.")
if query is not None and (not query or not isinstance(query, str)):
raise ValueError("The query must be a non-empty string.")
if query_batch is not None:
if not isinstance(query_batch, list) or not query_batch:
raise ValueError("query_batch must be a non-empty list of strings.")
if not all(isinstance(q, str) and q for q in query_batch):
raise ValueError("All items in query_batch must be non-empty strings.")
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
# Setting wide search limit based on the parameters
non_global_search = node_name is None
wide_search_limit = wide_search_top_k if non_global_search else None
query_list_length = len(query_batch) if query_batch is not None else None
wide_search_limit = (
None if query_list_length else (wide_search_top_k if node_name is None else None)
)
if collections is None:
collections = [
@ -123,77 +177,37 @@ async def brute_force_triplet_search(
collections.append("EdgeType_relationship_name")
try:
vector_engine = get_vector_engine()
except Exception as e:
logger.error("Failed to initialize vector engine: %s", e)
raise RuntimeError("Initialization error") from e
vector_search = NodeEdgeVectorSearch()
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
async def search_in_collection(collection_name: str):
try:
return await vector_engine.search(
collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
)
except CollectionNotFoundError:
return []
try:
start_time = time.time()
results = await asyncio.gather(
*[search_in_collection(collection_name) for collection_name in collections]
await vector_search.embed_and_retrieve_distances(
query=None if query_list_length else query,
query_batch=query_batch if query_list_length else None,
collections=collections,
wide_search_limit=wide_search_limit,
)
if all(not item for item in results):
return []
if not vector_search.has_results():
return [[] for _ in range(query_list_length)] if query_list_length else []
# Final statistics
vector_collection_search_time = time.time() - start_time
logger.info(
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
results = await _get_top_triplet_importances(
memory_fragment,
vector_search,
properties_to_project,
node_type,
node_name,
triplet_distance_penalty,
wide_search_limit,
top_k,
query_list_length=query_list_length,
)
node_distances = {collection: result for collection, result in zip(collections, results)}
edge_distances = node_distances.get("EdgeType_relationship_name", None)
if wide_search_limit is not None:
relevant_ids_to_filter = list(
{
str(getattr(scored_node, "id"))
for collection_name, score_collection in node_distances.items()
if collection_name != "EdgeType_relationship_name"
and isinstance(score_collection, (list, tuple))
for scored_node in score_collection
if getattr(scored_node, "id", None)
}
)
else:
relevant_ids_to_filter = None
if memory_fragment is None:
memory_fragment = await get_memory_fragment(
properties_to_project=properties_to_project,
node_type=node_type,
node_name=node_name,
relevant_ids_to_filter=relevant_ids_to_filter,
triplet_distance_penalty=triplet_distance_penalty,
)
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
return results
except CollectionNotFoundError:
return []
return [[] for _ in range(query_list_length)] if query_list_length else []
except Exception as error:
logger.error(
"Error during brute force search for query: %s. Error: %s",
query,
query_batch if query_list_length else [query],
error,
)
raise error

View file

@ -0,0 +1,174 @@
import asyncio
import time
from typing import Any, List, Optional
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector import get_vector_engine
logger = get_logger(level=ERROR)
class NodeEdgeVectorSearch:
"""Manages vector search and distance retrieval for graph nodes and edges."""
def __init__(self, edge_collection: str = "EdgeType_relationship_name", vector_engine=None):
self.edge_collection = edge_collection
self.vector_engine = vector_engine or self._init_vector_engine()
self.query_vector: Optional[Any] = None
self.node_distances: dict[str, list[Any]] = {}
self.edge_distances: list[Any] = []
self.query_list_length: Optional[int] = None
def _init_vector_engine(self):
try:
return get_vector_engine()
except Exception as e:
logger.error("Failed to initialize vector engine: %s", e)
raise RuntimeError("Initialization error") from e
async def embed_and_retrieve_distances(
self,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
collections: List[str] = None,
wide_search_limit: Optional[int] = None,
):
"""Embeds query/queries and retrieves vector distances from all collections."""
if query is not None and query_batch is not None:
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
if query is None and query_batch is None:
raise ValueError("Must provide either 'query' or 'query_batch'.")
if not collections:
raise ValueError("'collections' must be a non-empty list.")
start_time = time.time()
if query_batch is not None:
self.query_list_length = len(query_batch)
search_results = await self._run_batch_search(collections, query_batch)
else:
self.query_list_length = None
search_results = await self._run_single_search(collections, query, wide_search_limit)
elapsed_time = time.time() - start_time
collections_with_results = sum(1 for result in search_results if any(result))
logger.info(
f"Vector collection retrieval completed: Retrieved distances from "
f"{collections_with_results} collections in {elapsed_time:.2f}s"
)
self.set_distances_from_results(collections, search_results, self.query_list_length)
def has_results(self) -> bool:
"""Checks if any collections returned results."""
if self.query_list_length is None:
if self.edge_distances and any(self.edge_distances):
return True
return any(
bool(collection_results) for collection_results in self.node_distances.values()
)
if self.edge_distances and any(inner_list for inner_list in self.edge_distances):
return True
return any(
any(results_per_query for results_per_query in collection_results)
for collection_results in self.node_distances.values()
)
def extract_relevant_node_ids(self) -> List[str]:
"""Extracts unique node IDs from search results."""
if self.query_list_length is not None:
return []
relevant_node_ids = set()
for scored_results in self.node_distances.values():
for scored_node in scored_results:
node_id = getattr(scored_node, "id", None)
if node_id:
relevant_node_ids.add(str(node_id))
return list(relevant_node_ids)
def set_distances_from_results(
self,
collections: List[str],
search_results: List[List[Any]],
query_list_length: Optional[int] = None,
):
"""Separates search results into node and edge distances with stable shapes.
Ensures all collections are present in the output, even if empty:
- Batch mode: missing/empty collections become [[]] * query_list_length
- Single mode: missing/empty collections become []
"""
self.node_distances = {}
self.edge_distances = (
[] if query_list_length is None else [[] for _ in range(query_list_length)]
)
for collection, result in zip(collections, search_results):
if not result:
empty_result = (
[] if query_list_length is None else [[] for _ in range(query_list_length)]
)
if collection == self.edge_collection:
self.edge_distances = empty_result
else:
self.node_distances[collection] = empty_result
else:
if collection == self.edge_collection:
self.edge_distances = result
else:
self.node_distances[collection] = result
async def _run_batch_search(
self, collections: List[str], query_batch: List[str]
) -> List[List[Any]]:
"""Runs batch search across all collections and returns list-of-lists per collection."""
search_tasks = [
self._search_batch_collection(collection, query_batch) for collection in collections
]
return await asyncio.gather(*search_tasks)
async def _search_batch_collection(
self, collection_name: str, query_batch: List[str]
) -> List[List[Any]]:
"""Searches one collection with batch queries and returns list-of-lists."""
try:
return await self.vector_engine.batch_search(
collection_name=collection_name, query_texts=query_batch, limit=None
)
except CollectionNotFoundError:
return [[]] * len(query_batch)
async def _run_single_search(
self, collections: List[str], query: str, wide_search_limit: Optional[int]
) -> List[List[Any]]:
"""Runs single query search and returns flat lists per collection.
Returns a list where each element is a collection's results (flat list).
These are stored as flat lists in node_distances/edge_distances for single-query mode.
"""
await self._embed_query(query)
search_tasks = [
self._search_single_collection(self.vector_engine, wide_search_limit, collection)
for collection in collections
]
search_results = await asyncio.gather(*search_tasks)
return search_results
async def _embed_query(self, query: str):
"""Embeds the query and stores the resulting vector."""
query_embeddings = await self.vector_engine.embedding_engine.embed_text([query])
self.query_vector = query_embeddings[0]
async def _search_single_collection(
self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str
):
"""Searches one collection and returns results or empty list if not found."""
try:
return await vector_engine.search(
collection_name=collection_name,
query_vector=self.query_vector,
limit=wide_search_limit,
)
except CollectionNotFoundError:
return []

View file

@ -14,8 +14,6 @@ from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.search.types import (
SearchResult,
CombinedSearchResult,
SearchResultDataset,
SearchType,
)
from cognee.modules.search.operations import log_query, log_result
@ -45,11 +43,11 @@ async def search(
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[CombinedSearchResult, List[SearchResult]]:
verbose=False,
) -> List[SearchResult]:
"""
Args:
@ -90,7 +88,6 @@ async def search(
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
@ -127,87 +124,63 @@ async def search(
query.id,
json.dumps(
jsonable_encoder(
await prepare_search_result(
search_results[0] if isinstance(search_results, list) else search_results
)
if use_combined_context
else [
await prepare_search_result(search_result) for search_result in search_results
]
[await prepare_search_result(search_result) for search_result in search_results]
)
),
user.id,
)
if use_combined_context:
prepared_search_results = await prepare_search_result(
search_results[0] if isinstance(search_results, list) else search_results
)
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
# This is for maintaining backwards compatibility
if backend_access_control_enabled():
return_value = []
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)
return CombinedSearchResult(
result=result,
graphs=graphs,
context=context,
datasets=[
SearchResultDataset(
id=dataset.id,
name=dataset.name,
)
for dataset in datasets
],
)
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
if only_context:
search_result_dict = {
"search_result": [context] if context else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
}
if verbose:
# Include graphs only in verbose mode
search_result_dict["graphs"] = graphs
return_value.append(search_result_dict)
else:
search_result_dict = {
"search_result": [result] if result else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
}
if verbose:
# Include graphs only in verbose mode
search_result_dict["graphs"] = graphs
return_value.append(search_result_dict)
return return_value
else:
# This is for maintaining backwards compatibility
if backend_access_control_enabled():
return_value = []
return_value = []
if only_context:
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
if only_context:
return_value.append(
{
"search_result": [context] if context else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)
else:
return_value.append(
{
"search_result": [result] if result else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)
return return_value
return_value.append(prepared_search_results["context"])
else:
return_value = []
if only_context:
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)
return_value.append(prepared_search_results["context"])
else:
for search_result in search_results:
result, context, datasets = search_result
return_value.append(result)
# For maintaining backwards compatibility
if len(return_value) == 1 and isinstance(return_value[0], list):
return return_value[0]
else:
return return_value
for search_result in search_results:
result, context, datasets = search_result
return_value.append(result)
# For maintaining backwards compatibility
if len(return_value) == 1 and isinstance(return_value[0], list):
return return_value[0]
else:
return return_value
async def authorized_search(
@ -223,14 +196,10 @@ async def authorized_search(
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
]:
) -> List[Tuple[Any, Union[List[Edge], str], List[Dataset]]]:
"""
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
Not to be used outside of active access control mode.
@ -240,70 +209,6 @@ async def authorized_search(
datasets=dataset_ids, permission_type="read", user=user
)
if use_combined_context:
search_responses = await search_in_datasets_context(
search_datasets=search_datasets,
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=True,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
context = {}
datasets: List[Dataset] = []
for _, search_context, search_datasets in search_responses:
for dataset in search_datasets:
context[str(dataset.id)] = search_context
datasets.extend(search_datasets)
specific_search_tools = await get_search_type_tools(
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
[get_completion, _] = search_tools
else:
get_completion = search_tools[0]
def prepare_combined_context(
context,
) -> Union[List[Edge], str]:
combined_context = []
for dataset_context in context.values():
combined_context += dataset_context
if combined_context and isinstance(combined_context[0], str):
return "\n".join(combined_context)
return combined_context
combined_context = prepare_combined_context(context)
completion = await get_completion(query_text, combined_context, session_id=session_id)
return completion, combined_context, datasets
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await search_in_datasets_context(
search_datasets=search_datasets,
@ -319,6 +224,7 @@ async def authorized_search(
only_context=only_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
return search_results

View file

@ -1,6 +1,6 @@
from uuid import UUID
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from typing import Any, Optional
class SearchResultDataset(BaseModel):
@ -8,13 +8,6 @@ class SearchResultDataset(BaseModel):
name: str
class CombinedSearchResult(BaseModel):
result: Optional[Any]
context: Dict[str, Any]
graphs: Optional[Dict[str, Any]] = {}
datasets: Optional[List[SearchResultDataset]] = None
class SearchResult(BaseModel):
search_result: Any
dataset_id: Optional[UUID]

View file

@ -1,2 +1,2 @@
from .SearchType import SearchType
from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult
from .SearchResult import SearchResult, SearchResultDataset

View file

@ -92,7 +92,7 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
}
links_list.append(link_data)
html_template = """
html_template = r"""
<!DOCTYPE html>
<html>
<head>

View file

@ -0,0 +1,613 @@
# cognee/tasks/memify/extract_usage_frequency.py
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
from cognee.shared.logging_utils import get_logger
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.pipelines.tasks.task import Task
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
logger = get_logger("extract_usage_frequency")
async def extract_usage_frequency(
subgraphs: List[CogneeGraph],
time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1,
) -> Dict[str, Any]:
"""
Extract usage frequency from CogneeUserInteraction nodes.
When save_interaction=True in cognee.search(), the system creates:
- CogneeUserInteraction nodes (representing the query/answer interaction)
- used_graph_element_to_answer edges (connecting interactions to graph elements used)
This function tallies how often each graph element is referenced via these edges,
enabling frequency-based ranking in downstream retrievers.
:param subgraphs: List of CogneeGraph instances containing interaction data
:param time_window: Time window to consider for interactions (default: 7 days)
:param min_interaction_threshold: Minimum interactions to track (default: 1)
:return: Dictionary containing node frequencies, edge frequencies, and metadata
"""
current_time = datetime.now()
cutoff_time = current_time - time_window
# Track frequencies for graph elements (nodes and edges)
node_frequencies = {}
edge_frequencies = {}
relationship_type_frequencies = {}
# Track interaction metadata
interaction_count = 0
interactions_in_window = 0
logger.info(f"Extracting usage frequencies from {len(subgraphs)} subgraphs")
logger.info(f"Time window: {time_window}, Cutoff: {cutoff_time.isoformat()}")
for subgraph in subgraphs:
# Find all CogneeUserInteraction nodes
interaction_nodes = {}
for node_id, node in subgraph.nodes.items():
node_type = node.attributes.get("type") or node.attributes.get("node_type")
if node_type == "CogneeUserInteraction":
# Parse and validate timestamp
timestamp_value = node.attributes.get("timestamp") or node.attributes.get(
"created_at"
)
if timestamp_value is not None:
try:
# Handle various timestamp formats
interaction_time = None
if isinstance(timestamp_value, datetime):
# Already a Python datetime
interaction_time = timestamp_value
elif isinstance(timestamp_value, (int, float)):
# Unix timestamp (assume milliseconds if > 10 digits)
if timestamp_value > 10000000000:
# Milliseconds since epoch
interaction_time = datetime.fromtimestamp(timestamp_value / 1000.0)
else:
# Seconds since epoch
interaction_time = datetime.fromtimestamp(timestamp_value)
elif isinstance(timestamp_value, str):
# Try different string formats
if timestamp_value.isdigit():
# Numeric string - treat as Unix timestamp
ts_int = int(timestamp_value)
if ts_int > 10000000000:
interaction_time = datetime.fromtimestamp(ts_int / 1000.0)
else:
interaction_time = datetime.fromtimestamp(ts_int)
else:
# ISO format string
interaction_time = datetime.fromisoformat(timestamp_value)
elif hasattr(timestamp_value, "to_native"):
# Neo4j datetime object - convert to Python datetime
interaction_time = timestamp_value.to_native()
elif hasattr(timestamp_value, "year") and hasattr(timestamp_value, "month"):
# Datetime-like object - extract components
try:
interaction_time = datetime(
year=timestamp_value.year,
month=timestamp_value.month,
day=timestamp_value.day,
hour=getattr(timestamp_value, "hour", 0),
minute=getattr(timestamp_value, "minute", 0),
second=getattr(timestamp_value, "second", 0),
microsecond=getattr(timestamp_value, "microsecond", 0),
)
except (AttributeError, ValueError):
pass
if interaction_time is None:
# Last resort: try converting to string and parsing
str_value = str(timestamp_value)
if str_value.isdigit():
ts_int = int(str_value)
if ts_int > 10000000000:
interaction_time = datetime.fromtimestamp(ts_int / 1000.0)
else:
interaction_time = datetime.fromtimestamp(ts_int)
else:
interaction_time = datetime.fromisoformat(str_value)
if interaction_time is None:
raise ValueError(f"Could not parse timestamp: {timestamp_value}")
# Make sure it's timezone-naive for comparison
if interaction_time.tzinfo is not None:
interaction_time = interaction_time.replace(tzinfo=None)
interaction_nodes[node_id] = {
"node": node,
"timestamp": interaction_time,
"in_window": interaction_time >= cutoff_time,
}
interaction_count += 1
if interaction_time >= cutoff_time:
interactions_in_window += 1
except (ValueError, TypeError, AttributeError, OSError) as e:
logger.warning(
f"Failed to parse timestamp for interaction node {node_id}: {e}"
)
logger.debug(
f"Timestamp value type: {type(timestamp_value)}, value: {timestamp_value}"
)
# Process edges to find graph elements used in interactions
for edge in subgraph.edges:
relationship_type = edge.attributes.get("relationship_type")
# Look for 'used_graph_element_to_answer' edges
if relationship_type == "used_graph_element_to_answer":
# node1 should be the CogneeUserInteraction, node2 is the graph element
source_id = str(edge.node1.id)
target_id = str(edge.node2.id)
# Check if source is an interaction node in our time window
if source_id in interaction_nodes:
interaction_data = interaction_nodes[source_id]
if interaction_data["in_window"]:
# Count the graph element (target node) being used
node_frequencies[target_id] = node_frequencies.get(target_id, 0) + 1
# Also track what type of element it is for analytics
target_node = subgraph.get_node(target_id)
if target_node:
element_type = target_node.attributes.get(
"type"
) or target_node.attributes.get("node_type")
if element_type:
relationship_type_frequencies[element_type] = (
relationship_type_frequencies.get(element_type, 0) + 1
)
# Also track general edge usage patterns
elif relationship_type and relationship_type != "used_graph_element_to_answer":
# Check if either endpoint is referenced in a recent interaction
source_id = str(edge.node1.id)
target_id = str(edge.node2.id)
# If this edge connects to any frequently accessed nodes, track the edge type
if source_id in node_frequencies or target_id in node_frequencies:
edge_key = f"{relationship_type}:{source_id}:{target_id}"
edge_frequencies[edge_key] = edge_frequencies.get(edge_key, 0) + 1
# Filter frequencies above threshold
filtered_node_frequencies = {
node_id: freq
for node_id, freq in node_frequencies.items()
if freq >= min_interaction_threshold
}
filtered_edge_frequencies = {
edge_key: freq
for edge_key, freq in edge_frequencies.items()
if freq >= min_interaction_threshold
}
logger.info(
f"Processed {interactions_in_window}/{interaction_count} interactions in time window"
)
logger.info(
f"Found {len(filtered_node_frequencies)} nodes and {len(filtered_edge_frequencies)} edges "
f"above threshold (min: {min_interaction_threshold})"
)
logger.info(f"Element type distribution: {relationship_type_frequencies}")
return {
"node_frequencies": filtered_node_frequencies,
"edge_frequencies": filtered_edge_frequencies,
"element_type_frequencies": relationship_type_frequencies,
"total_interactions": interaction_count,
"interactions_in_window": interactions_in_window,
"time_window_days": time_window.days,
"last_processed_timestamp": current_time.isoformat(),
"cutoff_timestamp": cutoff_time.isoformat(),
}
async def add_frequency_weights(
graph_adapter: GraphDBInterface, usage_frequencies: Dict[str, Any]
) -> None:
"""
Add frequency weights to graph nodes and edges using the graph adapter.
Uses direct Cypher queries for Neo4j adapter compatibility.
Writes frequency_weight properties back to the graph for use in:
- Ranking frequently referenced entities higher during retrieval
- Adjusting scoring for completion strategies
- Exposing usage metrics in dashboards or audits
:param graph_adapter: Graph database adapter interface
:param usage_frequencies: Calculated usage frequencies from extract_usage_frequency
"""
node_frequencies = usage_frequencies.get("node_frequencies", {})
edge_frequencies = usage_frequencies.get("edge_frequencies", {})
logger.info(f"Adding frequency weights to {len(node_frequencies)} nodes")
# Check adapter type and use appropriate method
adapter_type = type(graph_adapter).__name__
logger.info(f"Using adapter: {adapter_type}")
nodes_updated = 0
nodes_failed = 0
# Determine which method to use based on adapter type
use_neo4j_cypher = adapter_type == "Neo4jAdapter" and hasattr(graph_adapter, "query")
use_kuzu_query = adapter_type == "KuzuAdapter" and hasattr(graph_adapter, "query")
use_get_update = hasattr(graph_adapter, "get_node_by_id") and hasattr(
graph_adapter, "update_node_properties"
)
# Method 1: Neo4j Cypher with SET (creates properties on the fly)
if use_neo4j_cypher:
try:
logger.info("Using Neo4j Cypher SET method")
last_updated = usage_frequencies.get("last_processed_timestamp")
for node_id, frequency in node_frequencies.items():
try:
query = """
MATCH (n)
WHERE n.id = $node_id
SET n.frequency_weight = $frequency,
n.frequency_updated_at = $updated_at
RETURN n.id as id
"""
result = await graph_adapter.query(
query,
params={
"node_id": node_id,
"frequency": frequency,
"updated_at": last_updated,
},
)
if result and len(result) > 0:
nodes_updated += 1
else:
logger.warning(f"Node {node_id} not found or not updated")
nodes_failed += 1
except Exception as e:
logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
except Exception as e:
logger.error(f"Neo4j Cypher update failed: {e}")
use_neo4j_cypher = False
# Method 2: Kuzu - use get_node + add_node (updates via re-adding with same ID)
elif (
use_kuzu_query and hasattr(graph_adapter, "get_node") and hasattr(graph_adapter, "add_node")
):
logger.info("Using Kuzu get_node + add_node method")
last_updated = usage_frequencies.get("last_processed_timestamp")
for node_id, frequency in node_frequencies.items():
try:
# Get the existing node (returns a dict)
existing_node_dict = await graph_adapter.get_node(node_id)
if existing_node_dict:
# Update the dict with new properties
existing_node_dict["frequency_weight"] = frequency
existing_node_dict["frequency_updated_at"] = last_updated
# Kuzu's add_node likely just takes the dict directly, not a Node object
# Try passing the dict directly first
try:
await graph_adapter.add_node(existing_node_dict)
nodes_updated += 1
except Exception as dict_error:
# If dict doesn't work, try creating a Node object
logger.debug(f"Dict add failed, trying Node object: {dict_error}")
try:
from cognee.infrastructure.engine import Node
# Try different Node constructor patterns
try:
# Pattern 1: Just properties
node_obj = Node(existing_node_dict)
except Exception:
# Pattern 2: Type and properties
node_obj = Node(
type=existing_node_dict.get("type", "Unknown"),
**existing_node_dict,
)
await graph_adapter.add_node(node_obj)
nodes_updated += 1
except Exception as node_error:
logger.error(f"Both dict and Node object failed: {node_error}")
nodes_failed += 1
else:
logger.warning(f"Node {node_id} not found in graph")
nodes_failed += 1
except Exception as e:
logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
# Method 3: Generic get_node_by_id + update_node_properties
elif use_get_update:
logger.info("Using get/update method for adapter")
for node_id, frequency in node_frequencies.items():
try:
# Get current node data
node_data = await graph_adapter.get_node_by_id(node_id)
if node_data:
# Tweak the properties dict - add frequency_weight
if isinstance(node_data, dict):
properties = node_data.get("properties", {})
else:
properties = getattr(node_data, "properties", {}) or {}
# Update with frequency weight
properties["frequency_weight"] = frequency
properties["frequency_updated_at"] = usage_frequencies.get(
"last_processed_timestamp"
)
# Write back via adapter
await graph_adapter.update_node_properties(node_id, properties)
nodes_updated += 1
else:
logger.warning(f"Node {node_id} not found in graph")
nodes_failed += 1
except Exception as e:
logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
for node_id, frequency in node_frequencies.items():
try:
# Get current node data
node_data = await graph_adapter.get_node_by_id(node_id)
if node_data:
# Tweak the properties dict - add frequency_weight
if isinstance(node_data, dict):
properties = node_data.get("properties", {})
else:
properties = getattr(node_data, "properties", {}) or {}
# Update with frequency weight
properties["frequency_weight"] = frequency
properties["frequency_updated_at"] = usage_frequencies.get(
"last_processed_timestamp"
)
# Write back via adapter
await graph_adapter.update_node_properties(node_id, properties)
nodes_updated += 1
else:
logger.warning(f"Node {node_id} not found in graph")
nodes_failed += 1
except Exception as e:
logger.error(f"Error updating node {node_id}: {e}")
nodes_failed += 1
# If no method is available
if not use_neo4j_cypher and not use_kuzu_query and not use_get_update:
logger.error(f"Adapter {adapter_type} does not support required update methods")
logger.error(
"Required: either 'query' method or both 'get_node_by_id' and 'update_node_properties'"
)
return
# Update edge frequencies
# Note: Edge property updates are backend-specific
if edge_frequencies:
logger.info(f"Processing {len(edge_frequencies)} edge frequency entries")
edges_updated = 0
edges_failed = 0
for edge_key, frequency in edge_frequencies.items():
try:
# Parse edge key: "relationship_type:source_id:target_id"
parts = edge_key.split(":", 2)
if len(parts) == 3:
relationship_type, source_id, target_id = parts
# Try to update edge if adapter supports it
if hasattr(graph_adapter, "update_edge_properties"):
edge_properties = {
"frequency_weight": frequency,
"frequency_updated_at": usage_frequencies.get(
"last_processed_timestamp"
),
}
await graph_adapter.update_edge_properties(
source_id, target_id, relationship_type, edge_properties
)
edges_updated += 1
else:
# Fallback: store in metadata or log
logger.debug(
f"Adapter doesn't support update_edge_properties for "
f"{relationship_type} ({source_id} -> {target_id})"
)
except Exception as e:
logger.error(f"Error updating edge {edge_key}: {e}")
edges_failed += 1
if edges_updated > 0:
logger.info(f"Edge update complete: {edges_updated} succeeded, {edges_failed} failed")
else:
logger.info(
"Edge frequency updates skipped (adapter may not support edge property updates)"
)
# Store aggregate statistics as metadata if supported
if hasattr(graph_adapter, "set_metadata"):
try:
metadata = {
"element_type_frequencies": usage_frequencies.get("element_type_frequencies", {}),
"total_interactions": usage_frequencies.get("total_interactions", 0),
"interactions_in_window": usage_frequencies.get("interactions_in_window", 0),
"last_frequency_update": usage_frequencies.get("last_processed_timestamp"),
}
await graph_adapter.set_metadata("usage_frequency_stats", metadata)
logger.info("Stored usage frequency statistics as metadata")
except Exception as e:
logger.warning(f"Could not store usage statistics as metadata: {e}")
async def create_usage_frequency_pipeline(
graph_adapter: GraphDBInterface,
time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1,
batch_size: int = 100,
) -> tuple:
"""
Create memify pipeline entry for usage frequency tracking.
This follows the same pattern as feedback enrichment flows, allowing
the frequency update to run end-to-end in a custom memify pipeline.
Use case example:
extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline(
graph_adapter=my_adapter,
time_window=timedelta(days=30),
min_interaction_threshold=2
)
# Run in memify pipeline
pipeline = Pipeline(extraction_tasks + enrichment_tasks)
results = await pipeline.run()
:param graph_adapter: Graph database adapter
:param time_window: Time window for counting interactions (default: 7 days)
:param min_interaction_threshold: Minimum interactions to track (default: 1)
:param batch_size: Batch size for processing (default: 100)
:return: Tuple of (extraction_tasks, enrichment_tasks)
"""
logger.info("Creating usage frequency pipeline")
logger.info(f"Config: time_window={time_window}, threshold={min_interaction_threshold}")
extraction_tasks = [
Task(
extract_usage_frequency,
time_window=time_window,
min_interaction_threshold=min_interaction_threshold,
)
]
enrichment_tasks = [
Task(
add_frequency_weights,
graph_adapter=graph_adapter,
task_config={"batch_size": batch_size},
)
]
return extraction_tasks, enrichment_tasks
async def run_usage_frequency_update(
graph_adapter: GraphDBInterface,
subgraphs: List[CogneeGraph],
time_window: timedelta = timedelta(days=7),
min_interaction_threshold: int = 1,
) -> Dict[str, Any]:
"""
Convenience function to run the complete usage frequency update pipeline.
This is the main entry point for updating frequency weights on graph elements
based on CogneeUserInteraction data from cognee.search(save_interaction=True).
Example usage:
# After running searches with save_interaction=True
from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update
# Get the graph with interactions
graph = await get_cognee_graph_with_interactions()
# Update frequency weights
stats = await run_usage_frequency_update(
graph_adapter=graph_adapter,
subgraphs=[graph],
time_window=timedelta(days=30), # Last 30 days
min_interaction_threshold=2 # At least 2 uses
)
print(f"Updated {len(stats['node_frequencies'])} nodes")
:param graph_adapter: Graph database adapter
:param subgraphs: List of CogneeGraph instances with interaction data
:param time_window: Time window for counting interactions
:param min_interaction_threshold: Minimum interactions to track
:return: Usage frequency statistics
"""
logger.info("Starting usage frequency update")
try:
# Extract frequencies from interaction data
usage_frequencies = await extract_usage_frequency(
subgraphs=subgraphs,
time_window=time_window,
min_interaction_threshold=min_interaction_threshold,
)
# Add frequency weights back to the graph
await add_frequency_weights(
graph_adapter=graph_adapter, usage_frequencies=usage_frequencies
)
logger.info("Usage frequency update completed successfully")
logger.info(
f"Summary: {usage_frequencies['interactions_in_window']} interactions processed, "
f"{len(usage_frequencies['node_frequencies'])} nodes weighted"
)
return usage_frequencies
except Exception as e:
logger.error(f"Error during usage frequency update: {str(e)}")
raise
async def get_most_frequent_elements(
graph_adapter: GraphDBInterface, top_n: int = 10, element_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Retrieve the most frequently accessed graph elements.
Useful for analytics dashboards and understanding user behavior.
:param graph_adapter: Graph database adapter
:param top_n: Number of top elements to return
:param element_type: Optional filter by element type
:return: List of elements with their frequency weights
"""
logger.info(f"Retrieving top {top_n} most frequent elements")
# This would need to be implemented based on the specific graph adapter's query capabilities
# Pseudocode:
# results = await graph_adapter.query_nodes_by_property(
# property_name='frequency_weight',
# order_by='DESC',
# limit=top_n,
# filters={'type': element_type} if element_type else None
# )
logger.warning("get_most_frequent_elements needs adapter-specific implementation")
return []

View file

@ -0,0 +1,67 @@
import os
import pathlib
import pytest
import pytest_asyncio
import cognee
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
skip_without_provider = pytest.mark.skipif(
not (os.getenv("OPENAI_API_KEY") or os.getenv("AZURE_OPENAI_API_KEY")),
reason="requires embedding/vector provider credentials",
)
@pytest_asyncio.fixture
async def clean_environment():
"""Configure isolated storage and ensure cleanup before/after."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_brute_force_triplet_search_e2e")
data_directory_path = str(base_dir / ".data_storage/test_brute_force_triplet_search_e2e")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@skip_without_provider
@pytest.mark.asyncio
async def test_brute_force_triplet_search_end_to_end(clean_environment):
"""Minimal end-to-end exercise of single and batch triplet search."""
text = """
Cognee is an open-source AI memory engine that structures data into searchable formats for use with AI agents.
The company focuses on persistent memory systems using knowledge graphs and vector search.
It is a Berlin-based startup building infrastructure for context-aware AI applications.
"""
await cognee.add(text)
await cognee.cognify()
single_result = await brute_force_triplet_search(query="What is NLP?", top_k=1)
assert isinstance(single_result, list)
if single_result:
assert all(isinstance(edge, Edge) for edge in single_result)
batch_queries = ["What is Cognee?", "What is the company's focus?"]
batch_result = await brute_force_triplet_search(query_batch=batch_queries, top_k=1)
assert isinstance(batch_result, list)
assert len(batch_result) == len(batch_queries)
assert all(isinstance(per_query, list) for per_query in batch_result)
for per_query in batch_result:
if per_query:
assert all(isinstance(edge, Edge) for edge in per_query)

View file

@ -0,0 +1,308 @@
"""
Test Suite: Usage Frequency Tracking
Comprehensive tests for the usage frequency tracking implementation.
Tests cover extraction logic, adapter integration, edge cases, and end-to-end workflows.
Run with:
pytest test_usage_frequency_comprehensive.py -v
Or without pytest:
python test_usage_frequency_comprehensive.py
"""
import asyncio
import unittest
from datetime import datetime, timedelta
from typing import List, Dict
# Mock imports for testing without full Cognee setup
try:
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
from cognee.tasks.memify.extract_usage_frequency import (
extract_usage_frequency,
add_frequency_weights,
run_usage_frequency_update,
)
COGNEE_AVAILABLE = True
except ImportError:
COGNEE_AVAILABLE = False
print("⚠ Cognee not fully available - some tests will be skipped")
class TestUsageFrequencyExtraction(unittest.TestCase):
"""Test the core frequency extraction logic."""
def setUp(self):
"""Set up test fixtures."""
if not COGNEE_AVAILABLE:
self.skipTest("Cognee modules not available")
def create_mock_graph(self, num_interactions: int = 3, num_elements: int = 5):
"""Create a mock graph with interactions and elements."""
graph = CogneeGraph()
# Create interaction nodes
current_time = datetime.now()
for i in range(num_interactions):
interaction_node = Node(
id=f"interaction_{i}",
node_type="CogneeUserInteraction",
attributes={
"type": "CogneeUserInteraction",
"query_text": f"Test query {i}",
"timestamp": int((current_time - timedelta(hours=i)).timestamp() * 1000),
},
)
graph.add_node(interaction_node)
# Create graph element nodes
for i in range(num_elements):
element_node = Node(
id=f"element_{i}",
node_type="DocumentChunk",
attributes={"type": "DocumentChunk", "text": f"Element content {i}"},
)
graph.add_node(element_node)
# Create usage edges (interactions reference elements)
for i in range(num_interactions):
# Each interaction uses 2-3 elements
for j in range(2):
element_idx = (i + j) % num_elements
edge = Edge(
node1=graph.get_node(f"interaction_{i}"),
node2=graph.get_node(f"element_{element_idx}"),
edge_type="used_graph_element_to_answer",
attributes={"relationship_type": "used_graph_element_to_answer"},
)
graph.add_edge(edge)
return graph
async def test_basic_frequency_extraction(self):
"""Test basic frequency extraction with simple graph."""
graph = self.create_mock_graph(num_interactions=3, num_elements=5)
result = await extract_usage_frequency(
subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=1
)
self.assertIn("node_frequencies", result)
self.assertIn("total_interactions", result)
self.assertEqual(result["total_interactions"], 3)
self.assertGreater(len(result["node_frequencies"]), 0)
async def test_time_window_filtering(self):
"""Test that time window correctly filters old interactions."""
graph = CogneeGraph()
current_time = datetime.now()
# Add recent interaction (within window)
recent_node = Node(
id="recent_interaction",
node_type="CogneeUserInteraction",
attributes={
"type": "CogneeUserInteraction",
"timestamp": int(current_time.timestamp() * 1000),
},
)
graph.add_node(recent_node)
# Add old interaction (outside window)
old_node = Node(
id="old_interaction",
node_type="CogneeUserInteraction",
attributes={
"type": "CogneeUserInteraction",
"timestamp": int((current_time - timedelta(days=10)).timestamp() * 1000),
},
)
graph.add_node(old_node)
# Add element
element = Node(
id="element_1", node_type="DocumentChunk", attributes={"type": "DocumentChunk"}
)
graph.add_node(element)
# Add edges
graph.add_edge(
Edge(
node1=recent_node,
node2=element,
edge_type="used_graph_element_to_answer",
attributes={"relationship_type": "used_graph_element_to_answer"},
)
)
graph.add_edge(
Edge(
node1=old_node,
node2=element,
edge_type="used_graph_element_to_answer",
attributes={"relationship_type": "used_graph_element_to_answer"},
)
)
# Extract with 7-day window
result = await extract_usage_frequency(
subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=1
)
# Should only count recent interaction
self.assertEqual(result["interactions_in_window"], 1)
self.assertEqual(result["total_interactions"], 2)
async def test_threshold_filtering(self):
"""Test that minimum threshold filters low-frequency nodes."""
graph = self.create_mock_graph(num_interactions=5, num_elements=10)
# Extract with threshold of 3
result = await extract_usage_frequency(
subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=3
)
# Only nodes with 3+ accesses should be included
for node_id, freq in result["node_frequencies"].items():
self.assertGreaterEqual(freq, 3)
async def test_element_type_tracking(self):
"""Test that element types are properly tracked."""
graph = CogneeGraph()
# Create interaction
interaction = Node(
id="interaction_1",
node_type="CogneeUserInteraction",
attributes={
"type": "CogneeUserInteraction",
"timestamp": int(datetime.now().timestamp() * 1000),
},
)
graph.add_node(interaction)
# Create elements of different types
chunk = Node(id="chunk_1", node_type="DocumentChunk", attributes={"type": "DocumentChunk"})
entity = Node(id="entity_1", node_type="Entity", attributes={"type": "Entity"})
graph.add_node(chunk)
graph.add_node(entity)
# Add edges
for element in [chunk, entity]:
graph.add_edge(
Edge(
node1=interaction,
node2=element,
edge_type="used_graph_element_to_answer",
attributes={"relationship_type": "used_graph_element_to_answer"},
)
)
result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
# Check element types were tracked
self.assertIn("element_type_frequencies", result)
types = result["element_type_frequencies"]
self.assertIn("DocumentChunk", types)
self.assertIn("Entity", types)
async def test_empty_graph(self):
"""Test handling of empty graph."""
graph = CogneeGraph()
result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
self.assertEqual(result["total_interactions"], 0)
self.assertEqual(len(result["node_frequencies"]), 0)
async def test_no_interactions_in_window(self):
"""Test handling when all interactions are outside time window."""
graph = CogneeGraph()
# Add old interaction
old_time = datetime.now() - timedelta(days=30)
old_interaction = Node(
id="old_interaction",
node_type="CogneeUserInteraction",
attributes={
"type": "CogneeUserInteraction",
"timestamp": int(old_time.timestamp() * 1000),
},
)
graph.add_node(old_interaction)
result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
self.assertEqual(result["interactions_in_window"], 0)
self.assertEqual(result["total_interactions"], 1)
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete workflow."""
def setUp(self):
"""Set up test fixtures."""
if not COGNEE_AVAILABLE:
self.skipTest("Cognee modules not available")
async def test_end_to_end_workflow(self):
"""Test the complete end-to-end frequency tracking workflow."""
# This would require a full Cognee setup with database
# Skipped in unit tests, run as part of example_usage_frequency_e2e.py
self.skipTest("E2E test - run example_usage_frequency_e2e.py instead")
# ============================================================================
# Test Runner
# ============================================================================
def run_async_test(test_func):
"""Helper to run async test functions."""
asyncio.run(test_func())
def main():
"""Run all tests."""
if not COGNEE_AVAILABLE:
print("⚠ Cognee not available - skipping tests")
print("Install with: pip install cognee[neo4j]")
return
print("=" * 80)
print("Running Usage Frequency Tests")
print("=" * 80)
print()
# Create test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()
# Add tests
suite.addTests(loader.loadTestsFromTestCase(TestUsageFrequencyExtraction))
suite.addTests(loader.loadTestsFromTestCase(TestIntegration))
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
# Summary
print()
print("=" * 80)
print("Test Summary")
print("=" * 80)
print(f"Tests run: {result.testsRun}")
print(f"Successes: {result.testsRun - len(result.failures) - len(result.errors)}")
print(f"Failures: {len(result.failures)}")
print(f"Errors: {len(result.errors)}")
print(f"Skipped: {len(result.skipped)}")
return 0 if result.wasSuccessful() else 1
if __name__ == "__main__":
exit(main())

View file

@ -149,7 +149,9 @@ async def e2e_state():
vector_engine = get_vector_engine()
collection = await vector_engine.search(
collection_name="Triplet_text", query_text="Test", limit=None
collection_name="Triplet_text",
query_text="Test",
limit=None,
)
# --- Retriever contexts ---
@ -188,57 +190,70 @@ async def e2e_state():
query_type=SearchType.GRAPH_COMPLETION,
query_text="Where is germany located, next to which country?",
save_interaction=True,
verbose=True,
)
completion_cot = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION_COT,
query_text="What is the country next to germany??",
save_interaction=True,
verbose=True,
)
completion_ext = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION,
query_text="What is the name of the country next to germany",
save_interaction=True,
verbose=True,
)
await cognee.search(
query_type=SearchType.FEEDBACK, query_text="This was not the best answer", last_k=1
query_type=SearchType.FEEDBACK,
query_text="This was not the best answer",
last_k=1,
verbose=True,
)
completion_sum = await cognee.search(
query_type=SearchType.GRAPH_SUMMARY_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=True,
verbose=True,
)
completion_triplet = await cognee.search(
query_type=SearchType.TRIPLET_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=True,
verbose=True,
)
completion_chunks = await cognee.search(
query_type=SearchType.CHUNKS,
query_text="Germany",
save_interaction=False,
verbose=True,
)
completion_summaries = await cognee.search(
query_type=SearchType.SUMMARIES,
query_text="Germany",
save_interaction=False,
verbose=True,
)
completion_rag = await cognee.search(
query_type=SearchType.RAG_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=False,
verbose=True,
)
completion_temporal = await cognee.search(
query_type=SearchType.TEMPORAL,
query_text="Next to which country is Germany located?",
save_interaction=False,
verbose=True,
)
await cognee.search(
query_type=SearchType.FEEDBACK,
query_text="This answer was great",
last_k=1,
verbose=True,
)
# Snapshot after all E2E operations above (used by assertion-only tests).

View file

@ -718,3 +718,49 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set
with pytest.raises(ValueError):
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
def test_normalize_query_distance_lists_flat_list_single_query(setup_graph):
"""Test that flat list is normalized to list-of-lists with length 1 for single-query mode."""
graph = setup_graph
flat_list = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)]
result = graph._normalize_query_distance_lists(flat_list, query_list_length=None, name="test")
assert len(result) == 1
assert result[0] == flat_list
def test_normalize_query_distance_lists_nested_list_batch_mode(setup_graph):
"""Test that nested list is used as-is when query_list_length matches."""
graph = setup_graph
nested_list = [
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
]
result = graph._normalize_query_distance_lists(nested_list, query_list_length=2, name="test")
assert len(result) == 2
assert result == nested_list
def test_normalize_query_distance_lists_raises_on_length_mismatch(setup_graph):
"""Test that ValueError is raised when nested list length doesn't match query_list_length."""
graph = setup_graph
nested_list = [
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
]
with pytest.raises(ValueError, match="test has 2 query lists, but query_list_length is 3"):
graph._normalize_query_distance_lists(nested_list, query_list_length=3, name="test")
def test_normalize_query_distance_lists_empty_list(setup_graph):
"""Test that empty list returns empty list."""
graph = setup_graph
result = graph._normalize_query_distance_lists([], query_list_length=None, name="test")
assert result == []

View file

@ -30,7 +30,7 @@ async def test_brute_force_triplet_search_empty_query():
@pytest.mark.asyncio
async def test_brute_force_triplet_search_none_query():
"""Test that None query raises ValueError."""
with pytest.raises(ValueError, match="The query must be a non-empty string."):
with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'."):
await brute_force_triplet_search(query=None)
@ -57,7 +57,7 @@ async def test_brute_force_triplet_search_wide_search_limit_global_search():
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(
@ -79,7 +79,7 @@ async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(
@ -101,7 +101,7 @@ async def test_brute_force_triplet_search_wide_search_default():
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", node_name=None)
@ -119,7 +119,7 @@ async def test_brute_force_triplet_search_default_collections():
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test")
@ -149,7 +149,7 @@ async def test_brute_force_triplet_search_custom_collections():
custom_collections = ["CustomCol1", "CustomCol2"]
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", collections=custom_collections)
@ -171,7 +171,7 @@ async def test_brute_force_triplet_search_always_includes_edge_collection():
collections_without_edge = ["Entity_name", "TextSummary_text"]
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", collections=collections_without_edge)
@ -194,7 +194,7 @@ async def test_brute_force_triplet_search_all_collections_empty():
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
):
results = await brute_force_triplet_search(query="test")
@ -216,7 +216,7 @@ async def test_brute_force_triplet_search_embeds_query():
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query=query_text)
@ -249,7 +249,7 @@ async def test_brute_force_triplet_search_extracts_node_ids_global_search():
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
@ -279,7 +279,7 @@ async def test_brute_force_triplet_search_reuses_provided_fragment():
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
@ -311,7 +311,7 @@ async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
@ -340,7 +340,7 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
@ -351,7 +351,9 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
custom_top_k = 15
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(
k=custom_top_k, query_list_length=None
)
@pytest.mark.asyncio
@ -430,7 +432,7 @@ async def test_brute_force_triplet_search_deduplicates_node_ids():
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
@ -471,7 +473,7 @@ async def test_brute_force_triplet_search_excludes_edge_collection():
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
@ -523,7 +525,7 @@ async def test_brute_force_triplet_search_skips_nodes_without_ids():
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
@ -564,7 +566,7 @@ async def test_brute_force_triplet_search_handles_tuple_results():
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
@ -606,7 +608,7 @@ async def test_brute_force_triplet_search_mixed_empty_collections():
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
@ -689,7 +691,7 @@ async def test_brute_force_triplet_search_vector_engine_init_error():
"""Test brute_force_triplet_search handles vector engine initialization error (lines 145-147)."""
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine"
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine"
) as mock_get_vector_engine,
):
mock_get_vector_engine.side_effect = Exception("Initialization error")
@ -716,7 +718,7 @@ async def test_brute_force_triplet_search_collection_not_found_error():
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
@ -743,7 +745,7 @@ async def test_brute_force_triplet_search_generic_exception():
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
):
@ -769,7 +771,7 @@ async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_no
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
@ -804,7 +806,7 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level():
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
@ -815,3 +817,237 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level():
result = await brute_force_triplet_search(query="test query")
assert result == []
@pytest.mark.asyncio
async def test_brute_force_triplet_search_single_query_regression():
"""Test that single-query mode maintains legacy behavior (flat list, ID filtering)."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("node1", 0.95)])
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
result = await brute_force_triplet_search(
query="q1", query_batch=None, wide_search_top_k=10, node_name=None
)
assert isinstance(result, list)
assert not (result and isinstance(result[0], list))
mock_get_fragment.assert_called_once()
call_kwargs = mock_get_fragment.call_args[1]
assert call_kwargs["relevant_ids_to_filter"] is not None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_batch_wiring_happy_path():
"""Test that batch mode returns list-of-lists and skips ID filtering."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.batch_search = AsyncMock(
return_value=[
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
]
)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
)
with (
patch(
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
result = await brute_force_triplet_search(query_batch=["q1", "q2"])
assert isinstance(result, list)
assert len(result) == 2
assert isinstance(result[0], list)
assert isinstance(result[1], list)
mock_get_fragment.assert_called_once()
call_kwargs = mock_get_fragment.call_args[1]
assert call_kwargs["relevant_ids_to_filter"] is None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_shape_propagation_to_graph():
"""Test that query_list_length is passed through to graph mapping methods."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.batch_search = AsyncMock(
return_value=[
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
]
)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
)
with (
patch(
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
),
):
await brute_force_triplet_search(query_batch=["q1", "q2"])
mock_fragment.map_vector_distances_to_graph_nodes.assert_called_once()
node_call_kwargs = mock_fragment.map_vector_distances_to_graph_nodes.call_args[1]
assert "query_list_length" in node_call_kwargs
assert node_call_kwargs["query_list_length"] == 2
mock_fragment.map_vector_distances_to_graph_edges.assert_called_once()
edge_call_kwargs = mock_fragment.map_vector_distances_to_graph_edges.call_args[1]
assert "query_list_length" in edge_call_kwargs
assert edge_call_kwargs["query_list_length"] == 2
mock_fragment.calculate_top_triplet_importances.assert_called_once()
importance_call_kwargs = mock_fragment.calculate_top_triplet_importances.call_args[1]
assert "query_list_length" in importance_call_kwargs
assert importance_call_kwargs["query_list_length"] == 2
@pytest.mark.asyncio
async def test_brute_force_triplet_search_batch_path_comprehensive():
"""Test batch mode: returns list-of-lists, skips ID filtering, passes None for wide_search_limit."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
def batch_search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
]
elif collection_name == "EdgeType_relationship_name":
return [
[MockScoredResult("edge1", 0.92)],
[MockScoredResult("edge2", 0.88)],
]
return [[], []]
mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
)
with (
patch(
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
result = await brute_force_triplet_search(
query_batch=["q1", "q2"], collections=["Entity_name", "EdgeType_relationship_name"]
)
assert isinstance(result, list)
assert len(result) == 2
assert isinstance(result[0], list)
assert isinstance(result[1], list)
mock_get_fragment.assert_called_once()
fragment_call_kwargs = mock_get_fragment.call_args[1]
assert fragment_call_kwargs["relevant_ids_to_filter"] is None
batch_search_calls = mock_vector_engine.batch_search.call_args_list
assert len(batch_search_calls) > 0
for call in batch_search_calls:
assert call[1]["limit"] is None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_batch_error_fallback():
"""Test that CollectionNotFoundError in batch mode returns [[], []] matching batch length."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.batch_search = AsyncMock(
side_effect=CollectionNotFoundError("Collection not found")
)
with patch(
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
):
result = await brute_force_triplet_search(query_batch=["q1", "q2"])
assert result == [[], []]
assert len(result) == 2
@pytest.mark.asyncio
async def test_cognee_graph_mapping_batch_shapes():
"""Test that CogneeGraph mapping methods accept list-of-lists with query_list_length set."""
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
graph = CogneeGraph()
node1 = Node("node1", {"name": "Node1"})
node2 = Node("node2", {"name": "Node2"})
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(node1, node2, attributes={"edge_text": "relates_to"})
graph.add_edge(edge)
node_distances_batch = {
"Entity_name": [
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
]
}
edge_distances_batch = [
[MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})],
[MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})],
]
await graph.map_vector_distances_to_graph_nodes(
node_distances=node_distances_batch, query_list_length=2
)
await graph.map_vector_distances_to_graph_edges(
edge_distances=edge_distances_batch, query_list_length=2
)
assert node1.attributes.get("vector_distance") == [0.95, 3.5]
assert node2.attributes.get("vector_distance") == [3.5, 0.87]
assert edge.attributes.get("vector_distance") == [0.92, 0.88]

View file

@ -0,0 +1,273 @@
import pytest
from unittest.mock import AsyncMock
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
class MockScoredResult:
"""Mock class for vector search results."""
def __init__(self, id, score, payload=None):
self.id = id
self.score = score
self.payload = payload or {}
@pytest.mark.asyncio
async def test_node_edge_vector_search_single_query_shape():
"""Test that single query mode produces flat lists (not list-of-lists)."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
node_results = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)]
edge_results = [MockScoredResult("edge1", 0.92)]
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "EdgeType_relationship_name":
return edge_results
return node_results
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
collections = ["Entity_name", "EdgeType_relationship_name"]
await vector_search.embed_and_retrieve_distances(
query="test query", query_batch=None, collections=collections, wide_search_limit=10
)
assert vector_search.query_list_length is None
assert vector_search.edge_distances == edge_results
assert vector_search.node_distances["Entity_name"] == node_results
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with(["test query"])
@pytest.mark.asyncio
async def test_node_edge_vector_search_batch_query_shape_and_empties():
"""Test that batch query mode produces list-of-lists with correct length and handles empty collections."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
query_batch = ["query a", "query b"]
node_results_query_a = [MockScoredResult("node1", 0.95)]
node_results_query_b = [MockScoredResult("node2", 0.87)]
edge_results_query_a = [MockScoredResult("edge1", 0.92)]
edge_results_query_b = []
def batch_search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "EdgeType_relationship_name":
return [edge_results_query_a, edge_results_query_b]
elif collection_name == "Entity_name":
return [node_results_query_a, node_results_query_b]
elif collection_name == "MissingCollection":
raise CollectionNotFoundError("Collection not found")
return [[], []]
mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect)
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
collections = [
"Entity_name",
"EdgeType_relationship_name",
"MissingCollection",
"EmptyCollection",
]
await vector_search.embed_and_retrieve_distances(
query=None, query_batch=query_batch, collections=collections, wide_search_limit=None
)
assert vector_search.query_list_length == 2
assert len(vector_search.edge_distances) == 2
assert vector_search.edge_distances[0] == edge_results_query_a
assert vector_search.edge_distances[1] == edge_results_query_b
assert len(vector_search.node_distances["Entity_name"]) == 2
assert vector_search.node_distances["Entity_name"][0] == node_results_query_a
assert vector_search.node_distances["Entity_name"][1] == node_results_query_b
assert len(vector_search.node_distances["MissingCollection"]) == 2
assert vector_search.node_distances["MissingCollection"] == [[], []]
assert len(vector_search.node_distances["EmptyCollection"]) == 2
assert vector_search.node_distances["EmptyCollection"] == [[], []]
mock_vector_engine.embedding_engine.embed_text.assert_not_called()
@pytest.mark.asyncio
async def test_node_edge_vector_search_input_validation_both_provided():
"""Test that providing both query and query_batch raises ValueError."""
mock_vector_engine = AsyncMock()
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
collections = ["Entity_name"]
with pytest.raises(ValueError, match="Cannot provide both 'query' and 'query_batch'"):
await vector_search.embed_and_retrieve_distances(
query="test", query_batch=["test1", "test2"], collections=collections
)
@pytest.mark.asyncio
async def test_node_edge_vector_search_input_validation_neither_provided():
"""Test that providing neither query nor query_batch raises ValueError."""
mock_vector_engine = AsyncMock()
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
collections = ["Entity_name"]
with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'"):
await vector_search.embed_and_retrieve_distances(
query=None, query_batch=None, collections=collections
)
@pytest.mark.asyncio
async def test_node_edge_vector_search_extract_relevant_node_ids_single_query():
"""Test that extract_relevant_node_ids returns IDs for single query mode."""
mock_vector_engine = AsyncMock()
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
vector_search.query_list_length = None
vector_search.node_distances = {
"Entity_name": [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)],
"TextSummary_text": [MockScoredResult("node1", 0.90), MockScoredResult("node3", 0.92)],
}
node_ids = vector_search.extract_relevant_node_ids()
assert set(node_ids) == {"node1", "node2", "node3"}
@pytest.mark.asyncio
async def test_node_edge_vector_search_extract_relevant_node_ids_batch():
"""Test that extract_relevant_node_ids returns empty list for batch mode."""
mock_vector_engine = AsyncMock()
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
vector_search.query_list_length = 2
vector_search.node_distances = {
"Entity_name": [
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
],
}
node_ids = vector_search.extract_relevant_node_ids()
assert node_ids == []
@pytest.mark.asyncio
async def test_node_edge_vector_search_has_results_single_query():
"""Test has_results returns True when results exist and False when only empties."""
mock_vector_engine = AsyncMock()
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
vector_search.edge_distances = [MockScoredResult("edge1", 0.92)]
vector_search.node_distances = {}
assert vector_search.has_results() is True
vector_search.edge_distances = []
vector_search.node_distances = {"Entity_name": [MockScoredResult("node1", 0.95)]}
assert vector_search.has_results() is True
vector_search.edge_distances = []
vector_search.node_distances = {}
assert vector_search.has_results() is False
@pytest.mark.asyncio
async def test_node_edge_vector_search_has_results_batch():
"""Test has_results works correctly for batch mode with list-of-lists."""
mock_vector_engine = AsyncMock()
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
vector_search.query_list_length = 2
vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []]
vector_search.node_distances = {}
assert vector_search.has_results() is True
vector_search.edge_distances = [[], []]
vector_search.node_distances = {
"Entity_name": [[MockScoredResult("node1", 0.95)], []],
}
assert vector_search.has_results() is True
vector_search.edge_distances = [[], []]
vector_search.node_distances = {"Entity_name": [[], []]}
assert vector_search.has_results() is False
@pytest.mark.asyncio
async def test_node_edge_vector_search_single_query_collection_not_found():
"""Test that CollectionNotFoundError in single query mode returns empty list."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(
side_effect=CollectionNotFoundError("Collection not found")
)
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
collections = ["MissingCollection"]
await vector_search.embed_and_retrieve_distances(
query="test query", query_batch=None, collections=collections, wide_search_limit=10
)
assert vector_search.node_distances["MissingCollection"] == []
@pytest.mark.asyncio
async def test_node_edge_vector_search_missing_collections_single_query():
"""Test that missing collections in single-query mode are handled gracefully with empty lists."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
node_result = MockScoredResult("node1", 0.95)
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [node_result]
elif collection_name == "MissingCollection":
raise CollectionNotFoundError("Collection not found")
return []
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
collections = ["Entity_name", "MissingCollection", "EmptyCollection"]
await vector_search.embed_and_retrieve_distances(
query="test query", query_batch=None, collections=collections, wide_search_limit=10
)
assert len(vector_search.node_distances["Entity_name"]) == 1
assert vector_search.node_distances["Entity_name"][0].id == "node1"
assert vector_search.node_distances["Entity_name"][0].score == 0.95
assert vector_search.node_distances["MissingCollection"] == []
assert vector_search.node_distances["EmptyCollection"] == []
@pytest.mark.asyncio
async def test_node_edge_vector_search_has_results_batch_nodes_only():
"""Test has_results returns True when only node distances are populated in batch mode."""
mock_vector_engine = AsyncMock()
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
vector_search.query_list_length = 2
vector_search.edge_distances = [[], []]
vector_search.node_distances = {
"Entity_name": [[MockScoredResult("node1", 0.95)], []],
}
assert vector_search.has_results() is True
@pytest.mark.asyncio
async def test_node_edge_vector_search_has_results_batch_edges_only():
"""Test has_results returns True when only edge distances are populated in batch mode."""
mock_vector_engine = AsyncMock()
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
vector_search.query_list_length = 2
vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []]
vector_search.node_distances = {}
assert vector_search.has_results() is True

View file

@ -129,14 +129,32 @@ async def test_search_access_control_returns_dataset_shaped_dicts(monkeypatch, s
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
out_non_verbose = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
verbose=False,
)
assert out == [
assert out_non_verbose == [
{
"search_result": ["r"],
"dataset_id": ds.id,
"dataset_name": "ds1",
"dataset_tenant_id": "t1",
}
]
out_verbose = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
verbose=True,
)
assert out_verbose == [
{
"search_result": ["r"],
"dataset_id": ds.id,
@ -166,6 +184,7 @@ async def test_search_access_control_only_context_returns_dataset_shaped_dicts(
dataset_ids=[ds.id],
user=user,
only_context=True,
verbose=True,
)
assert out == [
@ -180,35 +199,7 @@ async def test_search_access_control_only_context_returns_dataset_shaped_dicts(
@pytest.mark.asyncio
async def test_search_access_control_use_combined_context_returns_combined_model(
monkeypatch, search_mod
):
user = _make_user()
ds1 = _make_dataset(name="ds1", tenant_id="t1")
ds2 = _make_dataset(name="ds2", tenant_id="t1")
async def dummy_authorized_search(**_kwargs):
return ("answer", {"k": "v"}, [ds1, ds2])
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds1.id, ds2.id],
user=user,
use_combined_context=True,
)
assert out.result == "answer"
assert out.context == {"k": "v"}
assert out.graphs == {}
assert [d.id for d in out.datasets] == [ds1.id, ds2.id]
@pytest.mark.asyncio
async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod):
async def test_authorized_search_delegates_to_search_in_datasets_context(monkeypatch, search_mod):
user = _make_user()
ds = _make_dataset(name="ds1")
@ -218,7 +209,6 @@ async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod)
expected = [("r", ["ctx"], [ds])]
async def dummy_search_in_datasets_context(**kwargs):
assert kwargs["use_combined_context"] is False if "use_combined_context" in kwargs else True
return expected
monkeypatch.setattr(
@ -231,104 +221,12 @@ async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod)
query_text="q",
user=user,
dataset_ids=[ds.id],
use_combined_context=False,
only_context=False,
)
assert out == expected
@pytest.mark.asyncio
async def test_authorized_search_use_combined_context_joins_string_context(monkeypatch, search_mod):
user = _make_user()
ds1 = _make_dataset(name="ds1")
ds2 = _make_dataset(name="ds2")
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
return [ds1, ds2]
async def dummy_search_in_datasets_context(**kwargs):
assert kwargs["only_context"] is True
return [(None, ["a"], [ds1]), (None, ["b"], [ds2])]
seen = {}
async def dummy_get_completion(query_text, context, session_id=None):
seen["query_text"] = query_text
seen["context"] = context
seen["session_id"] = session_id
return ["answer"]
async def dummy_get_search_type_tools(**_kwargs):
return [dummy_get_completion, lambda *_a, **_k: None]
monkeypatch.setattr(
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
)
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
completion, combined_context, datasets = await search_mod.authorized_search(
query_type=SearchType.CHUNKS,
query_text="q",
user=user,
dataset_ids=[ds1.id, ds2.id],
use_combined_context=True,
session_id="s1",
)
assert combined_context == "a\nb"
assert completion == ["answer"]
assert datasets == [ds1, ds2]
assert seen == {"query_text": "q", "context": "a\nb", "session_id": "s1"}
@pytest.mark.asyncio
async def test_authorized_search_use_combined_context_keeps_non_string_context(
monkeypatch, search_mod
):
user = _make_user()
ds1 = _make_dataset(name="ds1")
ds2 = _make_dataset(name="ds2")
class DummyEdge:
pass
e1, e2 = DummyEdge(), DummyEdge()
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
return [ds1, ds2]
async def dummy_search_in_datasets_context(**_kwargs):
return [(None, [e1], [ds1]), (None, [e2], [ds2])]
async def dummy_get_completion(query_text, context, session_id=None):
assert query_text == "q"
assert context == [e1, e2]
return ["answer"]
async def dummy_get_search_type_tools(**_kwargs):
return [dummy_get_completion]
monkeypatch.setattr(
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
)
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
completion, combined_context, datasets = await search_mod.authorized_search(
query_type=SearchType.CHUNKS,
query_text="q",
user=user,
dataset_ids=[ds1.id, ds2.id],
use_combined_context=True,
)
assert combined_context == [e1, e2]
assert completion == ["answer"]
assert datasets == [ds1, ds2]
@pytest.mark.asyncio
async def test_search_in_datasets_context_two_tool_context_override_and_is_empty_branches(
monkeypatch, search_mod

View file

@ -90,6 +90,7 @@ async def test_search_access_control_edges_context_produces_graphs_and_context_m
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
verbose=True,
)
assert out[0]["dataset_name"] == "ds1"
@ -126,6 +127,7 @@ async def test_search_access_control_insights_context_produces_graphs_and_null_r
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
verbose=True,
)
assert out[0]["graphs"] is not None
@ -150,6 +152,7 @@ async def test_search_access_control_only_context_returns_context_text_map(monke
dataset_ids=[ds.id],
user=user,
only_context=True,
verbose=True,
)
assert out[0]["search_result"] == [{"ds1": "a\nb"}]
@ -172,6 +175,7 @@ async def test_search_access_control_results_edges_become_graph_result(monkeypat
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
verbose=True,
)
assert isinstance(out[0]["search_result"][0], dict)
@ -179,29 +183,6 @@ async def test_search_access_control_results_edges_become_graph_result(monkeypat
assert "edges" in out[0]["search_result"][0]
@pytest.mark.asyncio
async def test_search_use_combined_context_defaults_empty_datasets(monkeypatch, search_mod):
user = types.SimpleNamespace(id="u1", tenant_id=None)
async def dummy_authorized_search(**_kwargs):
return ("answer", "ctx", [])
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=None,
user=user,
use_combined_context=True,
)
assert out.result == "answer"
assert out.context == {"all available datasets": "ctx"}
assert out.datasets[0].name == "all available datasets"
@pytest.mark.asyncio
async def test_search_access_control_context_str_branch(monkeypatch, search_mod):
"""Covers prepare_search_result(context is str) through search()."""
@ -219,6 +200,7 @@ async def test_search_access_control_context_str_branch(monkeypatch, search_mod)
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
verbose=True,
)
assert out[0]["graphs"] is None
@ -242,6 +224,7 @@ async def test_search_access_control_context_empty_list_branch(monkeypatch, sear
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
verbose=True,
)
assert out[0]["graphs"] is None
@ -265,6 +248,7 @@ async def test_search_access_control_multiple_results_list_branch(monkeypatch, s
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
verbose=True,
)
assert out[0]["search_result"] == [["r1", "r2"]]
@ -293,4 +277,5 @@ async def test_search_access_control_defaults_empty_datasets(monkeypatch, search
query_type=SearchType.CHUNKS,
dataset_ids=None,
user=user,
verbose=True,
)

View file

@ -20,19 +20,29 @@ echo "HTTP port: $HTTP_PORT"
# smooth redeployments and container restarts while maintaining data integrity.
echo "Running database migrations..."
set +e # Disable exit on error to handle specific migration errors
MIGRATION_OUTPUT=$(alembic upgrade head)
MIGRATION_EXIT_CODE=$?
set -e
if [[ $MIGRATION_EXIT_CODE -ne 0 ]]; then
if [[ "$MIGRATION_OUTPUT" == *"UserAlreadyExists"* ]] || [[ "$MIGRATION_OUTPUT" == *"User default_user@example.com already exists"* ]]; then
echo "Warning: Default user already exists, continuing startup..."
else
echo "Migration failed with unexpected error."
exit 1
fi
fi
echo "Migration failed with unexpected error. Trying to run Cognee without migrations."
echo "Database migrations done."
echo "Initializing database tables..."
python /app/cognee/modules/engine/operations/setup.py
INIT_EXIT_CODE=$?
if [[ $INIT_EXIT_CODE -ne 0 ]]; then
echo "Database initialization failed!"
exit 1
fi
fi
else
echo "Database migrations done."
fi
echo "Starting server..."

View file

@ -1,8 +1,9 @@
import asyncio
import cognee
import os
from pprint import pprint
# By default cognee uses OpenAI's gpt-5-mini LLM model
# Provide your OpenAI LLM API KEY
os.environ["LLM_API_KEY"] = ""
@ -24,13 +25,13 @@ async def cognee_demo():
# Query Cognee for information from provided document
answer = await cognee.search("List me all the important characters in Alice in Wonderland.")
print(answer)
pprint(answer)
answer = await cognee.search("How did Alice end up in Wonderland?")
print(answer)
pprint(answer)
answer = await cognee.search("Tell me about Alice's personality.")
print(answer)
pprint(answer)
# Cognee is an async library, it has to be called in an async context

View file

@ -1,4 +1,5 @@
import asyncio
from pprint import pprint
import cognee
from cognee.api.v1.search import SearchType
@ -187,7 +188,7 @@ async def main(enable_steps):
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="Who has experience in design tools?"
)
print(search_results)
pprint(search_results)
if __name__ == "__main__":

View file

@ -0,0 +1,482 @@
#!/usr/bin/env python3
"""
End-to-End Example: Usage Frequency Tracking in Cognee
This example demonstrates the complete workflow for tracking and analyzing
how frequently different graph elements are accessed through user searches.
Features demonstrated:
- Setting up a knowledge base
- Running searches with interaction tracking (save_interaction=True)
- Extracting usage frequencies from interaction data
- Applying frequency weights to graph nodes
- Analyzing and visualizing the results
Use cases:
- Ranking search results by popularity
- Identifying "hot topics" in your knowledge base
- Understanding user behavior and interests
- Improving retrieval based on usage patterns
"""
import asyncio
import os
from datetime import timedelta
from typing import List, Dict, Any
from dotenv import load_dotenv
import cognee
from cognee.api.v1.search import SearchType
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update
# Load environment variables
load_dotenv()
# ============================================================================
# STEP 1: Setup and Configuration
# ============================================================================
async def setup_knowledge_base():
"""
Create a fresh knowledge base with sample content.
In a real application, you would:
- Load documents from files, databases, or APIs
- Process larger datasets
- Organize content by datasets/categories
"""
print("=" * 80)
print("STEP 1: Setting up knowledge base")
print("=" * 80)
# Reset state for clean demo (optional in production)
print("\nResetting Cognee state...")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
print("✓ Reset complete")
# Sample content: AI/ML educational material
documents = [
"""
Machine Learning Fundamentals:
Machine learning is a subset of artificial intelligence that enables systems
to learn and improve from experience without being explicitly programmed.
The three main types are supervised learning, unsupervised learning, and
reinforcement learning.
""",
"""
Neural Networks Explained:
Neural networks are computing systems inspired by biological neural networks.
They consist of layers of interconnected nodes (neurons) that process information
through weighted connections. Deep learning uses neural networks with many layers
to automatically learn hierarchical representations of data.
""",
"""
Natural Language Processing:
NLP enables computers to understand, interpret, and generate human language.
Modern NLP uses transformer architectures like BERT and GPT, which have
revolutionized tasks such as translation, summarization, and question answering.
""",
"""
Computer Vision Applications:
Computer vision allows machines to interpret visual information from the world.
Convolutional neural networks (CNNs) are particularly effective for image
recognition, object detection, and image segmentation tasks.
""",
]
print(f"\nAdding {len(documents)} documents to knowledge base...")
await cognee.add(documents, dataset_name="ai_ml_fundamentals")
print("✓ Documents added")
# Build knowledge graph
print("\nBuilding knowledge graph (cognify)...")
await cognee.cognify()
print("✓ Knowledge graph built")
print("\n" + "=" * 80)
# ============================================================================
# STEP 2: Simulate User Searches with Interaction Tracking
# ============================================================================
async def simulate_user_searches(queries: List[str]):
"""
Simulate users searching the knowledge base.
The key parameter is save_interaction=True, which creates:
- CogneeUserInteraction nodes (one per search)
- used_graph_element_to_answer edges (connecting queries to relevant nodes)
Args:
queries: List of search queries to simulate
Returns:
Number of successful searches
"""
print("=" * 80)
print("STEP 2: Simulating user searches with interaction tracking")
print("=" * 80)
successful_searches = 0
for i, query in enumerate(queries, 1):
print(f"\nSearch {i}/{len(queries)}: '{query}'")
try:
results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=query,
save_interaction=True, # ← THIS IS CRITICAL!
top_k=5,
)
successful_searches += 1
# Show snippet of results
result_preview = str(results)[:100] if results else "No results"
print(f" ✓ Completed ({result_preview}...)")
except Exception as e:
print(f" ✗ Failed: {e}")
print(f"\n✓ Completed {successful_searches}/{len(queries)} searches")
print("=" * 80)
return successful_searches
# ============================================================================
# STEP 3: Extract and Apply Usage Frequencies
# ============================================================================
async def extract_and_apply_frequencies(
time_window_days: int = 7, min_threshold: int = 1
) -> Dict[str, Any]:
"""
Extract usage frequencies from interactions and apply them to the graph.
This function:
1. Retrieves the graph with interaction data
2. Counts how often each node was accessed
3. Writes frequency_weight property back to nodes
Args:
time_window_days: Only count interactions from last N days
min_threshold: Minimum accesses to track (filter out rarely used nodes)
Returns:
Dictionary with statistics about the frequency update
"""
print("=" * 80)
print("STEP 3: Extracting and applying usage frequencies")
print("=" * 80)
# Get graph adapter
graph_engine = await get_graph_engine()
# Retrieve graph with interactions
print("\nRetrieving graph from database...")
graph = CogneeGraph()
await graph.project_graph_from_db(
adapter=graph_engine,
node_properties_to_project=[
"type",
"node_type",
"timestamp",
"created_at",
"text",
"name",
"query_text",
"frequency_weight",
],
edge_properties_to_project=["relationship_type", "timestamp"],
directed=True,
)
print(f"✓ Retrieved: {len(graph.nodes)} nodes, {len(graph.edges)} edges")
# Count interaction nodes
interaction_nodes = [
n
for n in graph.nodes.values()
if n.attributes.get("type") == "CogneeUserInteraction"
or n.attributes.get("node_type") == "CogneeUserInteraction"
]
print(f"✓ Found {len(interaction_nodes)} interaction nodes")
# Run frequency extraction and update
print(f"\nExtracting frequencies (time window: {time_window_days} days)...")
stats = await run_usage_frequency_update(
graph_adapter=graph_engine,
subgraphs=[graph],
time_window=timedelta(days=time_window_days),
min_interaction_threshold=min_threshold,
)
print("\n✓ Frequency extraction complete!")
print(
f" - Interactions processed: {stats['interactions_in_window']}/{stats['total_interactions']}"
)
print(f" - Nodes weighted: {len(stats['node_frequencies'])}")
print(f" - Element types tracked: {stats.get('element_type_frequencies', {})}")
print("=" * 80)
return stats
# ============================================================================
# STEP 4: Analyze and Display Results
# ============================================================================
async def analyze_results(stats: Dict[str, Any]):
"""
Analyze and display the frequency tracking results.
Shows:
- Top most frequently accessed nodes
- Element type distribution
- Verification that weights were written to database
Args:
stats: Statistics from frequency extraction
"""
print("=" * 80)
print("STEP 4: Analyzing usage frequency results")
print("=" * 80)
# Display top nodes by frequency
if stats["node_frequencies"]:
print("\n📊 Top 10 Most Frequently Accessed Elements:")
print("-" * 80)
sorted_nodes = sorted(stats["node_frequencies"].items(), key=lambda x: x[1], reverse=True)
# Get graph to display node details
graph_engine = await get_graph_engine()
graph = CogneeGraph()
await graph.project_graph_from_db(
adapter=graph_engine,
node_properties_to_project=["type", "text", "name"],
edge_properties_to_project=[],
directed=True,
)
for i, (node_id, frequency) in enumerate(sorted_nodes[:10], 1):
node = graph.get_node(node_id)
if node:
node_type = node.attributes.get("type", "Unknown")
text = node.attributes.get("text") or node.attributes.get("name") or ""
text_preview = text[:60] + "..." if len(text) > 60 else text
print(f"\n{i}. Frequency: {frequency} accesses")
print(f" Type: {node_type}")
print(f" Content: {text_preview}")
else:
print(f"\n{i}. Frequency: {frequency} accesses")
print(f" Node ID: {node_id[:50]}...")
# Display element type distribution
if stats.get("element_type_frequencies"):
print("\n\n📈 Element Type Distribution:")
print("-" * 80)
type_dist = stats["element_type_frequencies"]
for elem_type, count in sorted(type_dist.items(), key=lambda x: x[1], reverse=True):
print(f" {elem_type}: {count} accesses")
# Verify weights in database (Neo4j only)
print("\n\n🔍 Verifying weights in database...")
print("-" * 80)
graph_engine = await get_graph_engine()
adapter_type = type(graph_engine).__name__
if adapter_type == "Neo4jAdapter":
try:
result = await graph_engine.query("""
MATCH (n)
WHERE n.frequency_weight IS NOT NULL
RETURN count(n) as weighted_count
""")
count = result[0]["weighted_count"] if result else 0
if count > 0:
print(f"{count} nodes have frequency_weight in Neo4j database")
# Show sample
sample = await graph_engine.query("""
MATCH (n)
WHERE n.frequency_weight IS NOT NULL
RETURN n.frequency_weight as weight, labels(n) as labels
ORDER BY n.frequency_weight DESC
LIMIT 3
""")
print("\nSample weighted nodes:")
for row in sample:
print(f" - Weight: {row['weight']}, Type: {row['labels']}")
else:
print("⚠ No nodes with frequency_weight found in database")
except Exception as e:
print(f"Could not verify in Neo4j: {e}")
else:
print(f"Database verification not implemented for {adapter_type}")
print("\n" + "=" * 80)
# ============================================================================
# STEP 5: Demonstrate Usage in Retrieval
# ============================================================================
async def demonstrate_retrieval_usage():
"""
Demonstrate how frequency weights can be used in retrieval.
Note: This is a conceptual demonstration. To actually use frequency
weights in ranking, you would need to modify the retrieval/completion
strategies to incorporate the frequency_weight property.
"""
print("=" * 80)
print("STEP 5: How to use frequency weights in retrieval")
print("=" * 80)
print("""
Frequency weights can be used to improve search results:
1. RANKING BOOST:
- Multiply relevance scores by frequency_weight
- Prioritize frequently accessed nodes in results
2. COMPLETION STRATEGIES:
- Adjust triplet importance based on usage
- Filter out rarely accessed information
3. ANALYTICS:
- Track trending topics over time
- Understand user interests and behavior
- Identify knowledge gaps (low-frequency nodes)
4. ADAPTIVE RETRIEVAL:
- Personalize results based on team usage patterns
- Surface popular answers faster
Example Cypher query with frequency boost (Neo4j):
MATCH (n)
WHERE n.text CONTAINS $search_term
RETURN n, n.frequency_weight as boost
ORDER BY (n.relevance_score * COALESCE(n.frequency_weight, 1)) DESC
LIMIT 10
To integrate this into Cognee, you would modify the completion
strategy to include frequency_weight in the scoring function.
""")
print("=" * 80)
# ============================================================================
# MAIN: Run Complete Example
# ============================================================================
async def main():
"""
Run the complete end-to-end usage frequency tracking example.
"""
print("\n")
print("" + "=" * 78 + "")
print("" + " " * 78 + "")
print("" + " Usage Frequency Tracking - End-to-End Example".center(78) + "")
print("" + " " * 78 + "")
print("" + "=" * 78 + "")
print("\n")
# Configuration check
print("Configuration:")
print(f" Graph Provider: {os.getenv('GRAPH_DATABASE_PROVIDER')}")
print(f" Graph Handler: {os.getenv('GRAPH_DATASET_HANDLER')}")
print(f" LLM Provider: {os.getenv('LLM_PROVIDER')}")
# Verify LLM key is set
if not os.getenv("LLM_API_KEY") or os.getenv("LLM_API_KEY") == "sk-your-key-here":
print("\n⚠ WARNING: LLM_API_KEY not set in .env file")
print(" Set your API key to run searches")
return
print("\n")
try:
# Step 1: Setup
await setup_knowledge_base()
# Step 2: Simulate searches
# Note: Repeat queries increase frequency for those topics
queries = [
"What is machine learning?",
"Explain neural networks",
"How does deep learning work?",
"Tell me about neural networks", # Repeat - increases frequency
"What are transformers in NLP?",
"Explain neural networks again", # Another repeat
"How does computer vision work?",
"What is reinforcement learning?",
"Tell me more about neural networks", # Third repeat
]
successful_searches = await simulate_user_searches(queries)
if successful_searches == 0:
print("⚠ No searches completed - cannot demonstrate frequency tracking")
return
# Step 3: Extract frequencies
stats = await extract_and_apply_frequencies(time_window_days=7, min_threshold=1)
# Step 4: Analyze results
await analyze_results(stats)
# Step 5: Show usage examples
await demonstrate_retrieval_usage()
# Summary
print("\n")
print("" + "=" * 78 + "")
print("" + " " * 78 + "")
print("" + " Example Complete!".center(78) + "")
print("" + " " * 78 + "")
print("" + "=" * 78 + "")
print("\n")
print("Summary:")
print(" ✓ Documents added: 4")
print(f" ✓ Searches performed: {successful_searches}")
print(f" ✓ Interactions tracked: {stats['interactions_in_window']}")
print(f" ✓ Nodes weighted: {len(stats['node_frequencies'])}")
print("\nNext steps:")
print(" 1. Open Neo4j Browser (http://localhost:7474) to explore the graph")
print(" 2. Modify retrieval strategies to use frequency_weight")
print(" 3. Build analytics dashboards using element_type_frequencies")
print(" 4. Run periodic frequency updates to track trends over time")
print("\n")
except Exception as e:
print(f"\n✗ Example failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -1,6 +1,8 @@
import os
import asyncio
import pathlib
from pprint import pprint
from cognee.shared.logging_utils import setup_logging, ERROR
import cognee
@ -42,7 +44,7 @@ async def main():
# Display search results
for result_text in search_results:
print(result_text)
pprint(result_text)
if __name__ == "__main__":

View file

@ -1,5 +1,6 @@
import asyncio
import os
from pprint import pprint
import cognee
from cognee.api.v1.search import SearchType
@ -77,7 +78,7 @@ async def main():
query_type=SearchType.GRAPH_COMPLETION,
query_text="What are the exact cars and their types produced by Audi?",
)
print(search_results)
pprint(search_results)
await visualize_graph()

View file

@ -1,6 +1,7 @@
import os
import cognee
import pathlib
from pprint import pprint
from cognee.modules.users.exceptions import PermissionDeniedError
from cognee.modules.users.tenants.methods import select_tenant
@ -86,7 +87,7 @@ async def main():
)
print("\nSearch results as user_1 on dataset owned by user_1:")
for result in search_results:
print(f"{result}\n")
pprint(result)
# But user_1 cant read the dataset owned by user_2 (QUANTUM dataset)
print("\nSearch result as user_1 on the dataset owned by user_2:")
@ -134,7 +135,7 @@ async def main():
dataset_ids=[quantum_dataset_id],
)
for result in search_results:
print(f"{result}\n")
pprint(result)
# If we'd like for user_1 to add new documents to the QUANTUM dataset owned by user_2, user_1 would have to get
# "write" access permission, which user_1 currently does not have
@ -217,7 +218,7 @@ async def main():
dataset_ids=[quantum_cognee_lab_dataset_id],
)
for result in search_results:
print(f"{result}\n")
pprint(result)
# Note: All of these function calls and permission system is available through our backend endpoints as well

View file

@ -1,4 +1,6 @@
import asyncio
from pprint import pprint
import cognee
from cognee.modules.engine.operations.setup import setup
from cognee.modules.users.methods import get_default_user
@ -71,7 +73,7 @@ async def main():
print("Search results:")
# Display results
for result_text in search_results:
print(result_text)
pprint(result_text)
if __name__ == "__main__":

View file

@ -1,4 +1,6 @@
import asyncio
from pprint import pprint
import cognee
from cognee.shared.logging_utils import setup_logging, ERROR
from cognee.api.v1.search import SearchType
@ -54,7 +56,7 @@ async def main():
print("Search results:")
# Display results
for result_text in search_results:
print(result_text)
pprint(result_text)
if __name__ == "__main__":

View file

@ -1,4 +1,5 @@
import asyncio
from pprint import pprint
import cognee
from cognee.shared.logging_utils import setup_logging, INFO
from cognee.api.v1.search import SearchType
@ -87,7 +88,8 @@ async def main():
top_k=15,
)
print(f"Query: {query_text}")
print(f"Results: {search_results}\n")
print("Results:")
pprint(search_results)
if __name__ == "__main__":

View file

@ -1,4 +1,5 @@
import asyncio
from pprint import pprint
import cognee
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
@ -65,7 +66,7 @@ async def main():
query_type=SearchType.TRIPLET_COMPLETION,
query_text="What are the models produced by Volkswagen based on the context?",
)
print(search_results)
pprint(search_results)
if __name__ == "__main__":

View file

@ -1,7 +1,7 @@
[project]
name = "cognee"
version = "0.5.1.dev0"
version = "0.5.1"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = [
{ name = "Vasilije Markovic" },

9461
uv.lock generated

File diff suppressed because it is too large Load diff