Remove payload vector search [COG-3708] (#1998)
<!-- .github/pull_request_template.md -->
## Description
Make payload information optional when searching vector databases to
optimize performance
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
# Release Notes
* **New Features**
* Added optional `include_payload` parameter to search operations across
all vector database implementations. When enabled, search results
include complete payload data; disabled by default for improved
performance.
* **Improvements**
* Enhanced result efficiency by making payload optional in search
results.
* Standardized edge identifier generation for consistent graph
operations.
* Optimized data lookups to use direct IDs instead of payload
extraction.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
commit
6e69daa527
29 changed files with 161 additions and 63 deletions
|
|
@ -236,6 +236,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
query_vector: Optional[List[float]] = None,
|
query_vector: Optional[List[float]] = None,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
|
include_payload: bool = False, # TODO: Add support for this parameter
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a search in the specified collection using either a text query or a vector
|
Perform a search in the specified collection using either a text query or a vector
|
||||||
|
|
@ -319,7 +320,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
self._na_exception_handler(e, query_string)
|
self._na_exception_handler(e, query_string)
|
||||||
|
|
||||||
async def batch_search(
|
async def batch_search(
|
||||||
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
query_texts: List[str],
|
||||||
|
limit: int,
|
||||||
|
with_vectors: bool = False,
|
||||||
|
include_payload: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a batch search using multiple text queries against a collection.
|
Perform a batch search using multiple text queries against a collection.
|
||||||
|
|
@ -342,7 +348,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
||||||
data_vectors = await self.embedding_engine.embed_text(query_texts)
|
data_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||||
return await asyncio.gather(
|
return await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
self.search(collection_name, None, vector, limit, with_vectors)
|
self.search(
|
||||||
|
collection_name,
|
||||||
|
None,
|
||||||
|
vector,
|
||||||
|
limit,
|
||||||
|
with_vectors,
|
||||||
|
include_payload=include_payload,
|
||||||
|
)
|
||||||
for vector in data_vectors
|
for vector in data_vectors
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -355,6 +355,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
limit: Optional[int] = 15,
|
limit: Optional[int] = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
normalized: bool = True,
|
normalized: bool = True,
|
||||||
|
include_payload: bool = False, # TODO: Add support for this parameter when set to False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Search for items in a collection using either a text or a vector query.
|
Search for items in a collection using either a text or a vector query.
|
||||||
|
|
@ -441,6 +442,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
query_texts: List[str],
|
query_texts: List[str],
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
with_vectors: bool = False,
|
with_vectors: bool = False,
|
||||||
|
include_payload: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform multiple searches in a single request for efficiency, returning results for each
|
Perform multiple searches in a single request for efficiency, returning results for each
|
||||||
|
|
|
||||||
|
|
@ -231,6 +231,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
limit: Optional[int] = 15,
|
limit: Optional[int] = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
normalized: bool = True,
|
normalized: bool = True,
|
||||||
|
include_payload: bool = False,
|
||||||
):
|
):
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise MissingQueryParameterError()
|
raise MissingQueryParameterError()
|
||||||
|
|
@ -247,17 +248,27 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
if limit <= 0:
|
if limit <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
|
# Note: Exclude payload if not needed to optimize performance
|
||||||
|
select_columns = (
|
||||||
|
["id", "vector", "payload", "_distance"]
|
||||||
|
if include_payload
|
||||||
|
else ["id", "vector", "_distance"]
|
||||||
|
)
|
||||||
|
result_values = (
|
||||||
|
await collection.vector_search(query_vector)
|
||||||
|
.select(select_columns)
|
||||||
|
.limit(limit)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
|
||||||
if not result_values:
|
if not result_values:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
normalized_values = normalize_distances(result_values)
|
normalized_values = normalize_distances(result_values)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=parse_id(result["id"]),
|
id=parse_id(result["id"]),
|
||||||
payload=result["payload"],
|
payload=result["payload"] if include_payload else None,
|
||||||
score=normalized_values[value_index],
|
score=normalized_values[value_index],
|
||||||
)
|
)
|
||||||
for value_index, result in enumerate(result_values)
|
for value_index, result in enumerate(result_values)
|
||||||
|
|
@ -269,6 +280,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
query_texts: List[str],
|
query_texts: List[str],
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
with_vectors: bool = False,
|
with_vectors: bool = False,
|
||||||
|
include_payload: bool = False,
|
||||||
):
|
):
|
||||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||||
|
|
||||||
|
|
@ -279,6 +291,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
with_vector=with_vectors,
|
with_vector=with_vectors,
|
||||||
|
include_payload=include_payload,
|
||||||
)
|
)
|
||||||
for query_vector in query_vectors
|
for query_vector in query_vectors
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -12,10 +12,10 @@ class ScoredResult(BaseModel):
|
||||||
- id (UUID): Unique identifier for the scored result.
|
- id (UUID): Unique identifier for the scored result.
|
||||||
- score (float): The score associated with the result, where a lower score indicates a
|
- score (float): The score associated with the result, where a lower score indicates a
|
||||||
better outcome.
|
better outcome.
|
||||||
- payload (Dict[str, Any]): Additional information related to the score, stored as
|
- payload (Optional[Dict[str, Any]]): Additional information related to the score, stored as
|
||||||
key-value pairs in a dictionary.
|
key-value pairs in a dictionary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
score: float # Lower score is better
|
score: float # Lower score is better
|
||||||
payload: Dict[str, Any]
|
payload: Optional[Dict[str, Any]] = None
|
||||||
|
|
|
||||||
|
|
@ -301,6 +301,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
query_vector: Optional[List[float]] = None,
|
query_vector: Optional[List[float]] = None,
|
||||||
limit: Optional[int] = 15,
|
limit: Optional[int] = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
|
include_payload: bool = False,
|
||||||
) -> List[ScoredResult]:
|
) -> List[ScoredResult]:
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise MissingQueryParameterError()
|
raise MissingQueryParameterError()
|
||||||
|
|
@ -324,10 +325,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
# NOTE: This needs to be initialized in case search doesn't return a value
|
# NOTE: This needs to be initialized in case search doesn't return a value
|
||||||
closest_items = []
|
closest_items = []
|
||||||
|
|
||||||
|
# Note: Exclude payload from returned columns if not needed to optimize performance
|
||||||
|
select_columns = (
|
||||||
|
[PGVectorDataPoint]
|
||||||
|
if include_payload
|
||||||
|
else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector]
|
||||||
|
)
|
||||||
# Use async session to connect to the database
|
# Use async session to connect to the database
|
||||||
async with self.get_async_session() as session:
|
async with self.get_async_session() as session:
|
||||||
query = select(
|
query = select(
|
||||||
PGVectorDataPoint,
|
*select_columns,
|
||||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||||
).order_by("similarity")
|
).order_by("similarity")
|
||||||
|
|
||||||
|
|
@ -344,7 +351,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
vector_list.append(
|
vector_list.append(
|
||||||
{
|
{
|
||||||
"id": parse_id(str(vector.id)),
|
"id": parse_id(str(vector.id)),
|
||||||
"payload": vector.payload,
|
"payload": vector.payload if include_payload else None,
|
||||||
"_distance": vector.similarity,
|
"_distance": vector.similarity,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -359,7 +366,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
# Create and return ScoredResult objects
|
# Create and return ScoredResult objects
|
||||||
return [
|
return [
|
||||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
ScoredResult(
|
||||||
|
id=row.get("id"),
|
||||||
|
payload=row.get("payload") if include_payload else None,
|
||||||
|
score=row.get("score"),
|
||||||
|
)
|
||||||
for row in vector_list
|
for row in vector_list
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -369,6 +380,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
query_texts: List[str],
|
query_texts: List[str],
|
||||||
limit: int = None,
|
limit: int = None,
|
||||||
with_vectors: bool = False,
|
with_vectors: bool = False,
|
||||||
|
include_payload: bool = False,
|
||||||
):
|
):
|
||||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||||
|
|
||||||
|
|
@ -379,6 +391,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
with_vector=with_vectors,
|
with_vector=with_vectors,
|
||||||
|
include_payload=include_payload,
|
||||||
)
|
)
|
||||||
for query_vector in query_vectors
|
for query_vector in query_vectors
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,7 @@ class VectorDBInterface(Protocol):
|
||||||
query_vector: Optional[List[float]],
|
query_vector: Optional[List[float]],
|
||||||
limit: Optional[int],
|
limit: Optional[int],
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
|
include_payload: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a search in the specified collection using either a text query or a vector
|
Perform a search in the specified collection using either a text query or a vector
|
||||||
|
|
@ -103,6 +104,9 @@ class VectorDBInterface(Protocol):
|
||||||
- limit (Optional[int]): The maximum number of results to return from the search.
|
- limit (Optional[int]): The maximum number of results to return from the search.
|
||||||
- with_vector (bool): Whether to return the vector representations with search
|
- with_vector (bool): Whether to return the vector representations with search
|
||||||
results. (default False)
|
results. (default False)
|
||||||
|
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
|
||||||
|
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
|
||||||
|
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -113,6 +117,7 @@ class VectorDBInterface(Protocol):
|
||||||
query_texts: List[str],
|
query_texts: List[str],
|
||||||
limit: Optional[int],
|
limit: Optional[int],
|
||||||
with_vectors: bool = False,
|
with_vectors: bool = False,
|
||||||
|
include_payload: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a batch search using multiple text queries against a collection.
|
Perform a batch search using multiple text queries against a collection.
|
||||||
|
|
@ -125,6 +130,9 @@ class VectorDBInterface(Protocol):
|
||||||
- limit (Optional[int]): The maximum number of results to return for each query.
|
- limit (Optional[int]): The maximum number of results to return for each query.
|
||||||
- with_vectors (bool): Whether to include vector representations with search
|
- with_vectors (bool): Whether to include vector representations with search
|
||||||
results. (default False)
|
results. (default False)
|
||||||
|
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
|
||||||
|
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
|
||||||
|
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import time
|
import time
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||||
from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any
|
from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any
|
||||||
|
|
||||||
from cognee.modules.graph.exceptions import (
|
from cognee.modules.graph.exceptions import (
|
||||||
|
|
@ -44,6 +45,12 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
|
|
||||||
def add_edge(self, edge: Edge) -> None:
|
def add_edge(self, edge: Edge) -> None:
|
||||||
self.edges.append(edge)
|
self.edges.append(edge)
|
||||||
|
|
||||||
|
edge_text = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
|
||||||
|
edge.attributes["edge_type_id"] = (
|
||||||
|
generate_edge_id(edge_id=edge_text) if edge_text else None
|
||||||
|
) # Update edge with generated edge_type_id
|
||||||
|
|
||||||
edge.node1.add_skeleton_edge(edge)
|
edge.node1.add_skeleton_edge(edge)
|
||||||
edge.node2.add_skeleton_edge(edge)
|
edge.node2.add_skeleton_edge(edge)
|
||||||
key = edge.get_distance_key()
|
key = edge.get_distance_key()
|
||||||
|
|
@ -284,13 +291,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
|
|
||||||
for query_index, scored_results in enumerate(per_query_scored_results):
|
for query_index, scored_results in enumerate(per_query_scored_results):
|
||||||
for result in scored_results:
|
for result in scored_results:
|
||||||
payload = getattr(result, "payload", None)
|
matching_edges = self.edges_by_distance_key.get(str(result.id))
|
||||||
if not isinstance(payload, dict):
|
|
||||||
continue
|
|
||||||
text = payload.get("text")
|
|
||||||
if not text:
|
|
||||||
continue
|
|
||||||
matching_edges = self.edges_by_distance_key.get(str(text))
|
|
||||||
if not matching_edges:
|
if not matching_edges:
|
||||||
continue
|
continue
|
||||||
for edge in matching_edges:
|
for edge in matching_edges:
|
||||||
|
|
|
||||||
|
|
@ -141,7 +141,7 @@ class Edge:
|
||||||
self.status = np.ones(dimension, dtype=int)
|
self.status = np.ones(dimension, dtype=int)
|
||||||
|
|
||||||
def get_distance_key(self) -> Optional[str]:
|
def get_distance_key(self) -> Optional[str]:
|
||||||
key = self.attributes.get("edge_text") or self.attributes.get("relationship_type")
|
key = self.attributes.get("edge_type_id")
|
||||||
if key is None:
|
if key is None:
|
||||||
return None
|
return None
|
||||||
return str(key)
|
return str(key)
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,9 @@ class ChunksRetriever(BaseRetriever):
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
found_chunks = await vector_engine.search(
|
||||||
|
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
|
||||||
|
)
|
||||||
logger.info(f"Found {len(found_chunks)} chunks from vector search")
|
logger.info(f"Found {len(found_chunks)} chunks from vector search")
|
||||||
await update_node_access_timestamps(found_chunks)
|
await update_node_access_timestamps(found_chunks)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,9 @@ class CompletionRetriever(BaseRetriever):
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
found_chunks = await vector_engine.search(
|
||||||
|
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
if len(found_chunks) == 0:
|
if len(found_chunks) == 0:
|
||||||
return ""
|
return ""
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ class SummariesRetriever(BaseRetriever):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
summaries_results = await vector_engine.search(
|
summaries_results = await vector_engine.search(
|
||||||
"TextSummary_text", query, limit=self.top_k
|
"TextSummary_text", query, limit=self.top_k, include_payload=True
|
||||||
)
|
)
|
||||||
logger.info(f"Found {len(summaries_results)} summaries from vector search")
|
logger.info(f"Found {len(summaries_results)} summaries from vector search")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
async def filter_top_k_events(self, relevant_events, scored_results):
|
async def filter_top_k_events(self, relevant_events, scored_results):
|
||||||
# Build a score lookup from vector search results
|
# Build a score lookup from vector search results
|
||||||
score_lookup = {res.payload["id"]: res.score for res in scored_results}
|
score_lookup = {res.id: res.score for res in scored_results}
|
||||||
|
|
||||||
events_with_scores = []
|
events_with_scores = []
|
||||||
for event in relevant_events[0]["events"]:
|
for event in relevant_events[0]["events"]:
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,9 @@ class TripletRetriever(BaseRetriever):
|
||||||
"In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
|
"In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
|
||||||
)
|
)
|
||||||
|
|
||||||
found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
|
found_triplets = await vector_engine.search(
|
||||||
|
"Triplet_text", query, limit=self.top_k, include_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
if len(found_triplets) == 0:
|
if len(found_triplets) == 0:
|
||||||
return ""
|
return ""
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ async def test_vector_engine_search_none_limit():
|
||||||
query_vector = (await vector_engine.embedding_engine.embed_text([query_text]))[0]
|
query_vector = (await vector_engine.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
||||||
result = await vector_engine.search(
|
result = await vector_engine.search(
|
||||||
collection_name=collection_name, query_vector=query_vector, limit=None
|
collection_name=collection_name, query_vector=query_vector, limit=None, include_payload=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that we did not accidentally use any default value for limit
|
# Check that we did not accidentally use any default value for limit
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,9 @@ async def main():
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
random_node = (
|
||||||
|
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
|
||||||
|
)[0]
|
||||||
random_node_name = random_node.payload["text"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,9 @@ async def main():
|
||||||
await test_getting_of_documents(dataset_name_1)
|
await test_getting_of_documents(dataset_name_1)
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
random_node = (
|
||||||
|
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
|
||||||
|
)[0]
|
||||||
random_node_name = random_node.payload["text"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ async def main():
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0]
|
||||||
random_node_name = random_node.payload["text"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,9 @@ async def main():
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
random_node = (
|
||||||
|
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
|
||||||
|
)[0]
|
||||||
random_node_name = random_node.payload["text"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,9 @@ async def main():
|
||||||
await cognee.cognify([dataset_name])
|
await cognee.cognify([dataset_name])
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
random_node = (
|
||||||
|
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
|
||||||
|
)[0]
|
||||||
random_node_name = random_node.payload["text"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
|
|
|
||||||
|
|
@ -163,7 +163,9 @@ async def main():
|
||||||
await test_getting_of_documents(dataset_name_1)
|
await test_getting_of_documents(dataset_name_1)
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
random_node = (
|
||||||
|
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
|
||||||
|
)[0]
|
||||||
random_node_name = random_node.payload["text"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,9 @@ async def main():
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
random_node = (
|
||||||
|
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
|
||||||
|
)[0]
|
||||||
random_node_name = random_node.payload["text"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ async def main():
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0]
|
||||||
random_node_name = random_node.payload["text"]
|
random_node_name = random_node.payload["text"]
|
||||||
|
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||||
|
|
@ -379,7 +380,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
|
|
||||||
edge_distances = [
|
edge_distances = [
|
||||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
MockScoredResult(generate_edge_id("CONNECTS_TO"), 0.92, payload={"text": "CONNECTS_TO"}),
|
||||||
]
|
]
|
||||||
|
|
||||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||||
|
|
@ -404,8 +405,9 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
||||||
graph.add_edge(edge1)
|
graph.add_edge(edge1)
|
||||||
graph.add_edge(edge2)
|
graph.add_edge(edge2)
|
||||||
|
|
||||||
|
edge_1_text = "CONNECTS_TO"
|
||||||
edge_distances = [
|
edge_distances = [
|
||||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text}),
|
||||||
]
|
]
|
||||||
|
|
||||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||||
|
|
@ -431,8 +433,9 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr
|
||||||
)
|
)
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
|
|
||||||
|
edge_text = "KNOWS"
|
||||||
edge_distances = [
|
edge_distances = [
|
||||||
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
|
MockScoredResult(generate_edge_id(edge_text), 0.85, payload={"text": edge_text}),
|
||||||
]
|
]
|
||||||
|
|
||||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||||
|
|
@ -457,8 +460,9 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
|
||||||
)
|
)
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
|
|
||||||
|
edge_text = "SOME_OTHER_EDGE"
|
||||||
edge_distances = [
|
edge_distances = [
|
||||||
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
|
MockScoredResult(generate_edge_id(edge_text), 0.92, payload={"text": edge_text}),
|
||||||
]
|
]
|
||||||
|
|
||||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||||
|
|
@ -511,9 +515,15 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
|
||||||
graph.add_edge(edge1)
|
graph.add_edge(edge1)
|
||||||
graph.add_edge(edge2)
|
graph.add_edge(edge2)
|
||||||
|
|
||||||
|
edge_1_text = "A"
|
||||||
|
edge_2_text = "B"
|
||||||
edge_distances = [
|
edge_distances = [
|
||||||
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0
|
[
|
||||||
[MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1
|
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
|
||||||
|
], # query 0
|
||||||
|
[
|
||||||
|
MockScoredResult(generate_edge_id(edge_2_text), 0.2, payload={"text": edge_2_text})
|
||||||
|
], # query 1
|
||||||
]
|
]
|
||||||
|
|
||||||
await graph.map_vector_distances_to_graph_edges(
|
await graph.map_vector_distances_to_graph_edges(
|
||||||
|
|
@ -541,8 +551,11 @@ async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(se
|
||||||
graph.add_edge(edge1)
|
graph.add_edge(edge1)
|
||||||
graph.add_edge(edge2)
|
graph.add_edge(edge2)
|
||||||
|
|
||||||
|
edge_1_text = "A"
|
||||||
edge_distances = [
|
edge_distances = [
|
||||||
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped
|
[
|
||||||
|
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
|
||||||
|
], # query 0: only edge1 mapped
|
||||||
[], # query 1: no edges mapped
|
[], # query 1: no edges mapped
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine):
|
||||||
assert len(context) == 2
|
assert len(context) == 2
|
||||||
assert context[0]["text"] == "Steve Rodger"
|
assert context[0]["text"] == "Steve Rodger"
|
||||||
assert context[1]["text"] == "Mike Broski"
|
assert context[1]["text"] == "Mike Broski"
|
||||||
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5)
|
mock_vector_engine.search.assert_awaited_once_with(
|
||||||
|
"DocumentChunk_text", "test query", limit=5, include_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
|
||||||
context = await retriever.get_context("test query")
|
context = await retriever.get_context("test query")
|
||||||
|
|
||||||
assert len(context) == 3
|
assert len(context) == 3
|
||||||
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3)
|
mock_vector_engine.search.assert_awaited_once_with(
|
||||||
|
"DocumentChunk_text", "test query", limit=3, include_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,9 @@ async def test_get_context_success(mock_vector_engine):
|
||||||
context = await retriever.get_context("test query")
|
context = await retriever.get_context("test query")
|
||||||
|
|
||||||
assert context == "Steve Rodger\nMike Broski"
|
assert context == "Steve Rodger\nMike Broski"
|
||||||
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
|
mock_vector_engine.search.assert_awaited_once_with(
|
||||||
|
"DocumentChunk_text", "test query", limit=2, include_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -85,7 +87,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
|
||||||
context = await retriever.get_context("test query")
|
context = await retriever.get_context("test query")
|
||||||
|
|
||||||
assert context == "Chunk 0\nChunk 1"
|
assert context == "Chunk 0\nChunk 1"
|
||||||
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
|
mock_vector_engine.search.assert_awaited_once_with(
|
||||||
|
"DocumentChunk_text", "test query", limit=2, include_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine):
|
||||||
assert len(context) == 2
|
assert len(context) == 2
|
||||||
assert context[0]["text"] == "S.R."
|
assert context[0]["text"] == "S.R."
|
||||||
assert context[1]["text"] == "M.B."
|
assert context[1]["text"] == "M.B."
|
||||||
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5)
|
mock_vector_engine.search.assert_awaited_once_with(
|
||||||
|
"TextSummary_text", "test query", limit=5, include_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
|
||||||
context = await retriever.get_context("test query")
|
context = await retriever.get_context("test query")
|
||||||
|
|
||||||
assert len(context) == 3
|
assert len(context) == 3
|
||||||
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3)
|
mock_vector_engine.search.assert_awaited_once_with(
|
||||||
|
"TextSummary_text", "test query", limit=3, include_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
|
|
@ -63,8 +63,8 @@ async def test_filter_top_k_events_sorts_and_limits():
|
||||||
]
|
]
|
||||||
|
|
||||||
scored_results = [
|
scored_results = [
|
||||||
SimpleNamespace(payload={"id": "e2"}, score=0.10),
|
SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.10),
|
||||||
SimpleNamespace(payload={"id": "e1"}, score=0.20),
|
SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.20),
|
||||||
]
|
]
|
||||||
|
|
||||||
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||||
|
|
@ -91,8 +91,8 @@ async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k
|
||||||
]
|
]
|
||||||
|
|
||||||
scored_results = [
|
scored_results = [
|
||||||
SimpleNamespace(payload={"id": "known2"}, score=0.05),
|
SimpleNamespace(id="known2", payload={"id": "known2"}, score=0.05),
|
||||||
SimpleNamespace(payload={"id": "known1"}, score=0.50),
|
SimpleNamespace(id="known1", payload={"id": "known1"}, score=0.50),
|
||||||
]
|
]
|
||||||
|
|
||||||
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||||
|
|
@ -119,8 +119,8 @@ async def test_filter_top_k_events_limits_when_top_k_exceeds_events():
|
||||||
tr = TemporalRetriever(top_k=10)
|
tr = TemporalRetriever(top_k=10)
|
||||||
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
|
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
|
||||||
scored_results = [
|
scored_results = [
|
||||||
SimpleNamespace(payload={"id": "a"}, score=0.1),
|
SimpleNamespace(id="a", payload={"id": "a"}, score=0.1),
|
||||||
SimpleNamespace(payload={"id": "b"}, score=0.2),
|
SimpleNamespace(id="b", payload={"id": "b"}, score=0.2),
|
||||||
]
|
]
|
||||||
out = await tr.filter_top_k_events(relevant_events, scored_results)
|
out = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||||
assert [e["id"] for e in out] == ["a", "b"]
|
assert [e["id"] for e in out] == ["a", "b"]
|
||||||
|
|
@ -179,8 +179,8 @@ async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05)
|
mock_result1 = SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.05)
|
||||||
mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10)
|
mock_result2 = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.10)
|
||||||
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
|
|
@ -279,7 +279,7 @@ async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine)
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
||||||
mock_vector_engine.search.return_value = [mock_result]
|
mock_vector_engine.search.return_value = [mock_result]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
|
|
@ -313,7 +313,7 @@ async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
||||||
mock_vector_engine.search.return_value = [mock_result]
|
mock_vector_engine.search.return_value = [mock_result]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
|
|
@ -347,7 +347,7 @@ async def test_get_completion_without_context(mock_graph_engine, mock_vector_eng
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
||||||
mock_vector_engine.search.return_value = [mock_result]
|
mock_vector_engine.search.return_value = [mock_result]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
|
|
@ -416,7 +416,7 @@ async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
||||||
mock_vector_engine.search.return_value = [mock_result]
|
mock_vector_engine.search.return_value = [mock_result]
|
||||||
|
|
||||||
mock_user = MagicMock()
|
mock_user = MagicMock()
|
||||||
|
|
@ -481,7 +481,7 @@ async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_ve
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
||||||
mock_vector_engine.search.return_value = [mock_result]
|
mock_vector_engine.search.return_value = [mock_result]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
|
|
@ -570,7 +570,7 @@ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
||||||
mock_vector_engine.search.return_value = [mock_result]
|
mock_vector_engine.search.return_value = [mock_result]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
||||||
get_memory_fragment,
|
get_memory_fragment,
|
||||||
format_triplets,
|
format_triplets,
|
||||||
)
|
)
|
||||||
|
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||||
|
|
@ -1036,9 +1037,11 @@ async def test_cognee_graph_mapping_batch_shapes():
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
edge_1_text = "relates_to"
|
||||||
|
edge_2_text = "relates_to"
|
||||||
edge_distances_batch = [
|
edge_distances_batch = [
|
||||||
[MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})],
|
[MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text})],
|
||||||
[MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})],
|
[MockScoredResult(generate_edge_id(edge_2_text), 0.88, payload={"text": edge_2_text})],
|
||||||
]
|
]
|
||||||
|
|
||||||
await graph.map_vector_distances_to_graph_nodes(
|
await graph.map_vector_distances_to_graph_nodes(
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,9 @@ async def test_get_context_success(mock_vector_engine):
|
||||||
context = await retriever.get_context("test query")
|
context = await retriever.get_context("test query")
|
||||||
|
|
||||||
assert context == "Alice knows Bob\nBob works at Tech Corp"
|
assert context == "Alice knows Bob\nBob works at Tech Corp"
|
||||||
mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
|
mock_vector_engine.search.assert_awaited_once_with(
|
||||||
|
"Triplet_text", "test query", limit=5, include_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue