Remove payload vector search [COG-3708] (#1998)

<!-- .github/pull_request_template.md -->

## Description
Make payload information optional when searching vector databases to
optimize performance

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

# Release Notes

* **New Features**
* Added optional `include_payload` parameter to search operations across
all vector database implementations. When enabled, search results
include complete payload data; disabled by default for improved
performance.

* **Improvements**
* Enhanced result efficiency by making payload optional in search
results.
* Standardized edge identifier generation for consistent graph
operations.
* Optimized data lookups to use direct IDs instead of payload
extraction.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Igor Ilic 2026-01-20 10:25:08 +01:00 committed by GitHub
commit 6e69daa527
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 161 additions and 63 deletions

View file

@ -236,6 +236,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
query_vector: Optional[List[float]] = None,
limit: Optional[int] = None,
with_vector: bool = False,
include_payload: bool = False, # TODO: Add support for this parameter
):
"""
Perform a search in the specified collection using either a text query or a vector
@ -319,7 +320,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
self._na_exception_handler(e, query_string)
async def batch_search(
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
self,
collection_name: str,
query_texts: List[str],
limit: int,
with_vectors: bool = False,
include_payload: bool = False,
):
"""
Perform a batch search using multiple text queries against a collection.
@ -342,7 +348,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
data_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[
self.search(collection_name, None, vector, limit, with_vectors)
self.search(
collection_name,
None,
vector,
limit,
with_vectors,
include_payload=include_payload,
)
for vector in data_vectors
]
)

View file

@ -355,6 +355,7 @@ class ChromaDBAdapter(VectorDBInterface):
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
include_payload: bool = False, # TODO: Add support for this parameter when set to False
):
"""
Search for items in a collection using either a text or a vector query.
@ -441,6 +442,7 @@ class ChromaDBAdapter(VectorDBInterface):
query_texts: List[str],
limit: int = 5,
with_vectors: bool = False,
include_payload: bool = False,
):
"""
Perform multiple searches in a single request for efficiency, returning results for each

View file

@ -231,6 +231,7 @@ class LanceDBAdapter(VectorDBInterface):
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
include_payload: bool = False,
):
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
@ -247,17 +248,27 @@ class LanceDBAdapter(VectorDBInterface):
if limit <= 0:
return []
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
# Note: Exclude payload if not needed to optimize performance
select_columns = (
["id", "vector", "payload", "_distance"]
if include_payload
else ["id", "vector", "_distance"]
)
result_values = (
await collection.vector_search(query_vector)
.select(select_columns)
.limit(limit)
.to_list()
)
if not result_values:
return []
normalized_values = normalize_distances(result_values)
return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
payload=result["payload"] if include_payload else None,
score=normalized_values[value_index],
)
for value_index, result in enumerate(result_values)
@ -269,6 +280,7 @@ class LanceDBAdapter(VectorDBInterface):
query_texts: List[str],
limit: Optional[int] = None,
with_vectors: bool = False,
include_payload: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
@ -279,6 +291,7 @@ class LanceDBAdapter(VectorDBInterface):
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
include_payload=include_payload,
)
for query_vector in query_vectors
]

View file

@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional
from uuid import UUID
from pydantic import BaseModel
@ -12,10 +12,10 @@ class ScoredResult(BaseModel):
- id (UUID): Unique identifier for the scored result.
- score (float): The score associated with the result, where a lower score indicates a
better outcome.
- payload (Dict[str, Any]): Additional information related to the score, stored as
- payload (Optional[Dict[str, Any]]): Additional information related to the score, stored as
key-value pairs in a dictionary.
"""
id: UUID
score: float # Lower score is better
payload: Dict[str, Any]
payload: Optional[Dict[str, Any]] = None

View file

@ -301,6 +301,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query_vector: Optional[List[float]] = None,
limit: Optional[int] = 15,
with_vector: bool = False,
include_payload: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
@ -324,10 +325,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []
# Note: Exclude payload from returned columns if not needed to optimize performance
select_columns = (
[PGVectorDataPoint]
if include_payload
else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector]
)
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
PGVectorDataPoint,
*select_columns,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
).order_by("similarity")
@ -344,7 +351,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_list.append(
{
"id": parse_id(str(vector.id)),
"payload": vector.payload,
"payload": vector.payload if include_payload else None,
"_distance": vector.similarity,
}
)
@ -359,7 +366,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Create and return ScoredResult objects
return [
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
ScoredResult(
id=row.get("id"),
payload=row.get("payload") if include_payload else None,
score=row.get("score"),
)
for row in vector_list
]
@ -369,6 +380,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query_texts: List[str],
limit: int = None,
with_vectors: bool = False,
include_payload: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
@ -379,6 +391,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
include_payload=include_payload,
)
for query_vector in query_vectors
]

View file

@ -87,6 +87,7 @@ class VectorDBInterface(Protocol):
query_vector: Optional[List[float]],
limit: Optional[int],
with_vector: bool = False,
include_payload: bool = False,
):
"""
Perform a search in the specified collection using either a text query or a vector
@ -103,6 +104,9 @@ class VectorDBInterface(Protocol):
- limit (Optional[int]): The maximum number of results to return from the search.
- with_vector (bool): Whether to return the vector representations with search
results. (default False)
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
"""
raise NotImplementedError
@ -113,6 +117,7 @@ class VectorDBInterface(Protocol):
query_texts: List[str],
limit: Optional[int],
with_vectors: bool = False,
include_payload: bool = False,
):
"""
Perform a batch search using multiple text queries against a collection.
@ -125,6 +130,9 @@ class VectorDBInterface(Protocol):
- limit (Optional[int]): The maximum number of results to return for each query.
- with_vectors (bool): Whether to include vector representations with search
results. (default False)
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
"""
raise NotImplementedError

View file

@ -1,5 +1,6 @@
import time
from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any
from cognee.modules.graph.exceptions import (
@ -44,6 +45,12 @@ class CogneeGraph(CogneeAbstractGraph):
def add_edge(self, edge: Edge) -> None:
self.edges.append(edge)
edge_text = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
edge.attributes["edge_type_id"] = (
generate_edge_id(edge_id=edge_text) if edge_text else None
) # Update edge with generated edge_type_id
edge.node1.add_skeleton_edge(edge)
edge.node2.add_skeleton_edge(edge)
key = edge.get_distance_key()
@ -284,13 +291,7 @@ class CogneeGraph(CogneeAbstractGraph):
for query_index, scored_results in enumerate(per_query_scored_results):
for result in scored_results:
payload = getattr(result, "payload", None)
if not isinstance(payload, dict):
continue
text = payload.get("text")
if not text:
continue
matching_edges = self.edges_by_distance_key.get(str(text))
matching_edges = self.edges_by_distance_key.get(str(result.id))
if not matching_edges:
continue
for edge in matching_edges:

View file

@ -141,7 +141,7 @@ class Edge:
self.status = np.ones(dimension, dtype=int)
def get_distance_key(self) -> Optional[str]:
key = self.attributes.get("edge_text") or self.attributes.get("relationship_type")
key = self.attributes.get("edge_type_id")
if key is None:
return None
return str(key)

View file

@ -47,7 +47,9 @@ class ChunksRetriever(BaseRetriever):
vector_engine = get_vector_engine()
try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
found_chunks = await vector_engine.search(
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
)
logger.info(f"Found {len(found_chunks)} chunks from vector search")
await update_node_access_timestamps(found_chunks)

View file

@ -62,7 +62,9 @@ class CompletionRetriever(BaseRetriever):
vector_engine = get_vector_engine()
try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
found_chunks = await vector_engine.search(
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
)
if len(found_chunks) == 0:
return ""

View file

@ -52,7 +52,7 @@ class SummariesRetriever(BaseRetriever):
try:
summaries_results = await vector_engine.search(
"TextSummary_text", query, limit=self.top_k
"TextSummary_text", query, limit=self.top_k, include_payload=True
)
logger.info(f"Found {len(summaries_results)} summaries from vector search")

View file

@ -98,7 +98,7 @@ class TemporalRetriever(GraphCompletionRetriever):
async def filter_top_k_events(self, relevant_events, scored_results):
# Build a score lookup from vector search results
score_lookup = {res.payload["id"]: res.score for res in scored_results}
score_lookup = {res.id: res.score for res in scored_results}
events_with_scores = []
for event in relevant_events[0]["events"]:

View file

@ -67,7 +67,9 @@ class TripletRetriever(BaseRetriever):
"In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
)
found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
found_triplets = await vector_engine.search(
"Triplet_text", query, limit=self.top_k, include_payload=True
)
if len(found_triplets) == 0:
return ""

View file

@ -97,7 +97,7 @@ async def test_vector_engine_search_none_limit():
query_vector = (await vector_engine.embedding_engine.embed_text([query_text]))[0]
result = await vector_engine.search(
collection_name=collection_name, query_vector=query_vector, limit=None
collection_name=collection_name, query_vector=query_vector, limit=None, include_payload=True
)
# Check that we did not accidentally use any default value for limit

View file

@ -70,7 +70,9 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -149,7 +149,9 @@ async def main():
await test_getting_of_documents(dataset_name_1)
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -48,7 +48,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -63,7 +63,9 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -52,7 +52,9 @@ async def main():
await cognee.cognify([dataset_name])
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -163,7 +163,9 @@ async def main():
await test_getting_of_documents(dataset_name_1)
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -58,7 +58,9 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -43,7 +43,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -1,6 +1,7 @@
import pytest
from unittest.mock import AsyncMock
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
@ -379,7 +380,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
graph.add_edge(edge)
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
MockScoredResult(generate_edge_id("CONNECTS_TO"), 0.92, payload={"text": "CONNECTS_TO"}),
]
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
@ -404,8 +405,9 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_1_text = "CONNECTS_TO"
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text}),
]
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
@ -431,8 +433,9 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr
)
graph.add_edge(edge)
edge_text = "KNOWS"
edge_distances = [
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
MockScoredResult(generate_edge_id(edge_text), 0.85, payload={"text": edge_text}),
]
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
@ -457,8 +460,9 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
)
graph.add_edge(edge)
edge_text = "SOME_OTHER_EDGE"
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
MockScoredResult(generate_edge_id(edge_text), 0.92, payload={"text": edge_text}),
]
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
@ -511,9 +515,15 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_1_text = "A"
edge_2_text = "B"
edge_distances = [
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0
[MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1
[
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
], # query 0
[
MockScoredResult(generate_edge_id(edge_2_text), 0.2, payload={"text": edge_2_text})
], # query 1
]
await graph.map_vector_distances_to_graph_edges(
@ -541,8 +551,11 @@ async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(se
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_1_text = "A"
edge_distances = [
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped
[
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
], # query 0: only edge1 mapped
[], # query 1: no edges mapped
]

View file

@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine):
assert len(context) == 2
assert context[0]["text"] == "Steve Rodger"
assert context[1]["text"] == "Mike Broski"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5)
mock_vector_engine.search.assert_awaited_once_with(
"DocumentChunk_text", "test query", limit=5, include_payload=True
)
@pytest.mark.asyncio
@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
context = await retriever.get_context("test query")
assert len(context) == 3
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3)
mock_vector_engine.search.assert_awaited_once_with(
"DocumentChunk_text", "test query", limit=3, include_payload=True
)
@pytest.mark.asyncio

View file

@ -33,7 +33,9 @@ async def test_get_context_success(mock_vector_engine):
context = await retriever.get_context("test query")
assert context == "Steve Rodger\nMike Broski"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
mock_vector_engine.search.assert_awaited_once_with(
"DocumentChunk_text", "test query", limit=2, include_payload=True
)
@pytest.mark.asyncio
@ -85,7 +87,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
context = await retriever.get_context("test query")
assert context == "Chunk 0\nChunk 1"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
mock_vector_engine.search.assert_awaited_once_with(
"DocumentChunk_text", "test query", limit=2, include_payload=True
)
@pytest.mark.asyncio

View file

@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine):
assert len(context) == 2
assert context[0]["text"] == "S.R."
assert context[1]["text"] == "M.B."
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5)
mock_vector_engine.search.assert_awaited_once_with(
"TextSummary_text", "test query", limit=5, include_payload=True
)
@pytest.mark.asyncio
@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
context = await retriever.get_context("test query")
assert len(context) == 3
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3)
mock_vector_engine.search.assert_awaited_once_with(
"TextSummary_text", "test query", limit=3, include_payload=True
)
@pytest.mark.asyncio

View file

@ -63,8 +63,8 @@ async def test_filter_top_k_events_sorts_and_limits():
]
scored_results = [
SimpleNamespace(payload={"id": "e2"}, score=0.10),
SimpleNamespace(payload={"id": "e1"}, score=0.20),
SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.10),
SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.20),
]
top = await tr.filter_top_k_events(relevant_events, scored_results)
@ -91,8 +91,8 @@ async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k
]
scored_results = [
SimpleNamespace(payload={"id": "known2"}, score=0.05),
SimpleNamespace(payload={"id": "known1"}, score=0.50),
SimpleNamespace(id="known2", payload={"id": "known2"}, score=0.05),
SimpleNamespace(id="known1", payload={"id": "known1"}, score=0.50),
]
top = await tr.filter_top_k_events(relevant_events, scored_results)
@ -119,8 +119,8 @@ async def test_filter_top_k_events_limits_when_top_k_exceeds_events():
tr = TemporalRetriever(top_k=10)
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
scored_results = [
SimpleNamespace(payload={"id": "a"}, score=0.1),
SimpleNamespace(payload={"id": "b"}, score=0.2),
SimpleNamespace(id="a", payload={"id": "a"}, score=0.1),
SimpleNamespace(id="b", payload={"id": "b"}, score=0.2),
]
out = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in out] == ["a", "b"]
@ -179,8 +179,8 @@ async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine
}
]
mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05)
mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10)
mock_result1 = SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.05)
mock_result2 = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.10)
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
with (
@ -279,7 +279,7 @@ async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine)
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
@ -313,7 +313,7 @@ async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
@ -347,7 +347,7 @@ async def test_get_completion_without_context(mock_graph_engine, mock_vector_eng
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
@ -416,7 +416,7 @@ async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
mock_user = MagicMock()
@ -481,7 +481,7 @@ async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_ve
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
@ -570,7 +570,7 @@ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (

View file

@ -6,6 +6,7 @@ from cognee.modules.retrieval.utils.brute_force_triplet_search import (
get_memory_fragment,
format_triplets,
)
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
@ -1036,9 +1037,11 @@ async def test_cognee_graph_mapping_batch_shapes():
]
}
edge_1_text = "relates_to"
edge_2_text = "relates_to"
edge_distances_batch = [
[MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})],
[MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})],
[MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text})],
[MockScoredResult(generate_edge_id(edge_2_text), 0.88, payload={"text": edge_2_text})],
]
await graph.map_vector_distances_to_graph_nodes(

View file

@ -34,7 +34,9 @@ async def test_get_context_success(mock_vector_engine):
context = await retriever.get_context("test query")
assert context == "Alice knows Bob\nBob works at Tech Corp"
mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
mock_vector_engine.search.assert_awaited_once_with(
"Triplet_text", "test query", limit=5, include_payload=True
)
@pytest.mark.asyncio