chore: merge dev
This commit is contained in:
commit
2cdbc02b35
39 changed files with 7673 additions and 7387 deletions
|
|
@ -76,7 +76,7 @@ git clone https://github.com/<your-github-username>/cognee.git
|
|||
cd cognee
|
||||
```
|
||||
In case you are working on Vector and Graph Adapters
|
||||
1. Fork the [**cognee**](https://github.com/topoteretes/cognee-community) repository
|
||||
1. Fork the [**cognee-community**](https://github.com/topoteretes/cognee-community) repository
|
||||
2. Clone your fork:
|
||||
```shell
|
||||
git clone https://github.com/<your-github-username>/cognee-community.git
|
||||
|
|
@ -120,6 +120,21 @@ or
|
|||
uv run python examples/python/simple_example.py
|
||||
```
|
||||
|
||||
### Running Simple Example
|
||||
|
||||
Change .env.example into .env and provide your OPENAI_API_KEY as LLM_API_KEY
|
||||
|
||||
Make sure to run ```shell uv sync ``` in the root cloned folder or set up a virtual environment to run cognee
|
||||
|
||||
```shell
|
||||
python cognee/cognee/examples/python/simple_example.py
|
||||
```
|
||||
or
|
||||
|
||||
```shell
|
||||
uv run python cognee/cognee/examples/python/simple_example.py
|
||||
```
|
||||
|
||||
## 4. 📤 Submitting Changes
|
||||
|
||||
1. Make sure that `pre-commit` and hooks are installed. See `Required tools` section for more information. Try executing `pre-commit run` if you are not sure.
|
||||
|
|
|
|||
|
|
@ -126,6 +126,7 @@ Now, run a minimal pipeline:
|
|||
```python
|
||||
import cognee
|
||||
import asyncio
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
@ -143,7 +144,7 @@ async def main():
|
|||
|
||||
# Display the results
|
||||
for result in results:
|
||||
print(result)
|
||||
pprint(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
2612
cognee-frontend/package-lock.json
generated
2612
cognee-frontend/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
|
@ -13,7 +13,7 @@
|
|||
"classnames": "^2.5.1",
|
||||
"culori": "^4.0.1",
|
||||
"d3-force-3d": "^3.0.6",
|
||||
"next": "^16.1.7",
|
||||
"next": "^16.1.1",
|
||||
"react": "^19.2.3",
|
||||
"react-dom": "^19.2.3",
|
||||
"react-force-graph-2d": "^1.27.1",
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ class CogneeClient:
|
|||
|
||||
with redirect_stdout(sys.stderr):
|
||||
results = await self.cognee.search(
|
||||
query_type=SearchType[query_type.upper()], query_text=query_text
|
||||
query_type=SearchType[query_type.upper()], query_text=query_text, top_k=top_k
|
||||
)
|
||||
return results
|
||||
|
||||
|
|
|
|||
|
|
@ -316,7 +316,7 @@ async def save_interaction(data: str) -> list:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
async def search(search_query: str, search_type: str) -> list:
|
||||
async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
||||
|
|
@ -389,6 +389,13 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
|
||||
The search_type is case-insensitive and will be converted to uppercase.
|
||||
|
||||
top_k : int, optional
|
||||
Maximum number of results to return (default: 10).
|
||||
Controls the amount of context retrieved from the knowledge graph.
|
||||
- Lower values (3-5): Faster, more focused results
|
||||
- Higher values (10-20): More comprehensive, but slower and more context-heavy
|
||||
Helps manage response size and context window usage in MCP clients.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
|
|
@ -425,13 +432,32 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
|
||||
"""
|
||||
|
||||
async def search_task(search_query: str, search_type: str) -> str:
|
||||
"""Search the knowledge graph"""
|
||||
async def search_task(search_query: str, search_type: str, top_k: int) -> str:
|
||||
"""
|
||||
Internal task to execute knowledge graph search with result formatting.
|
||||
|
||||
Handles the actual search execution and formats results appropriately
|
||||
for MCP clients based on the search type and execution mode (API vs direct).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
search_query : str
|
||||
The search query in natural language
|
||||
search_type : str
|
||||
Type of search to perform (GRAPH_COMPLETION, CHUNKS, etc.)
|
||||
top_k : int
|
||||
Maximum number of results to return
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Formatted search results as a string, with format depending on search_type
|
||||
"""
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
search_results = await cognee_client.search(
|
||||
query_text=search_query, query_type=search_type
|
||||
query_text=search_query, query_type=search_type, top_k=top_k
|
||||
)
|
||||
|
||||
# Handle different result formats based on API vs direct mode
|
||||
|
|
@ -465,7 +491,7 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
else:
|
||||
return str(search_results)
|
||||
|
||||
search_results = await search_task(search_query, search_type)
|
||||
search_results = await search_task(search_query, search_type, top_k)
|
||||
return [types.TextContent(type="text", text=search_results)]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from fastapi import Depends, APIRouter
|
|||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from cognee.modules.search.types import SearchType, SearchResult, CombinedSearchResult
|
||||
from cognee.modules.search.types import SearchType, SearchResult
|
||||
from cognee.api.DTO import InDTO, OutDTO
|
||||
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError, UserNotFoundError
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -31,7 +31,7 @@ class SearchPayloadDTO(InDTO):
|
|||
node_name: Optional[list[str]] = Field(default=None, example=[])
|
||||
top_k: Optional[int] = Field(default=10)
|
||||
only_context: bool = Field(default=False)
|
||||
use_combined_context: bool = Field(default=False)
|
||||
verbose: bool = Field(default=False)
|
||||
|
||||
|
||||
def get_search_router() -> APIRouter:
|
||||
|
|
@ -74,7 +74,7 @@ def get_search_router() -> APIRouter:
|
|||
except Exception as error:
|
||||
return JSONResponse(status_code=500, content={"error": str(error)})
|
||||
|
||||
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
|
||||
@router.post("", response_model=Union[List[SearchResult], List])
|
||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Search for nodes in the graph database.
|
||||
|
|
@ -118,7 +118,7 @@ def get_search_router() -> APIRouter:
|
|||
"node_name": payload.node_name,
|
||||
"top_k": payload.top_k,
|
||||
"only_context": payload.only_context,
|
||||
"use_combined_context": payload.use_combined_context,
|
||||
"verbose": payload.verbose,
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
|
@ -135,8 +135,8 @@ def get_search_router() -> APIRouter:
|
|||
system_prompt=payload.system_prompt,
|
||||
node_name=payload.node_name,
|
||||
top_k=payload.top_k,
|
||||
verbose=payload.verbose,
|
||||
only_context=payload.only_context,
|
||||
use_combined_context=payload.use_combined_context,
|
||||
)
|
||||
|
||||
return jsonable_encoder(results)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Union, Optional, List, Type
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult
|
||||
from cognee.modules.search.types import SearchResult, SearchType
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.search.methods import search as search_function
|
||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||
|
|
@ -32,11 +32,11 @@ async def search(
|
|||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = 1,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[List[SearchResult], CombinedSearchResult]:
|
||||
verbose: bool = False,
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
||||
|
|
@ -126,6 +126,8 @@ async def search(
|
|||
|
||||
session_id: Optional session identifier for caching Q&A interactions. Defaults to 'default_session' if None.
|
||||
|
||||
verbose: If True, returns detailed result information including graph representation (when possible).
|
||||
|
||||
Returns:
|
||||
list: Search results in format determined by query_type:
|
||||
|
||||
|
|
@ -214,10 +216,10 @@ async def search(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
|
|
@ -17,3 +17,9 @@ async def setup():
|
|||
await create_relational_db_and_tables()
|
||||
if not backend_access_control_enabled():
|
||||
await create_pgvector_db_and_tables()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(setup())
|
||||
|
|
|
|||
|
|
@ -215,9 +215,6 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
edge_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.add_edge(edge)
|
||||
|
||||
source_node.add_skeleton_edge(edge)
|
||||
target_node.add_skeleton_edge(edge)
|
||||
else:
|
||||
raise EntityNotFoundError(
|
||||
message=f"Edge references nonexistent nodes: {source_id} -> {target_id}"
|
||||
|
|
|
|||
|
|
@ -27,6 +27,5 @@ await cognee.cognify(datasets=["python-development-with-cognee"], temporal_cogni
|
|||
results = await cognee.search(
|
||||
"What Python type hinting challenges did I face, and how does Guido approach similar problems in mypy?",
|
||||
datasets=["python-development-with-cognee"],
|
||||
use_combined_context=True, # Used to show reasoning graph visualization
|
||||
)
|
||||
print(results)
|
||||
|
|
|
|||
|
|
@ -1,21 +1,18 @@
|
|||
import asyncio
|
||||
import time
|
||||
from typing import List, Optional, Type
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
|
||||
|
||||
logger = get_logger(level=ERROR)
|
||||
|
||||
|
||||
def format_triplets(edges):
|
||||
"""Formats edges into human-readable triplet strings."""
|
||||
triplets = []
|
||||
for edge in edges:
|
||||
node1 = edge.node1
|
||||
|
|
@ -24,12 +21,10 @@ def format_triplets(edges):
|
|||
node1_attributes = node1.attributes
|
||||
node2_attributes = node2.attributes
|
||||
|
||||
# Filter only non-None properties
|
||||
node1_info = {key: value for key, value in node1_attributes.items() if value is not None}
|
||||
node2_info = {key: value for key, value in node2_attributes.items() if value is not None}
|
||||
edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
|
||||
|
||||
# Create the formatted triplet
|
||||
triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n"
|
||||
triplets.append(triplet)
|
||||
|
||||
|
|
@ -51,7 +46,6 @@ async def get_memory_fragment(
|
|||
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
|
|
@ -61,20 +55,64 @@ async def get_memory_fragment(
|
|||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
except EntityNotFoundError:
|
||||
# This is expected behavior - continue with empty fragment
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error during memory fragment creation: {str(e)}")
|
||||
# Still return the fragment even if projection failed
|
||||
pass
|
||||
|
||||
return memory_fragment
|
||||
|
||||
|
||||
async def _get_top_triplet_importances(
|
||||
memory_fragment: Optional[CogneeGraph],
|
||||
vector_search: NodeEdgeVectorSearch,
|
||||
properties_to_project: Optional[List[str]],
|
||||
node_type: Optional[Type],
|
||||
node_name: Optional[List[str]],
|
||||
triplet_distance_penalty: float,
|
||||
wide_search_limit: Optional[int],
|
||||
top_k: int,
|
||||
query_list_length: Optional[int] = None,
|
||||
) -> Union[List[Edge], List[List[Edge]]]:
|
||||
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances.
|
||||
|
||||
Args:
|
||||
query_list_length: Number of queries in batch mode (None for single-query mode).
|
||||
When None, node_distances/edge_distances are flat lists; when set, they are list-of-lists.
|
||||
|
||||
Returns:
|
||||
List[Edge]: For single-query mode (query_list_length is None).
|
||||
List[List[Edge]]: For batch mode (query_list_length is set), one list per query.
|
||||
"""
|
||||
if memory_fragment is None:
|
||||
if wide_search_limit is None:
|
||||
relevant_node_ids = None
|
||||
else:
|
||||
relevant_node_ids = vector_search.extract_relevant_node_ids()
|
||||
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_node_ids,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(
|
||||
node_distances=vector_search.node_distances, query_list_length=query_list_length
|
||||
)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(
|
||||
edge_distances=vector_search.edge_distances, query_list_length=query_list_length
|
||||
)
|
||||
|
||||
return await memory_fragment.calculate_top_triplet_importances(
|
||||
k=top_k, query_list_length=query_list_length
|
||||
)
|
||||
|
||||
|
||||
async def brute_force_triplet_search(
|
||||
query: str,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
top_k: int = 5,
|
||||
collections: Optional[List[str]] = None,
|
||||
properties_to_project: Optional[List[str]] = None,
|
||||
|
|
@ -83,33 +121,49 @@ async def brute_force_triplet_search(
|
|||
node_name: Optional[List[str]] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> List[Edge]:
|
||||
) -> Union[List[Edge], List[List[Edge]]]:
|
||||
"""
|
||||
Performs a brute force search to retrieve the top triplets from the graph.
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
query (Optional[str]): The search query (single query mode). Exactly one of query or query_batch must be provided.
|
||||
query_batch (Optional[List[str]]): List of search queries (batch mode). Exactly one of query or query_batch must be provided.
|
||||
top_k (int): The number of top results to retrieve.
|
||||
collections (Optional[List[str]]): List of collections to query.
|
||||
properties_to_project (Optional[List[str]]): List of properties to project.
|
||||
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
|
||||
node_type: node type to filter
|
||||
node_name: node name to filter
|
||||
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
|
||||
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections.
|
||||
Ignored in batch mode (always None to project full graph).
|
||||
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
|
||||
|
||||
Returns:
|
||||
list: The top triplet results.
|
||||
List[Edge]: The top triplet results for single query mode (flat list).
|
||||
List[List[Edge]]: List of top triplet results (one per query) for batch mode (list-of-lists).
|
||||
|
||||
Note:
|
||||
In single-query mode, node_distances and edge_distances are stored as flat lists.
|
||||
In batch mode, they are stored as list-of-lists (one list per query).
|
||||
"""
|
||||
if not query or not isinstance(query, str):
|
||||
if query is not None and query_batch is not None:
|
||||
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
|
||||
if query is None and query_batch is None:
|
||||
raise ValueError("Must provide either 'query' or 'query_batch'.")
|
||||
if query is not None and (not query or not isinstance(query, str)):
|
||||
raise ValueError("The query must be a non-empty string.")
|
||||
if query_batch is not None:
|
||||
if not isinstance(query_batch, list) or not query_batch:
|
||||
raise ValueError("query_batch must be a non-empty list of strings.")
|
||||
if not all(isinstance(q, str) and q for q in query_batch):
|
||||
raise ValueError("All items in query_batch must be non-empty strings.")
|
||||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
# Setting wide search limit based on the parameters
|
||||
non_global_search = node_name is None
|
||||
|
||||
wide_search_limit = wide_search_top_k if non_global_search else None
|
||||
query_list_length = len(query_batch) if query_batch is not None else None
|
||||
wide_search_limit = (
|
||||
None if query_list_length else (wide_search_top_k if node_name is None else None)
|
||||
)
|
||||
|
||||
if collections is None:
|
||||
collections = [
|
||||
|
|
@ -123,77 +177,37 @@ async def brute_force_triplet_search(
|
|||
collections.append("EdgeType_relationship_name")
|
||||
|
||||
try:
|
||||
vector_engine = get_vector_engine()
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize vector engine: %s", e)
|
||||
raise RuntimeError("Initialization error") from e
|
||||
vector_search = NodeEdgeVectorSearch()
|
||||
|
||||
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
|
||||
|
||||
async def search_in_collection(collection_name: str):
|
||||
try:
|
||||
return await vector_engine.search(
|
||||
collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[search_in_collection(collection_name) for collection_name in collections]
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query=None if query_list_length else query,
|
||||
query_batch=query_batch if query_list_length else None,
|
||||
collections=collections,
|
||||
wide_search_limit=wide_search_limit,
|
||||
)
|
||||
|
||||
if all(not item for item in results):
|
||||
return []
|
||||
if not vector_search.has_results():
|
||||
return [[] for _ in range(query_list_length)] if query_list_length else []
|
||||
|
||||
# Final statistics
|
||||
vector_collection_search_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
|
||||
results = await _get_top_triplet_importances(
|
||||
memory_fragment,
|
||||
vector_search,
|
||||
properties_to_project,
|
||||
node_type,
|
||||
node_name,
|
||||
triplet_distance_penalty,
|
||||
wide_search_limit,
|
||||
top_k,
|
||||
query_list_length=query_list_length,
|
||||
)
|
||||
|
||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
|
||||
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
||||
|
||||
if wide_search_limit is not None:
|
||||
relevant_ids_to_filter = list(
|
||||
{
|
||||
str(getattr(scored_node, "id"))
|
||||
for collection_name, score_collection in node_distances.items()
|
||||
if collection_name != "EdgeType_relationship_name"
|
||||
and isinstance(score_collection, (list, tuple))
|
||||
for scored_node in score_collection
|
||||
if getattr(scored_node, "id", None)
|
||||
}
|
||||
)
|
||||
else:
|
||||
relevant_ids_to_filter = None
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||
|
||||
return results
|
||||
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
return [[] for _ in range(query_list_length)] if query_list_length else []
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error during brute force search for query: %s. Error: %s",
|
||||
query,
|
||||
query_batch if query_list_length else [query],
|
||||
error,
|
||||
)
|
||||
raise error
|
||||
|
|
|
|||
174
cognee/modules/retrieval/utils/node_edge_vector_search.py
Normal file
174
cognee/modules/retrieval/utils/node_edge_vector_search.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
import asyncio
|
||||
import time
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
logger = get_logger(level=ERROR)
|
||||
|
||||
|
||||
class NodeEdgeVectorSearch:
|
||||
"""Manages vector search and distance retrieval for graph nodes and edges."""
|
||||
|
||||
def __init__(self, edge_collection: str = "EdgeType_relationship_name", vector_engine=None):
|
||||
self.edge_collection = edge_collection
|
||||
self.vector_engine = vector_engine or self._init_vector_engine()
|
||||
self.query_vector: Optional[Any] = None
|
||||
self.node_distances: dict[str, list[Any]] = {}
|
||||
self.edge_distances: list[Any] = []
|
||||
self.query_list_length: Optional[int] = None
|
||||
|
||||
def _init_vector_engine(self):
|
||||
try:
|
||||
return get_vector_engine()
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize vector engine: %s", e)
|
||||
raise RuntimeError("Initialization error") from e
|
||||
|
||||
async def embed_and_retrieve_distances(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
collections: List[str] = None,
|
||||
wide_search_limit: Optional[int] = None,
|
||||
):
|
||||
"""Embeds query/queries and retrieves vector distances from all collections."""
|
||||
if query is not None and query_batch is not None:
|
||||
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
|
||||
if query is None and query_batch is None:
|
||||
raise ValueError("Must provide either 'query' or 'query_batch'.")
|
||||
if not collections:
|
||||
raise ValueError("'collections' must be a non-empty list.")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if query_batch is not None:
|
||||
self.query_list_length = len(query_batch)
|
||||
search_results = await self._run_batch_search(collections, query_batch)
|
||||
else:
|
||||
self.query_list_length = None
|
||||
search_results = await self._run_single_search(collections, query, wide_search_limit)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
collections_with_results = sum(1 for result in search_results if any(result))
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from "
|
||||
f"{collections_with_results} collections in {elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
self.set_distances_from_results(collections, search_results, self.query_list_length)
|
||||
|
||||
def has_results(self) -> bool:
|
||||
"""Checks if any collections returned results."""
|
||||
if self.query_list_length is None:
|
||||
if self.edge_distances and any(self.edge_distances):
|
||||
return True
|
||||
return any(
|
||||
bool(collection_results) for collection_results in self.node_distances.values()
|
||||
)
|
||||
|
||||
if self.edge_distances and any(inner_list for inner_list in self.edge_distances):
|
||||
return True
|
||||
return any(
|
||||
any(results_per_query for results_per_query in collection_results)
|
||||
for collection_results in self.node_distances.values()
|
||||
)
|
||||
|
||||
def extract_relevant_node_ids(self) -> List[str]:
|
||||
"""Extracts unique node IDs from search results."""
|
||||
if self.query_list_length is not None:
|
||||
return []
|
||||
relevant_node_ids = set()
|
||||
for scored_results in self.node_distances.values():
|
||||
for scored_node in scored_results:
|
||||
node_id = getattr(scored_node, "id", None)
|
||||
if node_id:
|
||||
relevant_node_ids.add(str(node_id))
|
||||
return list(relevant_node_ids)
|
||||
|
||||
def set_distances_from_results(
|
||||
self,
|
||||
collections: List[str],
|
||||
search_results: List[List[Any]],
|
||||
query_list_length: Optional[int] = None,
|
||||
):
|
||||
"""Separates search results into node and edge distances with stable shapes.
|
||||
|
||||
Ensures all collections are present in the output, even if empty:
|
||||
- Batch mode: missing/empty collections become [[]] * query_list_length
|
||||
- Single mode: missing/empty collections become []
|
||||
"""
|
||||
self.node_distances = {}
|
||||
self.edge_distances = (
|
||||
[] if query_list_length is None else [[] for _ in range(query_list_length)]
|
||||
)
|
||||
for collection, result in zip(collections, search_results):
|
||||
if not result:
|
||||
empty_result = (
|
||||
[] if query_list_length is None else [[] for _ in range(query_list_length)]
|
||||
)
|
||||
if collection == self.edge_collection:
|
||||
self.edge_distances = empty_result
|
||||
else:
|
||||
self.node_distances[collection] = empty_result
|
||||
else:
|
||||
if collection == self.edge_collection:
|
||||
self.edge_distances = result
|
||||
else:
|
||||
self.node_distances[collection] = result
|
||||
|
||||
async def _run_batch_search(
|
||||
self, collections: List[str], query_batch: List[str]
|
||||
) -> List[List[Any]]:
|
||||
"""Runs batch search across all collections and returns list-of-lists per collection."""
|
||||
search_tasks = [
|
||||
self._search_batch_collection(collection, query_batch) for collection in collections
|
||||
]
|
||||
return await asyncio.gather(*search_tasks)
|
||||
|
||||
async def _search_batch_collection(
|
||||
self, collection_name: str, query_batch: List[str]
|
||||
) -> List[List[Any]]:
|
||||
"""Searches one collection with batch queries and returns list-of-lists."""
|
||||
try:
|
||||
return await self.vector_engine.batch_search(
|
||||
collection_name=collection_name, query_texts=query_batch, limit=None
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return [[]] * len(query_batch)
|
||||
|
||||
async def _run_single_search(
|
||||
self, collections: List[str], query: str, wide_search_limit: Optional[int]
|
||||
) -> List[List[Any]]:
|
||||
"""Runs single query search and returns flat lists per collection.
|
||||
|
||||
Returns a list where each element is a collection's results (flat list).
|
||||
These are stored as flat lists in node_distances/edge_distances for single-query mode.
|
||||
"""
|
||||
await self._embed_query(query)
|
||||
search_tasks = [
|
||||
self._search_single_collection(self.vector_engine, wide_search_limit, collection)
|
||||
for collection in collections
|
||||
]
|
||||
search_results = await asyncio.gather(*search_tasks)
|
||||
return search_results
|
||||
|
||||
async def _embed_query(self, query: str):
|
||||
"""Embeds the query and stores the resulting vector."""
|
||||
query_embeddings = await self.vector_engine.embedding_engine.embed_text([query])
|
||||
self.query_vector = query_embeddings[0]
|
||||
|
||||
async def _search_single_collection(
|
||||
self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str
|
||||
):
|
||||
"""Searches one collection and returns results or empty list if not found."""
|
||||
try:
|
||||
return await vector_engine.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=self.query_vector,
|
||||
limit=wide_search_limit,
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
|
|
@ -14,8 +14,6 @@ from cognee.modules.engine.models.node_set import NodeSet
|
|||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.search.types import (
|
||||
SearchResult,
|
||||
CombinedSearchResult,
|
||||
SearchResultDataset,
|
||||
SearchType,
|
||||
)
|
||||
from cognee.modules.search.operations import log_query, log_result
|
||||
|
|
@ -45,11 +43,11 @@ async def search(
|
|||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
||||
verbose=False,
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
|
|
@ -90,7 +88,6 @@ async def search(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
|
|
@ -127,87 +124,63 @@ async def search(
|
|||
query.id,
|
||||
json.dumps(
|
||||
jsonable_encoder(
|
||||
await prepare_search_result(
|
||||
search_results[0] if isinstance(search_results, list) else search_results
|
||||
)
|
||||
if use_combined_context
|
||||
else [
|
||||
await prepare_search_result(search_result) for search_result in search_results
|
||||
]
|
||||
[await prepare_search_result(search_result) for search_result in search_results]
|
||||
)
|
||||
),
|
||||
user.id,
|
||||
)
|
||||
|
||||
if use_combined_context:
|
||||
prepared_search_results = await prepare_search_result(
|
||||
search_results[0] if isinstance(search_results, list) else search_results
|
||||
)
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
# This is for maintaining backwards compatibility
|
||||
if backend_access_control_enabled():
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
|
||||
return CombinedSearchResult(
|
||||
result=result,
|
||||
graphs=graphs,
|
||||
context=context,
|
||||
datasets=[
|
||||
SearchResultDataset(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
)
|
||||
for dataset in datasets
|
||||
],
|
||||
)
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
|
||||
if only_context:
|
||||
search_result_dict = {
|
||||
"search_result": [context] if context else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
}
|
||||
if verbose:
|
||||
# Include graphs only in verbose mode
|
||||
search_result_dict["graphs"] = graphs
|
||||
|
||||
return_value.append(search_result_dict)
|
||||
else:
|
||||
search_result_dict = {
|
||||
"search_result": [result] if result else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
}
|
||||
if verbose:
|
||||
# Include graphs only in verbose mode
|
||||
search_result_dict["graphs"] = graphs
|
||||
return_value.append(search_result_dict)
|
||||
|
||||
return return_value
|
||||
else:
|
||||
# This is for maintaining backwards compatibility
|
||||
if backend_access_control_enabled():
|
||||
return_value = []
|
||||
return_value = []
|
||||
if only_context:
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
|
||||
if only_context:
|
||||
return_value.append(
|
||||
{
|
||||
"search_result": [context] if context else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
"graphs": graphs,
|
||||
}
|
||||
)
|
||||
else:
|
||||
return_value.append(
|
||||
{
|
||||
"search_result": [result] if result else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
"graphs": graphs,
|
||||
}
|
||||
)
|
||||
return return_value
|
||||
return_value.append(prepared_search_results["context"])
|
||||
else:
|
||||
return_value = []
|
||||
if only_context:
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
return_value.append(prepared_search_results["context"])
|
||||
else:
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
return_value.append(result)
|
||||
# For maintaining backwards compatibility
|
||||
if len(return_value) == 1 and isinstance(return_value[0], list):
|
||||
return return_value[0]
|
||||
else:
|
||||
return return_value
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
return_value.append(result)
|
||||
# For maintaining backwards compatibility
|
||||
if len(return_value) == 1 and isinstance(return_value[0], list):
|
||||
return return_value[0]
|
||||
else:
|
||||
return return_value
|
||||
|
||||
|
||||
async def authorized_search(
|
||||
|
|
@ -223,14 +196,10 @@ async def authorized_search(
|
|||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[
|
||||
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
||||
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
||||
]:
|
||||
) -> List[Tuple[Any, Union[List[Edge], str], List[Dataset]]]:
|
||||
"""
|
||||
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
||||
Not to be used outside of active access control mode.
|
||||
|
|
@ -240,70 +209,6 @@ async def authorized_search(
|
|||
datasets=dataset_ids, permission_type="read", user=user
|
||||
)
|
||||
|
||||
if use_combined_context:
|
||||
search_responses = await search_in_datasets_context(
|
||||
search_datasets=search_datasets,
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=True,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
context = {}
|
||||
datasets: List[Dataset] = []
|
||||
|
||||
for _, search_context, search_datasets in search_responses:
|
||||
for dataset in search_datasets:
|
||||
context[str(dataset.id)] = search_context
|
||||
|
||||
datasets.extend(search_datasets)
|
||||
|
||||
specific_search_tools = await get_search_type_tools(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
search_tools = specific_search_tools
|
||||
if len(search_tools) == 2:
|
||||
[get_completion, _] = search_tools
|
||||
else:
|
||||
get_completion = search_tools[0]
|
||||
|
||||
def prepare_combined_context(
|
||||
context,
|
||||
) -> Union[List[Edge], str]:
|
||||
combined_context = []
|
||||
|
||||
for dataset_context in context.values():
|
||||
combined_context += dataset_context
|
||||
|
||||
if combined_context and isinstance(combined_context[0], str):
|
||||
return "\n".join(combined_context)
|
||||
|
||||
return combined_context
|
||||
|
||||
combined_context = prepare_combined_context(context)
|
||||
completion = await get_completion(query_text, combined_context, session_id=session_id)
|
||||
|
||||
return completion, combined_context, datasets
|
||||
|
||||
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
||||
search_results = await search_in_datasets_context(
|
||||
search_datasets=search_datasets,
|
||||
|
|
@ -319,6 +224,7 @@ async def authorized_search(
|
|||
only_context=only_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
return search_results
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class SearchResultDataset(BaseModel):
|
||||
|
|
@ -8,13 +8,6 @@ class SearchResultDataset(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
class CombinedSearchResult(BaseModel):
|
||||
result: Optional[Any]
|
||||
context: Dict[str, Any]
|
||||
graphs: Optional[Dict[str, Any]] = {}
|
||||
datasets: Optional[List[SearchResultDataset]] = None
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
search_result: Any
|
||||
dataset_id: Optional[UUID]
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
from .SearchType import SearchType
|
||||
from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult
|
||||
from .SearchResult import SearchResult, SearchResultDataset
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
}
|
||||
links_list.append(link_data)
|
||||
|
||||
html_template = """
|
||||
html_template = r"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
|
|
|
|||
613
cognee/tasks/memify/extract_usage_frequency.py
Normal file
613
cognee/tasks/memify/extract_usage_frequency.py
Normal file
|
|
@ -0,0 +1,613 @@
|
|||
# cognee/tasks/memify/extract_usage_frequency.py
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
|
||||
logger = get_logger("extract_usage_frequency")
|
||||
|
||||
|
||||
async def extract_usage_frequency(
|
||||
subgraphs: List[CogneeGraph],
|
||||
time_window: timedelta = timedelta(days=7),
|
||||
min_interaction_threshold: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract usage frequency from CogneeUserInteraction nodes.
|
||||
|
||||
When save_interaction=True in cognee.search(), the system creates:
|
||||
- CogneeUserInteraction nodes (representing the query/answer interaction)
|
||||
- used_graph_element_to_answer edges (connecting interactions to graph elements used)
|
||||
|
||||
This function tallies how often each graph element is referenced via these edges,
|
||||
enabling frequency-based ranking in downstream retrievers.
|
||||
|
||||
:param subgraphs: List of CogneeGraph instances containing interaction data
|
||||
:param time_window: Time window to consider for interactions (default: 7 days)
|
||||
:param min_interaction_threshold: Minimum interactions to track (default: 1)
|
||||
:return: Dictionary containing node frequencies, edge frequencies, and metadata
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
cutoff_time = current_time - time_window
|
||||
|
||||
# Track frequencies for graph elements (nodes and edges)
|
||||
node_frequencies = {}
|
||||
edge_frequencies = {}
|
||||
relationship_type_frequencies = {}
|
||||
|
||||
# Track interaction metadata
|
||||
interaction_count = 0
|
||||
interactions_in_window = 0
|
||||
|
||||
logger.info(f"Extracting usage frequencies from {len(subgraphs)} subgraphs")
|
||||
logger.info(f"Time window: {time_window}, Cutoff: {cutoff_time.isoformat()}")
|
||||
|
||||
for subgraph in subgraphs:
|
||||
# Find all CogneeUserInteraction nodes
|
||||
interaction_nodes = {}
|
||||
for node_id, node in subgraph.nodes.items():
|
||||
node_type = node.attributes.get("type") or node.attributes.get("node_type")
|
||||
|
||||
if node_type == "CogneeUserInteraction":
|
||||
# Parse and validate timestamp
|
||||
timestamp_value = node.attributes.get("timestamp") or node.attributes.get(
|
||||
"created_at"
|
||||
)
|
||||
if timestamp_value is not None:
|
||||
try:
|
||||
# Handle various timestamp formats
|
||||
interaction_time = None
|
||||
|
||||
if isinstance(timestamp_value, datetime):
|
||||
# Already a Python datetime
|
||||
interaction_time = timestamp_value
|
||||
elif isinstance(timestamp_value, (int, float)):
|
||||
# Unix timestamp (assume milliseconds if > 10 digits)
|
||||
if timestamp_value > 10000000000:
|
||||
# Milliseconds since epoch
|
||||
interaction_time = datetime.fromtimestamp(timestamp_value / 1000.0)
|
||||
else:
|
||||
# Seconds since epoch
|
||||
interaction_time = datetime.fromtimestamp(timestamp_value)
|
||||
elif isinstance(timestamp_value, str):
|
||||
# Try different string formats
|
||||
if timestamp_value.isdigit():
|
||||
# Numeric string - treat as Unix timestamp
|
||||
ts_int = int(timestamp_value)
|
||||
if ts_int > 10000000000:
|
||||
interaction_time = datetime.fromtimestamp(ts_int / 1000.0)
|
||||
else:
|
||||
interaction_time = datetime.fromtimestamp(ts_int)
|
||||
else:
|
||||
# ISO format string
|
||||
interaction_time = datetime.fromisoformat(timestamp_value)
|
||||
elif hasattr(timestamp_value, "to_native"):
|
||||
# Neo4j datetime object - convert to Python datetime
|
||||
interaction_time = timestamp_value.to_native()
|
||||
elif hasattr(timestamp_value, "year") and hasattr(timestamp_value, "month"):
|
||||
# Datetime-like object - extract components
|
||||
try:
|
||||
interaction_time = datetime(
|
||||
year=timestamp_value.year,
|
||||
month=timestamp_value.month,
|
||||
day=timestamp_value.day,
|
||||
hour=getattr(timestamp_value, "hour", 0),
|
||||
minute=getattr(timestamp_value, "minute", 0),
|
||||
second=getattr(timestamp_value, "second", 0),
|
||||
microsecond=getattr(timestamp_value, "microsecond", 0),
|
||||
)
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
|
||||
if interaction_time is None:
|
||||
# Last resort: try converting to string and parsing
|
||||
str_value = str(timestamp_value)
|
||||
if str_value.isdigit():
|
||||
ts_int = int(str_value)
|
||||
if ts_int > 10000000000:
|
||||
interaction_time = datetime.fromtimestamp(ts_int / 1000.0)
|
||||
else:
|
||||
interaction_time = datetime.fromtimestamp(ts_int)
|
||||
else:
|
||||
interaction_time = datetime.fromisoformat(str_value)
|
||||
|
||||
if interaction_time is None:
|
||||
raise ValueError(f"Could not parse timestamp: {timestamp_value}")
|
||||
|
||||
# Make sure it's timezone-naive for comparison
|
||||
if interaction_time.tzinfo is not None:
|
||||
interaction_time = interaction_time.replace(tzinfo=None)
|
||||
|
||||
interaction_nodes[node_id] = {
|
||||
"node": node,
|
||||
"timestamp": interaction_time,
|
||||
"in_window": interaction_time >= cutoff_time,
|
||||
}
|
||||
interaction_count += 1
|
||||
if interaction_time >= cutoff_time:
|
||||
interactions_in_window += 1
|
||||
except (ValueError, TypeError, AttributeError, OSError) as e:
|
||||
logger.warning(
|
||||
f"Failed to parse timestamp for interaction node {node_id}: {e}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Timestamp value type: {type(timestamp_value)}, value: {timestamp_value}"
|
||||
)
|
||||
|
||||
# Process edges to find graph elements used in interactions
|
||||
for edge in subgraph.edges:
|
||||
relationship_type = edge.attributes.get("relationship_type")
|
||||
|
||||
# Look for 'used_graph_element_to_answer' edges
|
||||
if relationship_type == "used_graph_element_to_answer":
|
||||
# node1 should be the CogneeUserInteraction, node2 is the graph element
|
||||
source_id = str(edge.node1.id)
|
||||
target_id = str(edge.node2.id)
|
||||
|
||||
# Check if source is an interaction node in our time window
|
||||
if source_id in interaction_nodes:
|
||||
interaction_data = interaction_nodes[source_id]
|
||||
|
||||
if interaction_data["in_window"]:
|
||||
# Count the graph element (target node) being used
|
||||
node_frequencies[target_id] = node_frequencies.get(target_id, 0) + 1
|
||||
|
||||
# Also track what type of element it is for analytics
|
||||
target_node = subgraph.get_node(target_id)
|
||||
if target_node:
|
||||
element_type = target_node.attributes.get(
|
||||
"type"
|
||||
) or target_node.attributes.get("node_type")
|
||||
if element_type:
|
||||
relationship_type_frequencies[element_type] = (
|
||||
relationship_type_frequencies.get(element_type, 0) + 1
|
||||
)
|
||||
|
||||
# Also track general edge usage patterns
|
||||
elif relationship_type and relationship_type != "used_graph_element_to_answer":
|
||||
# Check if either endpoint is referenced in a recent interaction
|
||||
source_id = str(edge.node1.id)
|
||||
target_id = str(edge.node2.id)
|
||||
|
||||
# If this edge connects to any frequently accessed nodes, track the edge type
|
||||
if source_id in node_frequencies or target_id in node_frequencies:
|
||||
edge_key = f"{relationship_type}:{source_id}:{target_id}"
|
||||
edge_frequencies[edge_key] = edge_frequencies.get(edge_key, 0) + 1
|
||||
|
||||
# Filter frequencies above threshold
|
||||
filtered_node_frequencies = {
|
||||
node_id: freq
|
||||
for node_id, freq in node_frequencies.items()
|
||||
if freq >= min_interaction_threshold
|
||||
}
|
||||
|
||||
filtered_edge_frequencies = {
|
||||
edge_key: freq
|
||||
for edge_key, freq in edge_frequencies.items()
|
||||
if freq >= min_interaction_threshold
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Processed {interactions_in_window}/{interaction_count} interactions in time window"
|
||||
)
|
||||
logger.info(
|
||||
f"Found {len(filtered_node_frequencies)} nodes and {len(filtered_edge_frequencies)} edges "
|
||||
f"above threshold (min: {min_interaction_threshold})"
|
||||
)
|
||||
logger.info(f"Element type distribution: {relationship_type_frequencies}")
|
||||
|
||||
return {
|
||||
"node_frequencies": filtered_node_frequencies,
|
||||
"edge_frequencies": filtered_edge_frequencies,
|
||||
"element_type_frequencies": relationship_type_frequencies,
|
||||
"total_interactions": interaction_count,
|
||||
"interactions_in_window": interactions_in_window,
|
||||
"time_window_days": time_window.days,
|
||||
"last_processed_timestamp": current_time.isoformat(),
|
||||
"cutoff_timestamp": cutoff_time.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
async def add_frequency_weights(
|
||||
graph_adapter: GraphDBInterface, usage_frequencies: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Add frequency weights to graph nodes and edges using the graph adapter.
|
||||
|
||||
Uses direct Cypher queries for Neo4j adapter compatibility.
|
||||
Writes frequency_weight properties back to the graph for use in:
|
||||
- Ranking frequently referenced entities higher during retrieval
|
||||
- Adjusting scoring for completion strategies
|
||||
- Exposing usage metrics in dashboards or audits
|
||||
|
||||
:param graph_adapter: Graph database adapter interface
|
||||
:param usage_frequencies: Calculated usage frequencies from extract_usage_frequency
|
||||
"""
|
||||
node_frequencies = usage_frequencies.get("node_frequencies", {})
|
||||
edge_frequencies = usage_frequencies.get("edge_frequencies", {})
|
||||
|
||||
logger.info(f"Adding frequency weights to {len(node_frequencies)} nodes")
|
||||
|
||||
# Check adapter type and use appropriate method
|
||||
adapter_type = type(graph_adapter).__name__
|
||||
logger.info(f"Using adapter: {adapter_type}")
|
||||
|
||||
nodes_updated = 0
|
||||
nodes_failed = 0
|
||||
|
||||
# Determine which method to use based on adapter type
|
||||
use_neo4j_cypher = adapter_type == "Neo4jAdapter" and hasattr(graph_adapter, "query")
|
||||
use_kuzu_query = adapter_type == "KuzuAdapter" and hasattr(graph_adapter, "query")
|
||||
use_get_update = hasattr(graph_adapter, "get_node_by_id") and hasattr(
|
||||
graph_adapter, "update_node_properties"
|
||||
)
|
||||
|
||||
# Method 1: Neo4j Cypher with SET (creates properties on the fly)
|
||||
if use_neo4j_cypher:
|
||||
try:
|
||||
logger.info("Using Neo4j Cypher SET method")
|
||||
last_updated = usage_frequencies.get("last_processed_timestamp")
|
||||
|
||||
for node_id, frequency in node_frequencies.items():
|
||||
try:
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE n.id = $node_id
|
||||
SET n.frequency_weight = $frequency,
|
||||
n.frequency_updated_at = $updated_at
|
||||
RETURN n.id as id
|
||||
"""
|
||||
|
||||
result = await graph_adapter.query(
|
||||
query,
|
||||
params={
|
||||
"node_id": node_id,
|
||||
"frequency": frequency,
|
||||
"updated_at": last_updated,
|
||||
},
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
nodes_updated += 1
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found or not updated")
|
||||
nodes_failed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating node {node_id}: {e}")
|
||||
nodes_failed += 1
|
||||
|
||||
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Neo4j Cypher update failed: {e}")
|
||||
use_neo4j_cypher = False
|
||||
|
||||
# Method 2: Kuzu - use get_node + add_node (updates via re-adding with same ID)
|
||||
elif (
|
||||
use_kuzu_query and hasattr(graph_adapter, "get_node") and hasattr(graph_adapter, "add_node")
|
||||
):
|
||||
logger.info("Using Kuzu get_node + add_node method")
|
||||
last_updated = usage_frequencies.get("last_processed_timestamp")
|
||||
|
||||
for node_id, frequency in node_frequencies.items():
|
||||
try:
|
||||
# Get the existing node (returns a dict)
|
||||
existing_node_dict = await graph_adapter.get_node(node_id)
|
||||
|
||||
if existing_node_dict:
|
||||
# Update the dict with new properties
|
||||
existing_node_dict["frequency_weight"] = frequency
|
||||
existing_node_dict["frequency_updated_at"] = last_updated
|
||||
|
||||
# Kuzu's add_node likely just takes the dict directly, not a Node object
|
||||
# Try passing the dict directly first
|
||||
try:
|
||||
await graph_adapter.add_node(existing_node_dict)
|
||||
nodes_updated += 1
|
||||
except Exception as dict_error:
|
||||
# If dict doesn't work, try creating a Node object
|
||||
logger.debug(f"Dict add failed, trying Node object: {dict_error}")
|
||||
|
||||
try:
|
||||
from cognee.infrastructure.engine import Node
|
||||
|
||||
# Try different Node constructor patterns
|
||||
try:
|
||||
# Pattern 1: Just properties
|
||||
node_obj = Node(existing_node_dict)
|
||||
except Exception:
|
||||
# Pattern 2: Type and properties
|
||||
node_obj = Node(
|
||||
type=existing_node_dict.get("type", "Unknown"),
|
||||
**existing_node_dict,
|
||||
)
|
||||
|
||||
await graph_adapter.add_node(node_obj)
|
||||
nodes_updated += 1
|
||||
except Exception as node_error:
|
||||
logger.error(f"Both dict and Node object failed: {node_error}")
|
||||
nodes_failed += 1
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in graph")
|
||||
nodes_failed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating node {node_id}: {e}")
|
||||
nodes_failed += 1
|
||||
|
||||
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
|
||||
|
||||
# Method 3: Generic get_node_by_id + update_node_properties
|
||||
elif use_get_update:
|
||||
logger.info("Using get/update method for adapter")
|
||||
for node_id, frequency in node_frequencies.items():
|
||||
try:
|
||||
# Get current node data
|
||||
node_data = await graph_adapter.get_node_by_id(node_id)
|
||||
|
||||
if node_data:
|
||||
# Tweak the properties dict - add frequency_weight
|
||||
if isinstance(node_data, dict):
|
||||
properties = node_data.get("properties", {})
|
||||
else:
|
||||
properties = getattr(node_data, "properties", {}) or {}
|
||||
|
||||
# Update with frequency weight
|
||||
properties["frequency_weight"] = frequency
|
||||
properties["frequency_updated_at"] = usage_frequencies.get(
|
||||
"last_processed_timestamp"
|
||||
)
|
||||
|
||||
# Write back via adapter
|
||||
await graph_adapter.update_node_properties(node_id, properties)
|
||||
nodes_updated += 1
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in graph")
|
||||
nodes_failed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating node {node_id}: {e}")
|
||||
nodes_failed += 1
|
||||
|
||||
logger.info(f"Node update complete: {nodes_updated} succeeded, {nodes_failed} failed")
|
||||
for node_id, frequency in node_frequencies.items():
|
||||
try:
|
||||
# Get current node data
|
||||
node_data = await graph_adapter.get_node_by_id(node_id)
|
||||
|
||||
if node_data:
|
||||
# Tweak the properties dict - add frequency_weight
|
||||
if isinstance(node_data, dict):
|
||||
properties = node_data.get("properties", {})
|
||||
else:
|
||||
properties = getattr(node_data, "properties", {}) or {}
|
||||
|
||||
# Update with frequency weight
|
||||
properties["frequency_weight"] = frequency
|
||||
properties["frequency_updated_at"] = usage_frequencies.get(
|
||||
"last_processed_timestamp"
|
||||
)
|
||||
|
||||
# Write back via adapter
|
||||
await graph_adapter.update_node_properties(node_id, properties)
|
||||
nodes_updated += 1
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in graph")
|
||||
nodes_failed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating node {node_id}: {e}")
|
||||
nodes_failed += 1
|
||||
|
||||
# If no method is available
|
||||
if not use_neo4j_cypher and not use_kuzu_query and not use_get_update:
|
||||
logger.error(f"Adapter {adapter_type} does not support required update methods")
|
||||
logger.error(
|
||||
"Required: either 'query' method or both 'get_node_by_id' and 'update_node_properties'"
|
||||
)
|
||||
return
|
||||
|
||||
# Update edge frequencies
|
||||
# Note: Edge property updates are backend-specific
|
||||
if edge_frequencies:
|
||||
logger.info(f"Processing {len(edge_frequencies)} edge frequency entries")
|
||||
|
||||
edges_updated = 0
|
||||
edges_failed = 0
|
||||
|
||||
for edge_key, frequency in edge_frequencies.items():
|
||||
try:
|
||||
# Parse edge key: "relationship_type:source_id:target_id"
|
||||
parts = edge_key.split(":", 2)
|
||||
if len(parts) == 3:
|
||||
relationship_type, source_id, target_id = parts
|
||||
|
||||
# Try to update edge if adapter supports it
|
||||
if hasattr(graph_adapter, "update_edge_properties"):
|
||||
edge_properties = {
|
||||
"frequency_weight": frequency,
|
||||
"frequency_updated_at": usage_frequencies.get(
|
||||
"last_processed_timestamp"
|
||||
),
|
||||
}
|
||||
|
||||
await graph_adapter.update_edge_properties(
|
||||
source_id, target_id, relationship_type, edge_properties
|
||||
)
|
||||
edges_updated += 1
|
||||
else:
|
||||
# Fallback: store in metadata or log
|
||||
logger.debug(
|
||||
f"Adapter doesn't support update_edge_properties for "
|
||||
f"{relationship_type} ({source_id} -> {target_id})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating edge {edge_key}: {e}")
|
||||
edges_failed += 1
|
||||
|
||||
if edges_updated > 0:
|
||||
logger.info(f"Edge update complete: {edges_updated} succeeded, {edges_failed} failed")
|
||||
else:
|
||||
logger.info(
|
||||
"Edge frequency updates skipped (adapter may not support edge property updates)"
|
||||
)
|
||||
|
||||
# Store aggregate statistics as metadata if supported
|
||||
if hasattr(graph_adapter, "set_metadata"):
|
||||
try:
|
||||
metadata = {
|
||||
"element_type_frequencies": usage_frequencies.get("element_type_frequencies", {}),
|
||||
"total_interactions": usage_frequencies.get("total_interactions", 0),
|
||||
"interactions_in_window": usage_frequencies.get("interactions_in_window", 0),
|
||||
"last_frequency_update": usage_frequencies.get("last_processed_timestamp"),
|
||||
}
|
||||
await graph_adapter.set_metadata("usage_frequency_stats", metadata)
|
||||
logger.info("Stored usage frequency statistics as metadata")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not store usage statistics as metadata: {e}")
|
||||
|
||||
|
||||
async def create_usage_frequency_pipeline(
|
||||
graph_adapter: GraphDBInterface,
|
||||
time_window: timedelta = timedelta(days=7),
|
||||
min_interaction_threshold: int = 1,
|
||||
batch_size: int = 100,
|
||||
) -> tuple:
|
||||
"""
|
||||
Create memify pipeline entry for usage frequency tracking.
|
||||
|
||||
This follows the same pattern as feedback enrichment flows, allowing
|
||||
the frequency update to run end-to-end in a custom memify pipeline.
|
||||
|
||||
Use case example:
|
||||
extraction_tasks, enrichment_tasks = await create_usage_frequency_pipeline(
|
||||
graph_adapter=my_adapter,
|
||||
time_window=timedelta(days=30),
|
||||
min_interaction_threshold=2
|
||||
)
|
||||
|
||||
# Run in memify pipeline
|
||||
pipeline = Pipeline(extraction_tasks + enrichment_tasks)
|
||||
results = await pipeline.run()
|
||||
|
||||
:param graph_adapter: Graph database adapter
|
||||
:param time_window: Time window for counting interactions (default: 7 days)
|
||||
:param min_interaction_threshold: Minimum interactions to track (default: 1)
|
||||
:param batch_size: Batch size for processing (default: 100)
|
||||
:return: Tuple of (extraction_tasks, enrichment_tasks)
|
||||
"""
|
||||
logger.info("Creating usage frequency pipeline")
|
||||
logger.info(f"Config: time_window={time_window}, threshold={min_interaction_threshold}")
|
||||
|
||||
extraction_tasks = [
|
||||
Task(
|
||||
extract_usage_frequency,
|
||||
time_window=time_window,
|
||||
min_interaction_threshold=min_interaction_threshold,
|
||||
)
|
||||
]
|
||||
|
||||
enrichment_tasks = [
|
||||
Task(
|
||||
add_frequency_weights,
|
||||
graph_adapter=graph_adapter,
|
||||
task_config={"batch_size": batch_size},
|
||||
)
|
||||
]
|
||||
|
||||
return extraction_tasks, enrichment_tasks
|
||||
|
||||
|
||||
async def run_usage_frequency_update(
|
||||
graph_adapter: GraphDBInterface,
|
||||
subgraphs: List[CogneeGraph],
|
||||
time_window: timedelta = timedelta(days=7),
|
||||
min_interaction_threshold: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convenience function to run the complete usage frequency update pipeline.
|
||||
|
||||
This is the main entry point for updating frequency weights on graph elements
|
||||
based on CogneeUserInteraction data from cognee.search(save_interaction=True).
|
||||
|
||||
Example usage:
|
||||
# After running searches with save_interaction=True
|
||||
from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update
|
||||
|
||||
# Get the graph with interactions
|
||||
graph = await get_cognee_graph_with_interactions()
|
||||
|
||||
# Update frequency weights
|
||||
stats = await run_usage_frequency_update(
|
||||
graph_adapter=graph_adapter,
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=30), # Last 30 days
|
||||
min_interaction_threshold=2 # At least 2 uses
|
||||
)
|
||||
|
||||
print(f"Updated {len(stats['node_frequencies'])} nodes")
|
||||
|
||||
:param graph_adapter: Graph database adapter
|
||||
:param subgraphs: List of CogneeGraph instances with interaction data
|
||||
:param time_window: Time window for counting interactions
|
||||
:param min_interaction_threshold: Minimum interactions to track
|
||||
:return: Usage frequency statistics
|
||||
"""
|
||||
logger.info("Starting usage frequency update")
|
||||
|
||||
try:
|
||||
# Extract frequencies from interaction data
|
||||
usage_frequencies = await extract_usage_frequency(
|
||||
subgraphs=subgraphs,
|
||||
time_window=time_window,
|
||||
min_interaction_threshold=min_interaction_threshold,
|
||||
)
|
||||
|
||||
# Add frequency weights back to the graph
|
||||
await add_frequency_weights(
|
||||
graph_adapter=graph_adapter, usage_frequencies=usage_frequencies
|
||||
)
|
||||
|
||||
logger.info("Usage frequency update completed successfully")
|
||||
logger.info(
|
||||
f"Summary: {usage_frequencies['interactions_in_window']} interactions processed, "
|
||||
f"{len(usage_frequencies['node_frequencies'])} nodes weighted"
|
||||
)
|
||||
|
||||
return usage_frequencies
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during usage frequency update: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_most_frequent_elements(
|
||||
graph_adapter: GraphDBInterface, top_n: int = 10, element_type: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve the most frequently accessed graph elements.
|
||||
|
||||
Useful for analytics dashboards and understanding user behavior.
|
||||
|
||||
:param graph_adapter: Graph database adapter
|
||||
:param top_n: Number of top elements to return
|
||||
:param element_type: Optional filter by element type
|
||||
:return: List of elements with their frequency weights
|
||||
"""
|
||||
logger.info(f"Retrieving top {top_n} most frequent elements")
|
||||
|
||||
# This would need to be implemented based on the specific graph adapter's query capabilities
|
||||
# Pseudocode:
|
||||
# results = await graph_adapter.query_nodes_by_property(
|
||||
# property_name='frequency_weight',
|
||||
# order_by='DESC',
|
||||
# limit=top_n,
|
||||
# filters={'type': element_type} if element_type else None
|
||||
# )
|
||||
|
||||
logger.warning("get_most_frequent_elements needs adapter-specific implementation")
|
||||
return []
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import os
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import cognee
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
|
||||
|
||||
skip_without_provider = pytest.mark.skipif(
|
||||
not (os.getenv("OPENAI_API_KEY") or os.getenv("AZURE_OPENAI_API_KEY")),
|
||||
reason="requires embedding/vector provider credentials",
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clean_environment():
|
||||
"""Configure isolated storage and ensure cleanup before/after."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_brute_force_triplet_search_e2e")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_brute_force_triplet_search_e2e")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@skip_without_provider
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_end_to_end(clean_environment):
|
||||
"""Minimal end-to-end exercise of single and batch triplet search."""
|
||||
|
||||
text = """
|
||||
Cognee is an open-source AI memory engine that structures data into searchable formats for use with AI agents.
|
||||
The company focuses on persistent memory systems using knowledge graphs and vector search.
|
||||
It is a Berlin-based startup building infrastructure for context-aware AI applications.
|
||||
"""
|
||||
|
||||
await cognee.add(text)
|
||||
await cognee.cognify()
|
||||
|
||||
single_result = await brute_force_triplet_search(query="What is NLP?", top_k=1)
|
||||
assert isinstance(single_result, list)
|
||||
if single_result:
|
||||
assert all(isinstance(edge, Edge) for edge in single_result)
|
||||
|
||||
batch_queries = ["What is Cognee?", "What is the company's focus?"]
|
||||
batch_result = await brute_force_triplet_search(query_batch=batch_queries, top_k=1)
|
||||
|
||||
assert isinstance(batch_result, list)
|
||||
assert len(batch_result) == len(batch_queries)
|
||||
assert all(isinstance(per_query, list) for per_query in batch_result)
|
||||
for per_query in batch_result:
|
||||
if per_query:
|
||||
assert all(isinstance(edge, Edge) for edge in per_query)
|
||||
308
cognee/tests/test_extract_usage_frequency.py
Normal file
308
cognee/tests/test_extract_usage_frequency.py
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
"""
|
||||
Test Suite: Usage Frequency Tracking
|
||||
|
||||
Comprehensive tests for the usage frequency tracking implementation.
|
||||
Tests cover extraction logic, adapter integration, edge cases, and end-to-end workflows.
|
||||
|
||||
Run with:
|
||||
pytest test_usage_frequency_comprehensive.py -v
|
||||
|
||||
Or without pytest:
|
||||
python test_usage_frequency_comprehensive.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict
|
||||
|
||||
# Mock imports for testing without full Cognee setup
|
||||
try:
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
||||
from cognee.tasks.memify.extract_usage_frequency import (
|
||||
extract_usage_frequency,
|
||||
add_frequency_weights,
|
||||
run_usage_frequency_update,
|
||||
)
|
||||
|
||||
COGNEE_AVAILABLE = True
|
||||
except ImportError:
|
||||
COGNEE_AVAILABLE = False
|
||||
print("⚠ Cognee not fully available - some tests will be skipped")
|
||||
|
||||
|
||||
class TestUsageFrequencyExtraction(unittest.TestCase):
|
||||
"""Test the core frequency extraction logic."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
if not COGNEE_AVAILABLE:
|
||||
self.skipTest("Cognee modules not available")
|
||||
|
||||
def create_mock_graph(self, num_interactions: int = 3, num_elements: int = 5):
|
||||
"""Create a mock graph with interactions and elements."""
|
||||
graph = CogneeGraph()
|
||||
|
||||
# Create interaction nodes
|
||||
current_time = datetime.now()
|
||||
for i in range(num_interactions):
|
||||
interaction_node = Node(
|
||||
id=f"interaction_{i}",
|
||||
node_type="CogneeUserInteraction",
|
||||
attributes={
|
||||
"type": "CogneeUserInteraction",
|
||||
"query_text": f"Test query {i}",
|
||||
"timestamp": int((current_time - timedelta(hours=i)).timestamp() * 1000),
|
||||
},
|
||||
)
|
||||
graph.add_node(interaction_node)
|
||||
|
||||
# Create graph element nodes
|
||||
for i in range(num_elements):
|
||||
element_node = Node(
|
||||
id=f"element_{i}",
|
||||
node_type="DocumentChunk",
|
||||
attributes={"type": "DocumentChunk", "text": f"Element content {i}"},
|
||||
)
|
||||
graph.add_node(element_node)
|
||||
|
||||
# Create usage edges (interactions reference elements)
|
||||
for i in range(num_interactions):
|
||||
# Each interaction uses 2-3 elements
|
||||
for j in range(2):
|
||||
element_idx = (i + j) % num_elements
|
||||
edge = Edge(
|
||||
node1=graph.get_node(f"interaction_{i}"),
|
||||
node2=graph.get_node(f"element_{element_idx}"),
|
||||
edge_type="used_graph_element_to_answer",
|
||||
attributes={"relationship_type": "used_graph_element_to_answer"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
return graph
|
||||
|
||||
async def test_basic_frequency_extraction(self):
|
||||
"""Test basic frequency extraction with simple graph."""
|
||||
graph = self.create_mock_graph(num_interactions=3, num_elements=5)
|
||||
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=1
|
||||
)
|
||||
|
||||
self.assertIn("node_frequencies", result)
|
||||
self.assertIn("total_interactions", result)
|
||||
self.assertEqual(result["total_interactions"], 3)
|
||||
self.assertGreater(len(result["node_frequencies"]), 0)
|
||||
|
||||
async def test_time_window_filtering(self):
|
||||
"""Test that time window correctly filters old interactions."""
|
||||
graph = CogneeGraph()
|
||||
|
||||
current_time = datetime.now()
|
||||
|
||||
# Add recent interaction (within window)
|
||||
recent_node = Node(
|
||||
id="recent_interaction",
|
||||
node_type="CogneeUserInteraction",
|
||||
attributes={
|
||||
"type": "CogneeUserInteraction",
|
||||
"timestamp": int(current_time.timestamp() * 1000),
|
||||
},
|
||||
)
|
||||
graph.add_node(recent_node)
|
||||
|
||||
# Add old interaction (outside window)
|
||||
old_node = Node(
|
||||
id="old_interaction",
|
||||
node_type="CogneeUserInteraction",
|
||||
attributes={
|
||||
"type": "CogneeUserInteraction",
|
||||
"timestamp": int((current_time - timedelta(days=10)).timestamp() * 1000),
|
||||
},
|
||||
)
|
||||
graph.add_node(old_node)
|
||||
|
||||
# Add element
|
||||
element = Node(
|
||||
id="element_1", node_type="DocumentChunk", attributes={"type": "DocumentChunk"}
|
||||
)
|
||||
graph.add_node(element)
|
||||
|
||||
# Add edges
|
||||
graph.add_edge(
|
||||
Edge(
|
||||
node1=recent_node,
|
||||
node2=element,
|
||||
edge_type="used_graph_element_to_answer",
|
||||
attributes={"relationship_type": "used_graph_element_to_answer"},
|
||||
)
|
||||
)
|
||||
graph.add_edge(
|
||||
Edge(
|
||||
node1=old_node,
|
||||
node2=element,
|
||||
edge_type="used_graph_element_to_answer",
|
||||
attributes={"relationship_type": "used_graph_element_to_answer"},
|
||||
)
|
||||
)
|
||||
|
||||
# Extract with 7-day window
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=1
|
||||
)
|
||||
|
||||
# Should only count recent interaction
|
||||
self.assertEqual(result["interactions_in_window"], 1)
|
||||
self.assertEqual(result["total_interactions"], 2)
|
||||
|
||||
async def test_threshold_filtering(self):
|
||||
"""Test that minimum threshold filters low-frequency nodes."""
|
||||
graph = self.create_mock_graph(num_interactions=5, num_elements=10)
|
||||
|
||||
# Extract with threshold of 3
|
||||
result = await extract_usage_frequency(
|
||||
subgraphs=[graph], time_window=timedelta(days=7), min_interaction_threshold=3
|
||||
)
|
||||
|
||||
# Only nodes with 3+ accesses should be included
|
||||
for node_id, freq in result["node_frequencies"].items():
|
||||
self.assertGreaterEqual(freq, 3)
|
||||
|
||||
async def test_element_type_tracking(self):
|
||||
"""Test that element types are properly tracked."""
|
||||
graph = CogneeGraph()
|
||||
|
||||
# Create interaction
|
||||
interaction = Node(
|
||||
id="interaction_1",
|
||||
node_type="CogneeUserInteraction",
|
||||
attributes={
|
||||
"type": "CogneeUserInteraction",
|
||||
"timestamp": int(datetime.now().timestamp() * 1000),
|
||||
},
|
||||
)
|
||||
graph.add_node(interaction)
|
||||
|
||||
# Create elements of different types
|
||||
chunk = Node(id="chunk_1", node_type="DocumentChunk", attributes={"type": "DocumentChunk"})
|
||||
entity = Node(id="entity_1", node_type="Entity", attributes={"type": "Entity"})
|
||||
|
||||
graph.add_node(chunk)
|
||||
graph.add_node(entity)
|
||||
|
||||
# Add edges
|
||||
for element in [chunk, entity]:
|
||||
graph.add_edge(
|
||||
Edge(
|
||||
node1=interaction,
|
||||
node2=element,
|
||||
edge_type="used_graph_element_to_answer",
|
||||
attributes={"relationship_type": "used_graph_element_to_answer"},
|
||||
)
|
||||
)
|
||||
|
||||
result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
|
||||
|
||||
# Check element types were tracked
|
||||
self.assertIn("element_type_frequencies", result)
|
||||
types = result["element_type_frequencies"]
|
||||
self.assertIn("DocumentChunk", types)
|
||||
self.assertIn("Entity", types)
|
||||
|
||||
async def test_empty_graph(self):
|
||||
"""Test handling of empty graph."""
|
||||
graph = CogneeGraph()
|
||||
|
||||
result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
|
||||
|
||||
self.assertEqual(result["total_interactions"], 0)
|
||||
self.assertEqual(len(result["node_frequencies"]), 0)
|
||||
|
||||
async def test_no_interactions_in_window(self):
|
||||
"""Test handling when all interactions are outside time window."""
|
||||
graph = CogneeGraph()
|
||||
|
||||
# Add old interaction
|
||||
old_time = datetime.now() - timedelta(days=30)
|
||||
old_interaction = Node(
|
||||
id="old_interaction",
|
||||
node_type="CogneeUserInteraction",
|
||||
attributes={
|
||||
"type": "CogneeUserInteraction",
|
||||
"timestamp": int(old_time.timestamp() * 1000),
|
||||
},
|
||||
)
|
||||
graph.add_node(old_interaction)
|
||||
|
||||
result = await extract_usage_frequency(subgraphs=[graph], time_window=timedelta(days=7))
|
||||
|
||||
self.assertEqual(result["interactions_in_window"], 0)
|
||||
self.assertEqual(result["total_interactions"], 1)
|
||||
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests for the complete workflow."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
if not COGNEE_AVAILABLE:
|
||||
self.skipTest("Cognee modules not available")
|
||||
|
||||
async def test_end_to_end_workflow(self):
|
||||
"""Test the complete end-to-end frequency tracking workflow."""
|
||||
# This would require a full Cognee setup with database
|
||||
# Skipped in unit tests, run as part of example_usage_frequency_e2e.py
|
||||
self.skipTest("E2E test - run example_usage_frequency_e2e.py instead")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Runner
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def run_async_test(test_func):
|
||||
"""Helper to run async test functions."""
|
||||
asyncio.run(test_func())
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
if not COGNEE_AVAILABLE:
|
||||
print("⚠ Cognee not available - skipping tests")
|
||||
print("Install with: pip install cognee[neo4j]")
|
||||
return
|
||||
|
||||
print("=" * 80)
|
||||
print("Running Usage Frequency Tests")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
# Create test suite
|
||||
loader = unittest.TestLoader()
|
||||
suite = unittest.TestSuite()
|
||||
|
||||
# Add tests
|
||||
suite.addTests(loader.loadTestsFromTestCase(TestUsageFrequencyExtraction))
|
||||
suite.addTests(loader.loadTestsFromTestCase(TestIntegration))
|
||||
|
||||
# Run tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
|
||||
# Summary
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("Test Summary")
|
||||
print("=" * 80)
|
||||
print(f"Tests run: {result.testsRun}")
|
||||
print(f"Successes: {result.testsRun - len(result.failures) - len(result.errors)}")
|
||||
print(f"Failures: {len(result.failures)}")
|
||||
print(f"Errors: {len(result.errors)}")
|
||||
print(f"Skipped: {len(result.skipped)}")
|
||||
|
||||
return 0 if result.wasSuccessful() else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
|
|
@ -149,7 +149,9 @@ async def e2e_state():
|
|||
|
||||
vector_engine = get_vector_engine()
|
||||
collection = await vector_engine.search(
|
||||
collection_name="Triplet_text", query_text="Test", limit=None
|
||||
collection_name="Triplet_text",
|
||||
query_text="Test",
|
||||
limit=None,
|
||||
)
|
||||
|
||||
# --- Retriever contexts ---
|
||||
|
|
@ -188,57 +190,70 @@ async def e2e_state():
|
|||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="Where is germany located, next to which country?",
|
||||
save_interaction=True,
|
||||
verbose=True,
|
||||
)
|
||||
completion_cot = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION_COT,
|
||||
query_text="What is the country next to germany??",
|
||||
save_interaction=True,
|
||||
verbose=True,
|
||||
)
|
||||
completion_ext = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION,
|
||||
query_text="What is the name of the country next to germany",
|
||||
save_interaction=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
await cognee.search(
|
||||
query_type=SearchType.FEEDBACK, query_text="This was not the best answer", last_k=1
|
||||
query_type=SearchType.FEEDBACK,
|
||||
query_text="This was not the best answer",
|
||||
last_k=1,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
completion_sum = await cognee.search(
|
||||
query_type=SearchType.GRAPH_SUMMARY_COMPLETION,
|
||||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=True,
|
||||
verbose=True,
|
||||
)
|
||||
completion_triplet = await cognee.search(
|
||||
query_type=SearchType.TRIPLET_COMPLETION,
|
||||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=True,
|
||||
verbose=True,
|
||||
)
|
||||
completion_chunks = await cognee.search(
|
||||
query_type=SearchType.CHUNKS,
|
||||
query_text="Germany",
|
||||
save_interaction=False,
|
||||
verbose=True,
|
||||
)
|
||||
completion_summaries = await cognee.search(
|
||||
query_type=SearchType.SUMMARIES,
|
||||
query_text="Germany",
|
||||
save_interaction=False,
|
||||
verbose=True,
|
||||
)
|
||||
completion_rag = await cognee.search(
|
||||
query_type=SearchType.RAG_COMPLETION,
|
||||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=False,
|
||||
verbose=True,
|
||||
)
|
||||
completion_temporal = await cognee.search(
|
||||
query_type=SearchType.TEMPORAL,
|
||||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=False,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
await cognee.search(
|
||||
query_type=SearchType.FEEDBACK,
|
||||
query_text="This answer was great",
|
||||
last_k=1,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Snapshot after all E2E operations above (used by assertion-only tests).
|
||||
|
|
|
|||
|
|
@ -718,3 +718,49 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set
|
|||
|
||||
with pytest.raises(ValueError):
|
||||
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
|
||||
|
||||
|
||||
def test_normalize_query_distance_lists_flat_list_single_query(setup_graph):
|
||||
"""Test that flat list is normalized to list-of-lists with length 1 for single-query mode."""
|
||||
graph = setup_graph
|
||||
flat_list = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)]
|
||||
|
||||
result = graph._normalize_query_distance_lists(flat_list, query_list_length=None, name="test")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] == flat_list
|
||||
|
||||
|
||||
def test_normalize_query_distance_lists_nested_list_batch_mode(setup_graph):
|
||||
"""Test that nested list is used as-is when query_list_length matches."""
|
||||
graph = setup_graph
|
||||
nested_list = [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
|
||||
result = graph._normalize_query_distance_lists(nested_list, query_list_length=2, name="test")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result == nested_list
|
||||
|
||||
|
||||
def test_normalize_query_distance_lists_raises_on_length_mismatch(setup_graph):
|
||||
"""Test that ValueError is raised when nested list length doesn't match query_list_length."""
|
||||
graph = setup_graph
|
||||
nested_list = [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="test has 2 query lists, but query_list_length is 3"):
|
||||
graph._normalize_query_distance_lists(nested_list, query_list_length=3, name="test")
|
||||
|
||||
|
||||
def test_normalize_query_distance_lists_empty_list(setup_graph):
|
||||
"""Test that empty list returns empty list."""
|
||||
graph = setup_graph
|
||||
|
||||
result = graph._normalize_query_distance_lists([], query_list_length=None, name="test")
|
||||
|
||||
assert result == []
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ async def test_brute_force_triplet_search_empty_query():
|
|||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_none_query():
|
||||
"""Test that None query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'."):
|
||||
await brute_force_triplet_search(query=None)
|
||||
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ async def test_brute_force_triplet_search_wide_search_limit_global_search():
|
|||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
|
|
@ -79,7 +79,7 @@ async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
|
|||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
|
|
@ -101,7 +101,7 @@ async def test_brute_force_triplet_search_wide_search_default():
|
|||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
|
@ -119,7 +119,7 @@ async def test_brute_force_triplet_search_default_collections():
|
|||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test")
|
||||
|
|
@ -149,7 +149,7 @@ async def test_brute_force_triplet_search_custom_collections():
|
|||
custom_collections = ["CustomCol1", "CustomCol2"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=custom_collections)
|
||||
|
|
@ -171,7 +171,7 @@ async def test_brute_force_triplet_search_always_includes_edge_collection():
|
|||
collections_without_edge = ["Entity_name", "TextSummary_text"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=collections_without_edge)
|
||||
|
|
@ -194,7 +194,7 @@ async def test_brute_force_triplet_search_all_collections_empty():
|
|||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
results = await brute_force_triplet_search(query="test")
|
||||
|
|
@ -216,7 +216,7 @@ async def test_brute_force_triplet_search_embeds_query():
|
|||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query=query_text)
|
||||
|
|
@ -249,7 +249,7 @@ async def test_brute_force_triplet_search_extracts_node_ids_global_search():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -279,7 +279,7 @@ async def test_brute_force_triplet_search_reuses_provided_fragment():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -311,7 +311,7 @@ async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -340,7 +340,7 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -351,7 +351,9 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
|
|||
custom_top_k = 15
|
||||
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
|
||||
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(
|
||||
k=custom_top_k, query_list_length=None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -430,7 +432,7 @@ async def test_brute_force_triplet_search_deduplicates_node_ids():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -471,7 +473,7 @@ async def test_brute_force_triplet_search_excludes_edge_collection():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -523,7 +525,7 @@ async def test_brute_force_triplet_search_skips_nodes_without_ids():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -564,7 +566,7 @@ async def test_brute_force_triplet_search_handles_tuple_results():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -606,7 +608,7 @@ async def test_brute_force_triplet_search_mixed_empty_collections():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -689,7 +691,7 @@ async def test_brute_force_triplet_search_vector_engine_init_error():
|
|||
"""Test brute_force_triplet_search handles vector engine initialization error (lines 145-147)."""
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine"
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine"
|
||||
) as mock_get_vector_engine,
|
||||
):
|
||||
mock_get_vector_engine.side_effect = Exception("Initialization error")
|
||||
|
|
@ -716,7 +718,7 @@ async def test_brute_force_triplet_search_collection_not_found_error():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -743,7 +745,7 @@ async def test_brute_force_triplet_search_generic_exception():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
):
|
||||
|
|
@ -769,7 +771,7 @@ async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_no
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -804,7 +806,7 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level():
|
|||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -815,3 +817,237 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level():
|
|||
result = await brute_force_triplet_search(query="test query")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_single_query_regression():
|
||||
"""Test that single-query mode maintains legacy behavior (flat list, ID filtering)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("node1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
result = await brute_force_triplet_search(
|
||||
query="q1", query_batch=None, wide_search_top_k=10, node_name=None
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert not (result and isinstance(result[0], list))
|
||||
mock_get_fragment.assert_called_once()
|
||||
call_kwargs = mock_get_fragment.call_args[1]
|
||||
assert call_kwargs["relevant_ids_to_filter"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_batch_wiring_happy_path():
|
||||
"""Test that batch mode returns list-of-lists and skips ID filtering."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.batch_search = AsyncMock(
|
||||
return_value=[
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
result = await brute_force_triplet_search(query_batch=["q1", "q2"])
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], list)
|
||||
assert isinstance(result[1], list)
|
||||
mock_get_fragment.assert_called_once()
|
||||
call_kwargs = mock_get_fragment.call_args[1]
|
||||
assert call_kwargs["relevant_ids_to_filter"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_shape_propagation_to_graph():
|
||||
"""Test that query_list_length is passed through to graph mapping methods."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.batch_search = AsyncMock(
|
||||
return_value=[
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
),
|
||||
):
|
||||
await brute_force_triplet_search(query_batch=["q1", "q2"])
|
||||
|
||||
mock_fragment.map_vector_distances_to_graph_nodes.assert_called_once()
|
||||
node_call_kwargs = mock_fragment.map_vector_distances_to_graph_nodes.call_args[1]
|
||||
assert "query_list_length" in node_call_kwargs
|
||||
assert node_call_kwargs["query_list_length"] == 2
|
||||
|
||||
mock_fragment.map_vector_distances_to_graph_edges.assert_called_once()
|
||||
edge_call_kwargs = mock_fragment.map_vector_distances_to_graph_edges.call_args[1]
|
||||
assert "query_list_length" in edge_call_kwargs
|
||||
assert edge_call_kwargs["query_list_length"] == 2
|
||||
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once()
|
||||
importance_call_kwargs = mock_fragment.calculate_top_triplet_importances.call_args[1]
|
||||
assert "query_list_length" in importance_call_kwargs
|
||||
assert importance_call_kwargs["query_list_length"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_batch_path_comprehensive():
|
||||
"""Test batch mode: returns list-of-lists, skips ID filtering, passes None for wide_search_limit."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
|
||||
def batch_search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
elif collection_name == "EdgeType_relationship_name":
|
||||
return [
|
||||
[MockScoredResult("edge1", 0.92)],
|
||||
[MockScoredResult("edge2", 0.88)],
|
||||
]
|
||||
return [[], []]
|
||||
|
||||
mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
result = await brute_force_triplet_search(
|
||||
query_batch=["q1", "q2"], collections=["Entity_name", "EdgeType_relationship_name"]
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], list)
|
||||
assert isinstance(result[1], list)
|
||||
|
||||
mock_get_fragment.assert_called_once()
|
||||
fragment_call_kwargs = mock_get_fragment.call_args[1]
|
||||
assert fragment_call_kwargs["relevant_ids_to_filter"] is None
|
||||
|
||||
batch_search_calls = mock_vector_engine.batch_search.call_args_list
|
||||
assert len(batch_search_calls) > 0
|
||||
for call in batch_search_calls:
|
||||
assert call[1]["limit"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_batch_error_fallback():
|
||||
"""Test that CollectionNotFoundError in batch mode returns [[], []] matching batch length."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.batch_search = AsyncMock(
|
||||
side_effect=CollectionNotFoundError("Collection not found")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
result = await brute_force_triplet_search(query_batch=["q1", "q2"])
|
||||
|
||||
assert result == [[], []]
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cognee_graph_mapping_batch_shapes():
|
||||
"""Test that CogneeGraph mapping methods accept list-of-lists with query_list_length set."""
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
||||
|
||||
graph = CogneeGraph()
|
||||
node1 = Node("node1", {"name": "Node1"})
|
||||
node2 = Node("node2", {"name": "Node2"})
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(node1, node2, attributes={"edge_text": "relates_to"})
|
||||
graph.add_edge(edge)
|
||||
|
||||
node_distances_batch = {
|
||||
"Entity_name": [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
}
|
||||
|
||||
edge_distances_batch = [
|
||||
[MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})],
|
||||
[MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})],
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(
|
||||
node_distances=node_distances_batch, query_list_length=2
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
edge_distances=edge_distances_batch, query_list_length=2
|
||||
)
|
||||
|
||||
assert node1.attributes.get("vector_distance") == [0.95, 3.5]
|
||||
assert node2.attributes.get("vector_distance") == [3.5, 0.87]
|
||||
assert edge.attributes.get("vector_distance") == [0.92, 0.88]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,273 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
class MockScoredResult:
|
||||
"""Mock class for vector search results."""
|
||||
|
||||
def __init__(self, id, score, payload=None):
|
||||
self.id = id
|
||||
self.score = score
|
||||
self.payload = payload or {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_single_query_shape():
|
||||
"""Test that single query mode produces flat lists (not list-of-lists)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
node_results = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)]
|
||||
edge_results = [MockScoredResult("edge1", 0.92)]
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "EdgeType_relationship_name":
|
||||
return edge_results
|
||||
return node_results
|
||||
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = ["Entity_name", "EdgeType_relationship_name"]
|
||||
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
||||
)
|
||||
|
||||
assert vector_search.query_list_length is None
|
||||
assert vector_search.edge_distances == edge_results
|
||||
assert vector_search.node_distances["Entity_name"] == node_results
|
||||
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with(["test query"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_batch_query_shape_and_empties():
|
||||
"""Test that batch query mode produces list-of-lists with correct length and handles empty collections."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
|
||||
query_batch = ["query a", "query b"]
|
||||
node_results_query_a = [MockScoredResult("node1", 0.95)]
|
||||
node_results_query_b = [MockScoredResult("node2", 0.87)]
|
||||
edge_results_query_a = [MockScoredResult("edge1", 0.92)]
|
||||
edge_results_query_b = []
|
||||
|
||||
def batch_search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "EdgeType_relationship_name":
|
||||
return [edge_results_query_a, edge_results_query_b]
|
||||
elif collection_name == "Entity_name":
|
||||
return [node_results_query_a, node_results_query_b]
|
||||
elif collection_name == "MissingCollection":
|
||||
raise CollectionNotFoundError("Collection not found")
|
||||
return [[], []]
|
||||
|
||||
mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect)
|
||||
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = [
|
||||
"Entity_name",
|
||||
"EdgeType_relationship_name",
|
||||
"MissingCollection",
|
||||
"EmptyCollection",
|
||||
]
|
||||
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query=None, query_batch=query_batch, collections=collections, wide_search_limit=None
|
||||
)
|
||||
|
||||
assert vector_search.query_list_length == 2
|
||||
assert len(vector_search.edge_distances) == 2
|
||||
assert vector_search.edge_distances[0] == edge_results_query_a
|
||||
assert vector_search.edge_distances[1] == edge_results_query_b
|
||||
assert len(vector_search.node_distances["Entity_name"]) == 2
|
||||
assert vector_search.node_distances["Entity_name"][0] == node_results_query_a
|
||||
assert vector_search.node_distances["Entity_name"][1] == node_results_query_b
|
||||
assert len(vector_search.node_distances["MissingCollection"]) == 2
|
||||
assert vector_search.node_distances["MissingCollection"] == [[], []]
|
||||
assert len(vector_search.node_distances["EmptyCollection"]) == 2
|
||||
assert vector_search.node_distances["EmptyCollection"] == [[], []]
|
||||
mock_vector_engine.embedding_engine.embed_text.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_input_validation_both_provided():
|
||||
"""Test that providing both query and query_batch raises ValueError."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = ["Entity_name"]
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot provide both 'query' and 'query_batch'"):
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query="test", query_batch=["test1", "test2"], collections=collections
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_input_validation_neither_provided():
|
||||
"""Test that providing neither query nor query_batch raises ValueError."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = ["Entity_name"]
|
||||
|
||||
with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'"):
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query=None, query_batch=None, collections=collections
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_extract_relevant_node_ids_single_query():
|
||||
"""Test that extract_relevant_node_ids returns IDs for single query mode."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
vector_search.query_list_length = None
|
||||
vector_search.node_distances = {
|
||||
"Entity_name": [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)],
|
||||
"TextSummary_text": [MockScoredResult("node1", 0.90), MockScoredResult("node3", 0.92)],
|
||||
}
|
||||
|
||||
node_ids = vector_search.extract_relevant_node_ids()
|
||||
assert set(node_ids) == {"node1", "node2", "node3"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_extract_relevant_node_ids_batch():
|
||||
"""Test that extract_relevant_node_ids returns empty list for batch mode."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
vector_search.query_list_length = 2
|
||||
vector_search.node_distances = {
|
||||
"Entity_name": [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
],
|
||||
}
|
||||
|
||||
node_ids = vector_search.extract_relevant_node_ids()
|
||||
assert node_ids == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_has_results_single_query():
|
||||
"""Test has_results returns True when results exist and False when only empties."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
|
||||
vector_search.edge_distances = [MockScoredResult("edge1", 0.92)]
|
||||
vector_search.node_distances = {}
|
||||
assert vector_search.has_results() is True
|
||||
|
||||
vector_search.edge_distances = []
|
||||
vector_search.node_distances = {"Entity_name": [MockScoredResult("node1", 0.95)]}
|
||||
assert vector_search.has_results() is True
|
||||
|
||||
vector_search.edge_distances = []
|
||||
vector_search.node_distances = {}
|
||||
assert vector_search.has_results() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_has_results_batch():
|
||||
"""Test has_results works correctly for batch mode with list-of-lists."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
vector_search.query_list_length = 2
|
||||
|
||||
vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []]
|
||||
vector_search.node_distances = {}
|
||||
assert vector_search.has_results() is True
|
||||
|
||||
vector_search.edge_distances = [[], []]
|
||||
vector_search.node_distances = {
|
||||
"Entity_name": [[MockScoredResult("node1", 0.95)], []],
|
||||
}
|
||||
assert vector_search.has_results() is True
|
||||
|
||||
vector_search.edge_distances = [[], []]
|
||||
vector_search.node_distances = {"Entity_name": [[], []]}
|
||||
assert vector_search.has_results() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_single_query_collection_not_found():
|
||||
"""Test that CollectionNotFoundError in single query mode returns empty list."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(
|
||||
side_effect=CollectionNotFoundError("Collection not found")
|
||||
)
|
||||
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = ["MissingCollection"]
|
||||
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
||||
)
|
||||
|
||||
assert vector_search.node_distances["MissingCollection"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_missing_collections_single_query():
|
||||
"""Test that missing collections in single-query mode are handled gracefully with empty lists."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
node_result = MockScoredResult("node1", 0.95)
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [node_result]
|
||||
elif collection_name == "MissingCollection":
|
||||
raise CollectionNotFoundError("Collection not found")
|
||||
return []
|
||||
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = ["Entity_name", "MissingCollection", "EmptyCollection"]
|
||||
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
||||
)
|
||||
|
||||
assert len(vector_search.node_distances["Entity_name"]) == 1
|
||||
assert vector_search.node_distances["Entity_name"][0].id == "node1"
|
||||
assert vector_search.node_distances["Entity_name"][0].score == 0.95
|
||||
assert vector_search.node_distances["MissingCollection"] == []
|
||||
assert vector_search.node_distances["EmptyCollection"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_has_results_batch_nodes_only():
|
||||
"""Test has_results returns True when only node distances are populated in batch mode."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
vector_search.query_list_length = 2
|
||||
vector_search.edge_distances = [[], []]
|
||||
vector_search.node_distances = {
|
||||
"Entity_name": [[MockScoredResult("node1", 0.95)], []],
|
||||
}
|
||||
|
||||
assert vector_search.has_results() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_has_results_batch_edges_only():
|
||||
"""Test has_results returns True when only edge distances are populated in batch mode."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
vector_search.query_list_length = 2
|
||||
vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []]
|
||||
vector_search.node_distances = {}
|
||||
|
||||
assert vector_search.has_results() is True
|
||||
|
|
@ -129,14 +129,32 @@ async def test_search_access_control_returns_dataset_shaped_dicts(monkeypatch, s
|
|||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
out_non_verbose = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
assert out == [
|
||||
assert out_non_verbose == [
|
||||
{
|
||||
"search_result": ["r"],
|
||||
"dataset_id": ds.id,
|
||||
"dataset_name": "ds1",
|
||||
"dataset_tenant_id": "t1",
|
||||
}
|
||||
]
|
||||
|
||||
out_verbose = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert out_verbose == [
|
||||
{
|
||||
"search_result": ["r"],
|
||||
"dataset_id": ds.id,
|
||||
|
|
@ -166,6 +184,7 @@ async def test_search_access_control_only_context_returns_dataset_shaped_dicts(
|
|||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
only_context=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert out == [
|
||||
|
|
@ -180,35 +199,7 @@ async def test_search_access_control_only_context_returns_dataset_shaped_dicts(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_use_combined_context_returns_combined_model(
|
||||
monkeypatch, search_mod
|
||||
):
|
||||
user = _make_user()
|
||||
ds1 = _make_dataset(name="ds1", tenant_id="t1")
|
||||
ds2 = _make_dataset(name="ds2", tenant_id="t1")
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return ("answer", {"k": "v"}, [ds1, ds2])
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds1.id, ds2.id],
|
||||
user=user,
|
||||
use_combined_context=True,
|
||||
)
|
||||
|
||||
assert out.result == "answer"
|
||||
assert out.context == {"k": "v"}
|
||||
assert out.graphs == {}
|
||||
assert [d.id for d in out.datasets] == [ds1.id, ds2.id]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod):
|
||||
async def test_authorized_search_delegates_to_search_in_datasets_context(monkeypatch, search_mod):
|
||||
user = _make_user()
|
||||
ds = _make_dataset(name="ds1")
|
||||
|
||||
|
|
@ -218,7 +209,6 @@ async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod)
|
|||
expected = [("r", ["ctx"], [ds])]
|
||||
|
||||
async def dummy_search_in_datasets_context(**kwargs):
|
||||
assert kwargs["use_combined_context"] is False if "use_combined_context" in kwargs else True
|
||||
return expected
|
||||
|
||||
monkeypatch.setattr(
|
||||
|
|
@ -231,104 +221,12 @@ async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod)
|
|||
query_text="q",
|
||||
user=user,
|
||||
dataset_ids=[ds.id],
|
||||
use_combined_context=False,
|
||||
only_context=False,
|
||||
)
|
||||
|
||||
assert out == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorized_search_use_combined_context_joins_string_context(monkeypatch, search_mod):
|
||||
user = _make_user()
|
||||
ds1 = _make_dataset(name="ds1")
|
||||
ds2 = _make_dataset(name="ds2")
|
||||
|
||||
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
|
||||
return [ds1, ds2]
|
||||
|
||||
async def dummy_search_in_datasets_context(**kwargs):
|
||||
assert kwargs["only_context"] is True
|
||||
return [(None, ["a"], [ds1]), (None, ["b"], [ds2])]
|
||||
|
||||
seen = {}
|
||||
|
||||
async def dummy_get_completion(query_text, context, session_id=None):
|
||||
seen["query_text"] = query_text
|
||||
seen["context"] = context
|
||||
seen["session_id"] = session_id
|
||||
return ["answer"]
|
||||
|
||||
async def dummy_get_search_type_tools(**_kwargs):
|
||||
return [dummy_get_completion, lambda *_a, **_k: None]
|
||||
|
||||
monkeypatch.setattr(
|
||||
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
|
||||
)
|
||||
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
|
||||
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
|
||||
|
||||
completion, combined_context, datasets = await search_mod.authorized_search(
|
||||
query_type=SearchType.CHUNKS,
|
||||
query_text="q",
|
||||
user=user,
|
||||
dataset_ids=[ds1.id, ds2.id],
|
||||
use_combined_context=True,
|
||||
session_id="s1",
|
||||
)
|
||||
|
||||
assert combined_context == "a\nb"
|
||||
assert completion == ["answer"]
|
||||
assert datasets == [ds1, ds2]
|
||||
assert seen == {"query_text": "q", "context": "a\nb", "session_id": "s1"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorized_search_use_combined_context_keeps_non_string_context(
|
||||
monkeypatch, search_mod
|
||||
):
|
||||
user = _make_user()
|
||||
ds1 = _make_dataset(name="ds1")
|
||||
ds2 = _make_dataset(name="ds2")
|
||||
|
||||
class DummyEdge:
|
||||
pass
|
||||
|
||||
e1, e2 = DummyEdge(), DummyEdge()
|
||||
|
||||
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
|
||||
return [ds1, ds2]
|
||||
|
||||
async def dummy_search_in_datasets_context(**_kwargs):
|
||||
return [(None, [e1], [ds1]), (None, [e2], [ds2])]
|
||||
|
||||
async def dummy_get_completion(query_text, context, session_id=None):
|
||||
assert query_text == "q"
|
||||
assert context == [e1, e2]
|
||||
return ["answer"]
|
||||
|
||||
async def dummy_get_search_type_tools(**_kwargs):
|
||||
return [dummy_get_completion]
|
||||
|
||||
monkeypatch.setattr(
|
||||
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
|
||||
)
|
||||
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
|
||||
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
|
||||
|
||||
completion, combined_context, datasets = await search_mod.authorized_search(
|
||||
query_type=SearchType.CHUNKS,
|
||||
query_text="q",
|
||||
user=user,
|
||||
dataset_ids=[ds1.id, ds2.id],
|
||||
use_combined_context=True,
|
||||
)
|
||||
|
||||
assert combined_context == [e1, e2]
|
||||
assert completion == ["answer"]
|
||||
assert datasets == [ds1, ds2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_in_datasets_context_two_tool_context_override_and_is_empty_branches(
|
||||
monkeypatch, search_mod
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ async def test_search_access_control_edges_context_produces_graphs_and_context_m
|
|||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert out[0]["dataset_name"] == "ds1"
|
||||
|
|
@ -126,6 +127,7 @@ async def test_search_access_control_insights_context_produces_graphs_and_null_r
|
|||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert out[0]["graphs"] is not None
|
||||
|
|
@ -150,6 +152,7 @@ async def test_search_access_control_only_context_returns_context_text_map(monke
|
|||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
only_context=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert out[0]["search_result"] == [{"ds1": "a\nb"}]
|
||||
|
|
@ -172,6 +175,7 @@ async def test_search_access_control_results_edges_become_graph_result(monkeypat
|
|||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert isinstance(out[0]["search_result"][0], dict)
|
||||
|
|
@ -179,29 +183,6 @@ async def test_search_access_control_results_edges_become_graph_result(monkeypat
|
|||
assert "edges" in out[0]["search_result"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_use_combined_context_defaults_empty_datasets(monkeypatch, search_mod):
|
||||
user = types.SimpleNamespace(id="u1", tenant_id=None)
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return ("answer", "ctx", [])
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=None,
|
||||
user=user,
|
||||
use_combined_context=True,
|
||||
)
|
||||
|
||||
assert out.result == "answer"
|
||||
assert out.context == {"all available datasets": "ctx"}
|
||||
assert out.datasets[0].name == "all available datasets"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_context_str_branch(monkeypatch, search_mod):
|
||||
"""Covers prepare_search_result(context is str) through search()."""
|
||||
|
|
@ -219,6 +200,7 @@ async def test_search_access_control_context_str_branch(monkeypatch, search_mod)
|
|||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert out[0]["graphs"] is None
|
||||
|
|
@ -242,6 +224,7 @@ async def test_search_access_control_context_empty_list_branch(monkeypatch, sear
|
|||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert out[0]["graphs"] is None
|
||||
|
|
@ -265,6 +248,7 @@ async def test_search_access_control_multiple_results_list_branch(monkeypatch, s
|
|||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert out[0]["search_result"] == [["r1", "r2"]]
|
||||
|
|
@ -293,4 +277,5 @@ async def test_search_access_control_defaults_empty_datasets(monkeypatch, search
|
|||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=None,
|
||||
user=user,
|
||||
verbose=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,19 +20,29 @@ echo "HTTP port: $HTTP_PORT"
|
|||
# smooth redeployments and container restarts while maintaining data integrity.
|
||||
echo "Running database migrations..."
|
||||
|
||||
set +e # Disable exit on error to handle specific migration errors
|
||||
MIGRATION_OUTPUT=$(alembic upgrade head)
|
||||
MIGRATION_EXIT_CODE=$?
|
||||
set -e
|
||||
|
||||
if [[ $MIGRATION_EXIT_CODE -ne 0 ]]; then
|
||||
if [[ "$MIGRATION_OUTPUT" == *"UserAlreadyExists"* ]] || [[ "$MIGRATION_OUTPUT" == *"User default_user@example.com already exists"* ]]; then
|
||||
echo "Warning: Default user already exists, continuing startup..."
|
||||
else
|
||||
echo "Migration failed with unexpected error."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
echo "Migration failed with unexpected error. Trying to run Cognee without migrations."
|
||||
|
||||
echo "Database migrations done."
|
||||
echo "Initializing database tables..."
|
||||
python /app/cognee/modules/engine/operations/setup.py
|
||||
INIT_EXIT_CODE=$?
|
||||
|
||||
if [[ $INIT_EXIT_CODE -ne 0 ]]; then
|
||||
echo "Database initialization failed!"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "Database migrations done."
|
||||
fi
|
||||
|
||||
echo "Starting server..."
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
|
||||
import os
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
# By default cognee uses OpenAI's gpt-5-mini LLM model
|
||||
# Provide your OpenAI LLM API KEY
|
||||
os.environ["LLM_API_KEY"] = ""
|
||||
|
|
@ -24,13 +25,13 @@ async def cognee_demo():
|
|||
|
||||
# Query Cognee for information from provided document
|
||||
answer = await cognee.search("List me all the important characters in Alice in Wonderland.")
|
||||
print(answer)
|
||||
pprint(answer)
|
||||
|
||||
answer = await cognee.search("How did Alice end up in Wonderland?")
|
||||
print(answer)
|
||||
pprint(answer)
|
||||
|
||||
answer = await cognee.search("Tell me about Alice's personality.")
|
||||
print(answer)
|
||||
pprint(answer)
|
||||
|
||||
|
||||
# Cognee is an async library, it has to be called in an async context
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
from pprint import pprint
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
|
@ -187,7 +188,7 @@ async def main(enable_steps):
|
|||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text="Who has experience in design tools?"
|
||||
)
|
||||
print(search_results)
|
||||
pprint(search_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
482
examples/python/extract_usage_frequency_example.py
Normal file
482
examples/python/extract_usage_frequency_example.py
Normal file
|
|
@ -0,0 +1,482 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
End-to-End Example: Usage Frequency Tracking in Cognee
|
||||
|
||||
This example demonstrates the complete workflow for tracking and analyzing
|
||||
how frequently different graph elements are accessed through user searches.
|
||||
|
||||
Features demonstrated:
|
||||
- Setting up a knowledge base
|
||||
- Running searches with interaction tracking (save_interaction=True)
|
||||
- Extracting usage frequencies from interaction data
|
||||
- Applying frequency weights to graph nodes
|
||||
- Analyzing and visualizing the results
|
||||
|
||||
Use cases:
|
||||
- Ranking search results by popularity
|
||||
- Identifying "hot topics" in your knowledge base
|
||||
- Understanding user behavior and interests
|
||||
- Improving retrieval based on usage patterns
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import List, Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.tasks.memify.extract_usage_frequency import run_usage_frequency_update
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STEP 1: Setup and Configuration
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def setup_knowledge_base():
|
||||
"""
|
||||
Create a fresh knowledge base with sample content.
|
||||
|
||||
In a real application, you would:
|
||||
- Load documents from files, databases, or APIs
|
||||
- Process larger datasets
|
||||
- Organize content by datasets/categories
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("STEP 1: Setting up knowledge base")
|
||||
print("=" * 80)
|
||||
|
||||
# Reset state for clean demo (optional in production)
|
||||
print("\nResetting Cognee state...")
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
print("✓ Reset complete")
|
||||
|
||||
# Sample content: AI/ML educational material
|
||||
documents = [
|
||||
"""
|
||||
Machine Learning Fundamentals:
|
||||
Machine learning is a subset of artificial intelligence that enables systems
|
||||
to learn and improve from experience without being explicitly programmed.
|
||||
The three main types are supervised learning, unsupervised learning, and
|
||||
reinforcement learning.
|
||||
""",
|
||||
"""
|
||||
Neural Networks Explained:
|
||||
Neural networks are computing systems inspired by biological neural networks.
|
||||
They consist of layers of interconnected nodes (neurons) that process information
|
||||
through weighted connections. Deep learning uses neural networks with many layers
|
||||
to automatically learn hierarchical representations of data.
|
||||
""",
|
||||
"""
|
||||
Natural Language Processing:
|
||||
NLP enables computers to understand, interpret, and generate human language.
|
||||
Modern NLP uses transformer architectures like BERT and GPT, which have
|
||||
revolutionized tasks such as translation, summarization, and question answering.
|
||||
""",
|
||||
"""
|
||||
Computer Vision Applications:
|
||||
Computer vision allows machines to interpret visual information from the world.
|
||||
Convolutional neural networks (CNNs) are particularly effective for image
|
||||
recognition, object detection, and image segmentation tasks.
|
||||
""",
|
||||
]
|
||||
|
||||
print(f"\nAdding {len(documents)} documents to knowledge base...")
|
||||
await cognee.add(documents, dataset_name="ai_ml_fundamentals")
|
||||
print("✓ Documents added")
|
||||
|
||||
# Build knowledge graph
|
||||
print("\nBuilding knowledge graph (cognify)...")
|
||||
await cognee.cognify()
|
||||
print("✓ Knowledge graph built")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STEP 2: Simulate User Searches with Interaction Tracking
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def simulate_user_searches(queries: List[str]):
|
||||
"""
|
||||
Simulate users searching the knowledge base.
|
||||
|
||||
The key parameter is save_interaction=True, which creates:
|
||||
- CogneeUserInteraction nodes (one per search)
|
||||
- used_graph_element_to_answer edges (connecting queries to relevant nodes)
|
||||
|
||||
Args:
|
||||
queries: List of search queries to simulate
|
||||
|
||||
Returns:
|
||||
Number of successful searches
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("STEP 2: Simulating user searches with interaction tracking")
|
||||
print("=" * 80)
|
||||
|
||||
successful_searches = 0
|
||||
|
||||
for i, query in enumerate(queries, 1):
|
||||
print(f"\nSearch {i}/{len(queries)}: '{query}'")
|
||||
try:
|
||||
results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text=query,
|
||||
save_interaction=True, # ← THIS IS CRITICAL!
|
||||
top_k=5,
|
||||
)
|
||||
successful_searches += 1
|
||||
|
||||
# Show snippet of results
|
||||
result_preview = str(results)[:100] if results else "No results"
|
||||
print(f" ✓ Completed ({result_preview}...)")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed: {e}")
|
||||
|
||||
print(f"\n✓ Completed {successful_searches}/{len(queries)} searches")
|
||||
print("=" * 80)
|
||||
|
||||
return successful_searches
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STEP 3: Extract and Apply Usage Frequencies
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def extract_and_apply_frequencies(
|
||||
time_window_days: int = 7, min_threshold: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract usage frequencies from interactions and apply them to the graph.
|
||||
|
||||
This function:
|
||||
1. Retrieves the graph with interaction data
|
||||
2. Counts how often each node was accessed
|
||||
3. Writes frequency_weight property back to nodes
|
||||
|
||||
Args:
|
||||
time_window_days: Only count interactions from last N days
|
||||
min_threshold: Minimum accesses to track (filter out rarely used nodes)
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics about the frequency update
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("STEP 3: Extracting and applying usage frequencies")
|
||||
print("=" * 80)
|
||||
|
||||
# Get graph adapter
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
# Retrieve graph with interactions
|
||||
print("\nRetrieving graph from database...")
|
||||
graph = CogneeGraph()
|
||||
await graph.project_graph_from_db(
|
||||
adapter=graph_engine,
|
||||
node_properties_to_project=[
|
||||
"type",
|
||||
"node_type",
|
||||
"timestamp",
|
||||
"created_at",
|
||||
"text",
|
||||
"name",
|
||||
"query_text",
|
||||
"frequency_weight",
|
||||
],
|
||||
edge_properties_to_project=["relationship_type", "timestamp"],
|
||||
directed=True,
|
||||
)
|
||||
|
||||
print(f"✓ Retrieved: {len(graph.nodes)} nodes, {len(graph.edges)} edges")
|
||||
|
||||
# Count interaction nodes
|
||||
interaction_nodes = [
|
||||
n
|
||||
for n in graph.nodes.values()
|
||||
if n.attributes.get("type") == "CogneeUserInteraction"
|
||||
or n.attributes.get("node_type") == "CogneeUserInteraction"
|
||||
]
|
||||
print(f"✓ Found {len(interaction_nodes)} interaction nodes")
|
||||
|
||||
# Run frequency extraction and update
|
||||
print(f"\nExtracting frequencies (time window: {time_window_days} days)...")
|
||||
stats = await run_usage_frequency_update(
|
||||
graph_adapter=graph_engine,
|
||||
subgraphs=[graph],
|
||||
time_window=timedelta(days=time_window_days),
|
||||
min_interaction_threshold=min_threshold,
|
||||
)
|
||||
|
||||
print("\n✓ Frequency extraction complete!")
|
||||
print(
|
||||
f" - Interactions processed: {stats['interactions_in_window']}/{stats['total_interactions']}"
|
||||
)
|
||||
print(f" - Nodes weighted: {len(stats['node_frequencies'])}")
|
||||
print(f" - Element types tracked: {stats.get('element_type_frequencies', {})}")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STEP 4: Analyze and Display Results
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def analyze_results(stats: Dict[str, Any]):
|
||||
"""
|
||||
Analyze and display the frequency tracking results.
|
||||
|
||||
Shows:
|
||||
- Top most frequently accessed nodes
|
||||
- Element type distribution
|
||||
- Verification that weights were written to database
|
||||
|
||||
Args:
|
||||
stats: Statistics from frequency extraction
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("STEP 4: Analyzing usage frequency results")
|
||||
print("=" * 80)
|
||||
|
||||
# Display top nodes by frequency
|
||||
if stats["node_frequencies"]:
|
||||
print("\n📊 Top 10 Most Frequently Accessed Elements:")
|
||||
print("-" * 80)
|
||||
|
||||
sorted_nodes = sorted(stats["node_frequencies"].items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Get graph to display node details
|
||||
graph_engine = await get_graph_engine()
|
||||
graph = CogneeGraph()
|
||||
await graph.project_graph_from_db(
|
||||
adapter=graph_engine,
|
||||
node_properties_to_project=["type", "text", "name"],
|
||||
edge_properties_to_project=[],
|
||||
directed=True,
|
||||
)
|
||||
|
||||
for i, (node_id, frequency) in enumerate(sorted_nodes[:10], 1):
|
||||
node = graph.get_node(node_id)
|
||||
if node:
|
||||
node_type = node.attributes.get("type", "Unknown")
|
||||
text = node.attributes.get("text") or node.attributes.get("name") or ""
|
||||
text_preview = text[:60] + "..." if len(text) > 60 else text
|
||||
|
||||
print(f"\n{i}. Frequency: {frequency} accesses")
|
||||
print(f" Type: {node_type}")
|
||||
print(f" Content: {text_preview}")
|
||||
else:
|
||||
print(f"\n{i}. Frequency: {frequency} accesses")
|
||||
print(f" Node ID: {node_id[:50]}...")
|
||||
|
||||
# Display element type distribution
|
||||
if stats.get("element_type_frequencies"):
|
||||
print("\n\n📈 Element Type Distribution:")
|
||||
print("-" * 80)
|
||||
type_dist = stats["element_type_frequencies"]
|
||||
for elem_type, count in sorted(type_dist.items(), key=lambda x: x[1], reverse=True):
|
||||
print(f" {elem_type}: {count} accesses")
|
||||
|
||||
# Verify weights in database (Neo4j only)
|
||||
print("\n\n🔍 Verifying weights in database...")
|
||||
print("-" * 80)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
adapter_type = type(graph_engine).__name__
|
||||
|
||||
if adapter_type == "Neo4jAdapter":
|
||||
try:
|
||||
result = await graph_engine.query("""
|
||||
MATCH (n)
|
||||
WHERE n.frequency_weight IS NOT NULL
|
||||
RETURN count(n) as weighted_count
|
||||
""")
|
||||
|
||||
count = result[0]["weighted_count"] if result else 0
|
||||
if count > 0:
|
||||
print(f"✓ {count} nodes have frequency_weight in Neo4j database")
|
||||
|
||||
# Show sample
|
||||
sample = await graph_engine.query("""
|
||||
MATCH (n)
|
||||
WHERE n.frequency_weight IS NOT NULL
|
||||
RETURN n.frequency_weight as weight, labels(n) as labels
|
||||
ORDER BY n.frequency_weight DESC
|
||||
LIMIT 3
|
||||
""")
|
||||
|
||||
print("\nSample weighted nodes:")
|
||||
for row in sample:
|
||||
print(f" - Weight: {row['weight']}, Type: {row['labels']}")
|
||||
else:
|
||||
print("⚠ No nodes with frequency_weight found in database")
|
||||
except Exception as e:
|
||||
print(f"Could not verify in Neo4j: {e}")
|
||||
else:
|
||||
print(f"Database verification not implemented for {adapter_type}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STEP 5: Demonstrate Usage in Retrieval
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def demonstrate_retrieval_usage():
|
||||
"""
|
||||
Demonstrate how frequency weights can be used in retrieval.
|
||||
|
||||
Note: This is a conceptual demonstration. To actually use frequency
|
||||
weights in ranking, you would need to modify the retrieval/completion
|
||||
strategies to incorporate the frequency_weight property.
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("STEP 5: How to use frequency weights in retrieval")
|
||||
print("=" * 80)
|
||||
|
||||
print("""
|
||||
Frequency weights can be used to improve search results:
|
||||
|
||||
1. RANKING BOOST:
|
||||
- Multiply relevance scores by frequency_weight
|
||||
- Prioritize frequently accessed nodes in results
|
||||
|
||||
2. COMPLETION STRATEGIES:
|
||||
- Adjust triplet importance based on usage
|
||||
- Filter out rarely accessed information
|
||||
|
||||
3. ANALYTICS:
|
||||
- Track trending topics over time
|
||||
- Understand user interests and behavior
|
||||
- Identify knowledge gaps (low-frequency nodes)
|
||||
|
||||
4. ADAPTIVE RETRIEVAL:
|
||||
- Personalize results based on team usage patterns
|
||||
- Surface popular answers faster
|
||||
|
||||
Example Cypher query with frequency boost (Neo4j):
|
||||
|
||||
MATCH (n)
|
||||
WHERE n.text CONTAINS $search_term
|
||||
RETURN n, n.frequency_weight as boost
|
||||
ORDER BY (n.relevance_score * COALESCE(n.frequency_weight, 1)) DESC
|
||||
LIMIT 10
|
||||
|
||||
To integrate this into Cognee, you would modify the completion
|
||||
strategy to include frequency_weight in the scoring function.
|
||||
""")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MAIN: Run Complete Example
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Run the complete end-to-end usage frequency tracking example.
|
||||
"""
|
||||
print("\n")
|
||||
print("╔" + "=" * 78 + "╗")
|
||||
print("║" + " " * 78 + "║")
|
||||
print("║" + " Usage Frequency Tracking - End-to-End Example".center(78) + "║")
|
||||
print("║" + " " * 78 + "║")
|
||||
print("╚" + "=" * 78 + "╝")
|
||||
print("\n")
|
||||
|
||||
# Configuration check
|
||||
print("Configuration:")
|
||||
print(f" Graph Provider: {os.getenv('GRAPH_DATABASE_PROVIDER')}")
|
||||
print(f" Graph Handler: {os.getenv('GRAPH_DATASET_HANDLER')}")
|
||||
print(f" LLM Provider: {os.getenv('LLM_PROVIDER')}")
|
||||
|
||||
# Verify LLM key is set
|
||||
if not os.getenv("LLM_API_KEY") or os.getenv("LLM_API_KEY") == "sk-your-key-here":
|
||||
print("\n⚠ WARNING: LLM_API_KEY not set in .env file")
|
||||
print(" Set your API key to run searches")
|
||||
return
|
||||
|
||||
print("\n")
|
||||
|
||||
try:
|
||||
# Step 1: Setup
|
||||
await setup_knowledge_base()
|
||||
|
||||
# Step 2: Simulate searches
|
||||
# Note: Repeat queries increase frequency for those topics
|
||||
queries = [
|
||||
"What is machine learning?",
|
||||
"Explain neural networks",
|
||||
"How does deep learning work?",
|
||||
"Tell me about neural networks", # Repeat - increases frequency
|
||||
"What are transformers in NLP?",
|
||||
"Explain neural networks again", # Another repeat
|
||||
"How does computer vision work?",
|
||||
"What is reinforcement learning?",
|
||||
"Tell me more about neural networks", # Third repeat
|
||||
]
|
||||
|
||||
successful_searches = await simulate_user_searches(queries)
|
||||
|
||||
if successful_searches == 0:
|
||||
print("⚠ No searches completed - cannot demonstrate frequency tracking")
|
||||
return
|
||||
|
||||
# Step 3: Extract frequencies
|
||||
stats = await extract_and_apply_frequencies(time_window_days=7, min_threshold=1)
|
||||
|
||||
# Step 4: Analyze results
|
||||
await analyze_results(stats)
|
||||
|
||||
# Step 5: Show usage examples
|
||||
await demonstrate_retrieval_usage()
|
||||
|
||||
# Summary
|
||||
print("\n")
|
||||
print("╔" + "=" * 78 + "╗")
|
||||
print("║" + " " * 78 + "║")
|
||||
print("║" + " Example Complete!".center(78) + "║")
|
||||
print("║" + " " * 78 + "║")
|
||||
print("╚" + "=" * 78 + "╝")
|
||||
print("\n")
|
||||
|
||||
print("Summary:")
|
||||
print(" ✓ Documents added: 4")
|
||||
print(f" ✓ Searches performed: {successful_searches}")
|
||||
print(f" ✓ Interactions tracked: {stats['interactions_in_window']}")
|
||||
print(f" ✓ Nodes weighted: {len(stats['node_frequencies'])}")
|
||||
|
||||
print("\nNext steps:")
|
||||
print(" 1. Open Neo4j Browser (http://localhost:7474) to explore the graph")
|
||||
print(" 2. Modify retrieval strategies to use frequency_weight")
|
||||
print(" 3. Build analytics dashboards using element_type_frequencies")
|
||||
print(" 4. Run periodic frequency updates to track trends over time")
|
||||
|
||||
print("\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Example failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
import asyncio
|
||||
import pathlib
|
||||
from pprint import pprint
|
||||
|
||||
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||
|
||||
import cognee
|
||||
|
|
@ -42,7 +44,7 @@ async def main():
|
|||
|
||||
# Display search results
|
||||
for result_text in search_results:
|
||||
print(result_text)
|
||||
pprint(result_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import os
|
||||
from pprint import pprint
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
|
@ -77,7 +78,7 @@ async def main():
|
|||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="What are the exact cars and their types produced by Audi?",
|
||||
)
|
||||
print(search_results)
|
||||
pprint(search_results)
|
||||
|
||||
await visualize_graph()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import cognee
|
||||
import pathlib
|
||||
from pprint import pprint
|
||||
|
||||
from cognee.modules.users.exceptions import PermissionDeniedError
|
||||
from cognee.modules.users.tenants.methods import select_tenant
|
||||
|
|
@ -86,7 +87,7 @@ async def main():
|
|||
)
|
||||
print("\nSearch results as user_1 on dataset owned by user_1:")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
pprint(result)
|
||||
|
||||
# But user_1 cant read the dataset owned by user_2 (QUANTUM dataset)
|
||||
print("\nSearch result as user_1 on the dataset owned by user_2:")
|
||||
|
|
@ -134,7 +135,7 @@ async def main():
|
|||
dataset_ids=[quantum_dataset_id],
|
||||
)
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
pprint(result)
|
||||
|
||||
# If we'd like for user_1 to add new documents to the QUANTUM dataset owned by user_2, user_1 would have to get
|
||||
# "write" access permission, which user_1 currently does not have
|
||||
|
|
@ -217,7 +218,7 @@ async def main():
|
|||
dataset_ids=[quantum_cognee_lab_dataset_id],
|
||||
)
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
pprint(result)
|
||||
|
||||
# Note: All of these function calls and permission system is available through our backend endpoints as well
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import asyncio
|
||||
from pprint import pprint
|
||||
|
||||
import cognee
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
|
@ -71,7 +73,7 @@ async def main():
|
|||
print("Search results:")
|
||||
# Display results
|
||||
for result_text in search_results:
|
||||
print(result_text)
|
||||
pprint(result_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import asyncio
|
||||
from pprint import pprint
|
||||
|
||||
import cognee
|
||||
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
|
@ -54,7 +56,7 @@ async def main():
|
|||
print("Search results:")
|
||||
# Display results
|
||||
for result_text in search_results:
|
||||
print(result_text)
|
||||
pprint(result_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
from pprint import pprint
|
||||
import cognee
|
||||
from cognee.shared.logging_utils import setup_logging, INFO
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
|
@ -87,7 +88,8 @@ async def main():
|
|||
top_k=15,
|
||||
)
|
||||
print(f"Query: {query_text}")
|
||||
print(f"Results: {search_results}\n")
|
||||
print("Results:")
|
||||
pprint(search_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
from pprint import pprint
|
||||
|
||||
import cognee
|
||||
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
|
||||
|
|
@ -65,7 +66,7 @@ async def main():
|
|||
query_type=SearchType.TRIPLET_COMPLETION,
|
||||
query_text="What are the models produced by Volkswagen based on the context?",
|
||||
)
|
||||
print(search_results)
|
||||
pprint(search_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "cognee"
|
||||
|
||||
version = "0.5.1.dev0"
|
||||
version = "0.5.1"
|
||||
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
||||
authors = [
|
||||
{ name = "Vasilije Markovic" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue