Merge branch 'dev' into feature/cog-3409-add-bedrock-as-supported-llm-provider
This commit is contained in:
commit
8cad9ef225
41 changed files with 9132 additions and 5607 deletions
127
.coderabbit.yaml
Normal file
127
.coderabbit.yaml
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
# .coderabbit.yaml
|
||||
language: en
|
||||
early_access: false
|
||||
enable_free_tier: true
|
||||
reviews:
|
||||
profile: chill
|
||||
instructions: >-
|
||||
# Code Review Instructions
|
||||
|
||||
- Ensure the code follows best practices and coding standards.
|
||||
- For **Python** code, follow
|
||||
[PEP 20](https://www.python.org/dev/peps/pep-0020/) and
|
||||
[CEP-8](https://gist.github.com/reactive-firewall/b7ee98df9e636a51806e62ef9c4ab161)
|
||||
standards.
|
||||
|
||||
# Documentation Review Instructions
|
||||
- Verify that documentation and comments are clear and comprehensive.
|
||||
- Verify that documentation and comments are free of spelling mistakes.
|
||||
|
||||
# Test Code Review Instructions
|
||||
- Ensure that test code is automated, comprehensive, and follows testing best practices.
|
||||
- Verify that all critical functionality is covered by tests.
|
||||
- Ensure that test code follow
|
||||
[CEP-8](https://gist.github.com/reactive-firewall/d840ee9990e65f302ce2a8d78ebe73f6)
|
||||
|
||||
# Misc.
|
||||
- Confirm that the code meets the project's requirements and objectives.
|
||||
- Confirm that copyright years are up-to date whenever a file is changed.
|
||||
request_changes_workflow: false
|
||||
high_level_summary: true
|
||||
high_level_summary_placeholder: '@coderabbitai summary'
|
||||
auto_title_placeholder: '@coderabbitai'
|
||||
review_status: true
|
||||
poem: false
|
||||
collapse_walkthrough: false
|
||||
sequence_diagrams: false
|
||||
changed_files_summary: true
|
||||
path_filters: ['!*.xc*/**', '!node_modules/**', '!dist/**', '!build/**', '!.git/**', '!venv/**', '!__pycache__/**']
|
||||
path_instructions:
|
||||
- path: README.md
|
||||
instructions: >-
|
||||
1. Consider the file 'README.md' the overview/introduction of the project.
|
||||
Also consider the 'README.md' file the first place to look for project documentation.
|
||||
|
||||
2. When reviewing the file 'README.md' it should be linted with help
|
||||
from the tools `markdownlint` and `languagetool`, pointing out any issues.
|
||||
|
||||
3. You may assume the file 'README.md' will contain GitHub flavor Markdown.
|
||||
- path: '**/*.py'
|
||||
instructions: >-
|
||||
When reviewing Python code for this project:
|
||||
|
||||
1. Prioritize portability over clarity, especially when dealing with cross-Python compatibility. However, with the priority in mind, do still consider improvements to clarity when relevant.
|
||||
|
||||
2. As a general guideline, consider the code style advocated in the PEP 8 standard (excluding the use of spaces for indentation) and evaluate suggested changes for code style compliance.
|
||||
|
||||
3. As a style convention, consider the code style advocated in [CEP-8](https://gist.github.com/reactive-firewall/b7ee98df9e636a51806e62ef9c4ab161) and evaluate suggested changes for code style compliance.
|
||||
|
||||
4. As a general guideline, try to provide any relevant, official, and supporting documentation links to any tool's suggestions in review comments. This guideline is important for posterity.
|
||||
|
||||
5. As a general rule, undocumented function definitions and class definitions in the project's Python code are assumed incomplete. Please consider suggesting a short summary of the code for any of these incomplete definitions as docstrings when reviewing.
|
||||
- path: cognee/tests/*
|
||||
instructions: >-
|
||||
When reviewing test code:
|
||||
|
||||
1. Prioritize portability over clarity, especially when dealing with cross-Python compatibility. However, with the priority in mind, do still consider improvements to clarity when relevant.
|
||||
|
||||
2. As a general guideline, consider the code style advocated in the PEP 8 standard (excluding the use of spaces for indentation) and evaluate suggested changes for code style compliance.
|
||||
|
||||
3. As a style convention, consider the code style advocated in [CEP-8](https://gist.github.com/reactive-firewall/b7ee98df9e636a51806e62ef9c4ab161) and evaluate suggested changes for code style compliance, pointing out any violations discovered.
|
||||
|
||||
4. As a general guideline, try to provide any relevant, official, and supporting documentation links to any tool's suggestions in review comments. This guideline is important for posterity.
|
||||
|
||||
5. As a project rule, Python source files with names prefixed by the string "test_" and located in the project's "tests" directory are the project's unit-testing code. It is safe, albeit a heuristic, to assume these are considered part of the project's minimal acceptance testing unless a justifying exception to this assumption is documented.
|
||||
|
||||
6. As a project rule, any files without extensions and with names prefixed by either the string "check_" or the string "test_", and located in the project's "tests" directory, are the project's non-unit test code. "Non-unit test" in this context refers to any type of testing other than unit testing, such as (but not limited to) functional testing, style linting, regression testing, etc. It can also be assumed that non-unit testing code is usually written as Bash shell scripts.
|
||||
- path: requirements.txt
|
||||
instructions: >-
|
||||
* The project's own Python dependencies are recorded in 'requirements.txt' for production code.
|
||||
|
||||
* The project's testing-specific Python dependencies are recorded in 'tests/requirements.txt' and are used for testing the project.
|
||||
|
||||
* The project's documentation-specific Python dependencies are recorded in 'docs/requirements.txt' and are used only for generating Python-focused documentation for the project. 'docs/requirements.txt' may be absent if not applicable.
|
||||
|
||||
Consider these 'requirements.txt' files the records of truth regarding project dependencies.
|
||||
- path: .github/**
|
||||
instructions: >-
|
||||
* When the project is hosted on GitHub: All GitHub-specific configurations, templates, and tools should be found in the '.github' directory tree.
|
||||
|
||||
* 'actionlint' erroneously generates false positives when dealing with GitHub's `${{ ... }}` syntax in conditionals.
|
||||
|
||||
* 'actionlint' erroneously generates incorrect solutions when suggesting the removal of valid `${{ ... }}` syntax.
|
||||
abort_on_close: true
|
||||
auto_review:
|
||||
enabled: true
|
||||
auto_incremental_review: true
|
||||
ignore_title_keywords: []
|
||||
labels: []
|
||||
drafts: false
|
||||
base_branches:
|
||||
- dev
|
||||
- main
|
||||
tools:
|
||||
shellcheck:
|
||||
enabled: true
|
||||
ruff:
|
||||
enabled: true
|
||||
configuration:
|
||||
extend_select:
|
||||
- E # Pycodestyle errors (style issues)
|
||||
- F # PyFlakes codes (logical errors)
|
||||
- W # Pycodestyle warnings
|
||||
- N # PEP 8 naming conventions
|
||||
ignore:
|
||||
- W191
|
||||
- W391
|
||||
- E117
|
||||
- D208
|
||||
line_length: 100
|
||||
dummy_variable_rgx: '^(_.*|junk|extra)$' # Variables starting with '_' or named 'junk' or 'extras', are considered dummy variables
|
||||
markdownlint:
|
||||
enabled: true
|
||||
yamllint:
|
||||
enabled: true
|
||||
chat:
|
||||
auto_reply: true
|
||||
17
.github/workflows/test_ollama.yml
vendored
17
.github/workflows/test_ollama.yml
vendored
|
|
@ -7,13 +7,8 @@ jobs:
|
|||
|
||||
run_ollama_test:
|
||||
|
||||
# needs 16 Gb RAM for phi4
|
||||
runs-on: buildjet-4vcpu-ubuntu-2204
|
||||
# services:
|
||||
# ollama:
|
||||
# image: ollama/ollama
|
||||
# ports:
|
||||
# - 11434:11434
|
||||
# needs 32 Gb RAM for phi4 in a container
|
||||
runs-on: buildjet-8vcpu-ubuntu-2204
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
|
@ -28,14 +23,6 @@ jobs:
|
|||
run: |
|
||||
uv add torch
|
||||
|
||||
# - name: Install ollama
|
||||
# run: curl -fsSL https://ollama.com/install.sh | sh
|
||||
# - name: Run ollama
|
||||
# run: |
|
||||
# ollama serve --openai &
|
||||
# ollama pull llama3.2 &
|
||||
# ollama pull avr/sfr-embedding-mistral:latest
|
||||
|
||||
- name: Start Ollama container
|
||||
run: |
|
||||
docker run -d --name ollama -p 11434:11434 ollama/ollama
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ async def search(
|
|||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[List[SearchResult], CombinedSearchResult]:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
|
@ -200,6 +202,8 @@ async def search(
|
|||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
29
cognee/eval_framework/Dockerfile
Normal file
29
cognee/eval_framework/Dockerfile
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
FROM python:3.11-slim
|
||||
|
||||
# Set environment variables
|
||||
ENV PIP_NO_CACHE_DIR=true
|
||||
ENV PATH="${PATH}:/root/.poetry/bin"
|
||||
ENV PYTHONPATH=/app
|
||||
ENV SKIP_MIGRATIONS=true
|
||||
|
||||
# System dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
git \
|
||||
curl \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml poetry.lock README.md /app/
|
||||
|
||||
RUN pip install poetry
|
||||
|
||||
RUN poetry config virtualenvs.create false
|
||||
|
||||
RUN poetry install --extras distributed --extras evals --extras deepeval --no-root
|
||||
|
||||
COPY cognee/ /app/cognee
|
||||
COPY distributed/ /app/distributed
|
||||
|
|
@ -35,6 +35,16 @@ class AnswerGeneratorExecutor:
|
|||
retrieval_context = await retriever.get_context(query_text)
|
||||
search_results = await retriever.get_completion(query_text, retrieval_context)
|
||||
|
||||
############
|
||||
#:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure..
|
||||
if isinstance(retrieval_context, list):
|
||||
retrieval_context = await retriever.convert_retrieved_objects_to_context(
|
||||
triplets=retrieval_context
|
||||
)
|
||||
|
||||
if isinstance(search_results, str):
|
||||
search_results = [search_results]
|
||||
#############
|
||||
answer = {
|
||||
"question": query_text,
|
||||
"answer": search_results[0],
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload):
|
|||
|
||||
|
||||
async def run_question_answering(
|
||||
params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
|
||||
params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
if params.get("answering_questions"):
|
||||
logger.info("Question answering started...")
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class EvalConfig(BaseSettings):
|
|||
|
||||
# Question answering params
|
||||
answering_questions: bool = True
|
||||
qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
|
||||
qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
|
||||
|
||||
# Evaluation params
|
||||
evaluating_answers: bool = True
|
||||
|
|
@ -25,7 +25,7 @@ class EvalConfig(BaseSettings):
|
|||
"EM",
|
||||
"f1",
|
||||
] # Use only 'correctness' for DirectLLM
|
||||
deepeval_model: str = "gpt-5-mini"
|
||||
deepeval_model: str = "gpt-4o-mini"
|
||||
|
||||
# Metrics params
|
||||
calculate_metrics: bool = True
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import modal
|
|||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.eval_framework.eval_config import EvalConfig
|
||||
|
|
@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b
|
|||
from cognee.eval_framework.answer_generation.run_question_answering_module import (
|
||||
run_question_answering,
|
||||
)
|
||||
import pathlib
|
||||
from os import path
|
||||
from modal import Image
|
||||
from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation
|
||||
from cognee.eval_framework.metrics_dashboard import create_dashboard
|
||||
|
||||
|
|
@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict:
|
|||
|
||||
app = modal.App("modal-run-eval")
|
||||
|
||||
image = (
|
||||
modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
|
||||
.copy_local_file("pyproject.toml", "pyproject.toml")
|
||||
.copy_local_file("poetry.lock", "poetry.lock")
|
||||
.env(
|
||||
{
|
||||
"ENV": os.getenv("ENV"),
|
||||
"LLM_API_KEY": os.getenv("LLM_API_KEY"),
|
||||
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
|
||||
}
|
||||
)
|
||||
.pip_install("protobuf", "h2", "deepeval", "gdown", "plotly")
|
||||
image = Image.from_dockerfile(
|
||||
path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(),
|
||||
force_build=False,
|
||||
).add_local_python_source("cognee")
|
||||
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
max_containers=10,
|
||||
timeout=86400,
|
||||
volumes={"/data": vol},
|
||||
secrets=[modal.Secret.from_name("eval_secrets")],
|
||||
)
|
||||
|
||||
|
||||
@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol})
|
||||
async def modal_run_eval(eval_params=None):
|
||||
"""Runs evaluation pipeline and returns combined metrics results."""
|
||||
if eval_params is None:
|
||||
|
|
@ -105,18 +104,7 @@ async def main():
|
|||
configs = [
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
benchmark="HotPotQA",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
answering_questions=True,
|
||||
evaluating_answers=True,
|
||||
calculate_metrics=True,
|
||||
dashboard=True,
|
||||
),
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
number_of_samples_in_corpus=25,
|
||||
benchmark="TwoWikiMultiHop",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
|
|
@ -127,7 +115,7 @@ async def main():
|
|||
),
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
number_of_samples_in_corpus=25,
|
||||
benchmark="Musique",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
|
|
|
|||
|
|
@ -398,3 +398,18 @@ class GraphDBInterface(ABC):
|
|||
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
) -> Tuple[List[Node], List[EdgeData]]:
|
||||
"""
|
||||
Retrieve nodes and edges filtered by the provided attribute criteria.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- attribute_filters: A list of dictionaries where keys are attribute names and values
|
||||
are lists of attribute values to filter by.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
|
||||
from cognee.exceptions import CogneeValidationError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
|
|
@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface):
|
|||
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
|
||||
tuples of (source_id, target_id, relationship_name, properties).
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
nodes_query = """
|
||||
MATCH (n:Node)
|
||||
|
|
@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface):
|
|||
},
|
||||
)
|
||||
)
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
|
||||
)
|
||||
return formatted_nodes, formatted_edges
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get graph data: {e}")
|
||||
|
|
@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface):
|
|||
formatted_edges.append((source_id, target_id, rel_type, props))
|
||||
return formatted_nodes, formatted_edges
|
||||
|
||||
async def get_id_filtered_graph_data(self, target_ids: list[str]):
|
||||
"""
|
||||
Retrieve graph data filtered by specific node IDs, including their direct neighbors
|
||||
and only edges where one endpoint matches those IDs.
|
||||
|
||||
Returns:
|
||||
nodes: List[dict] -> Each dict includes "id" and all node properties
|
||||
edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if not target_ids:
|
||||
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
|
||||
return [], []
|
||||
|
||||
if not all(isinstance(x, str) for x in target_ids):
|
||||
raise CogneeValidationError("target_ids must be a list of strings")
|
||||
|
||||
query = """
|
||||
MATCH (n:Node)-[r]->(m:Node)
|
||||
WHERE n.id IN $target_ids OR m.id IN $target_ids
|
||||
RETURN n.id, {
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}, m.id, {
|
||||
name: m.name,
|
||||
type: m.type,
|
||||
properties: m.properties
|
||||
}, r.relationship_name, r.properties
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"target_ids": target_ids})
|
||||
|
||||
if not result:
|
||||
logger.info("No data returned for the supplied IDs")
|
||||
return [], []
|
||||
|
||||
nodes_dict = {}
|
||||
edges = []
|
||||
|
||||
for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
|
||||
if n_props.get("properties"):
|
||||
try:
|
||||
additional_props = json.loads(n_props["properties"])
|
||||
n_props.update(additional_props)
|
||||
del n_props["properties"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {n_id}")
|
||||
|
||||
if m_props.get("properties"):
|
||||
try:
|
||||
additional_props = json.loads(m_props["properties"])
|
||||
m_props.update(additional_props)
|
||||
del m_props["properties"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {m_id}")
|
||||
|
||||
nodes_dict[n_id] = (n_id, n_props)
|
||||
nodes_dict[m_id] = (m_id, m_props)
|
||||
|
||||
edge_props = {}
|
||||
if r_props_raw:
|
||||
try:
|
||||
edge_props = json.loads(r_props_raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
|
||||
|
||||
source_id = edge_props.get("source_node_id", n_id)
|
||||
target_id = edge_props.get("target_node_id", m_id)
|
||||
edges.append((source_id, target_id, r_type, edge_props))
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
|
||||
)
|
||||
|
||||
return list(nodes_dict.values()), edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
|
||||
"""
|
||||
Get metrics on graph structure and connectivity.
|
||||
|
|
@ -1908,3 +2005,134 @@ class KuzuAdapter(GraphDBInterface):
|
|||
time_ids_list = [item[0] for item in time_nodes]
|
||||
|
||||
return ", ".join(f"'{uid}'" for uid in time_ids_list)
|
||||
|
||||
async def get_triplets_batch(self, offset: int, limit: int) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve a batch of triplets (start_node, relationship, end_node) from the graph.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- offset (int): Number of triplets to skip before returning results.
|
||||
- limit (int): Maximum number of triplets to return.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- list[dict[str, Any]]: A list of triplets, where each triplet is a dictionary
|
||||
with keys: 'start_node', 'relationship_properties', 'end_node'.
|
||||
|
||||
Raises:
|
||||
-------
|
||||
- ValueError: If offset or limit are negative.
|
||||
- Exception: Re-raises any exceptions from query execution.
|
||||
"""
|
||||
if offset < 0:
|
||||
raise ValueError(f"Offset must be non-negative, got {offset}")
|
||||
if limit < 0:
|
||||
raise ValueError(f"Limit must be non-negative, got {limit}")
|
||||
|
||||
query = """
|
||||
MATCH (start_node:Node)-[relationship:EDGE]->(end_node:Node)
|
||||
RETURN {
|
||||
start_node: {
|
||||
id: start_node.id,
|
||||
name: start_node.name,
|
||||
type: start_node.type,
|
||||
properties: start_node.properties
|
||||
},
|
||||
relationship_properties: {
|
||||
relationship_name: relationship.relationship_name,
|
||||
properties: relationship.properties
|
||||
},
|
||||
end_node: {
|
||||
id: end_node.id,
|
||||
name: end_node.name,
|
||||
type: end_node.type,
|
||||
properties: end_node.properties
|
||||
}
|
||||
} AS triplet
|
||||
SKIP $offset LIMIT $limit
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await self.query(query, {"offset": offset, "limit": limit})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute triplet query: {str(e)}")
|
||||
logger.error(f"Query: {query}")
|
||||
logger.error(f"Parameters: offset={offset}, limit={limit}")
|
||||
raise
|
||||
|
||||
triplets = []
|
||||
for idx, row in enumerate(results):
|
||||
try:
|
||||
if not row or len(row) == 0:
|
||||
logger.warning(f"Skipping empty row at index {idx} in triplet batch")
|
||||
continue
|
||||
|
||||
if not isinstance(row[0], dict):
|
||||
logger.warning(
|
||||
f"Skipping invalid row at index {idx}: expected dict, got {type(row[0])}"
|
||||
)
|
||||
continue
|
||||
|
||||
triplet = row[0]
|
||||
|
||||
if "start_node" not in triplet:
|
||||
logger.warning(f"Skipping triplet at index {idx}: missing 'start_node' key")
|
||||
continue
|
||||
|
||||
if not isinstance(triplet["start_node"], dict):
|
||||
logger.warning(f"Skipping triplet at index {idx}: 'start_node' is not a dict")
|
||||
continue
|
||||
|
||||
triplet["start_node"] = self._parse_node_properties(triplet["start_node"].copy())
|
||||
|
||||
if "relationship_properties" not in triplet:
|
||||
logger.warning(
|
||||
f"Skipping triplet at index {idx}: missing 'relationship_properties' key"
|
||||
)
|
||||
continue
|
||||
|
||||
if not isinstance(triplet["relationship_properties"], dict):
|
||||
logger.warning(
|
||||
f"Skipping triplet at index {idx}: 'relationship_properties' is not a dict"
|
||||
)
|
||||
continue
|
||||
|
||||
rel_props = triplet["relationship_properties"].copy()
|
||||
relationship_name = rel_props.get("relationship_name") or ""
|
||||
|
||||
if rel_props.get("properties"):
|
||||
try:
|
||||
parsed_props = json.loads(rel_props["properties"])
|
||||
if isinstance(parsed_props, dict):
|
||||
rel_props.update(parsed_props)
|
||||
del rel_props["properties"]
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed relationship properties is not a dict for triplet at index {idx}"
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"Failed to parse relationship properties JSON for triplet at index {idx}: {e}"
|
||||
)
|
||||
|
||||
rel_props["relationship_name"] = relationship_name
|
||||
triplet["relationship_properties"] = rel_props
|
||||
|
||||
if "end_node" not in triplet:
|
||||
logger.warning(f"Skipping triplet at index {idx}: missing 'end_node' key")
|
||||
continue
|
||||
|
||||
if not isinstance(triplet["end_node"], dict):
|
||||
logger.warning(f"Skipping triplet at index {idx}: 'end_node' is not a dict")
|
||||
continue
|
||||
|
||||
triplet["end_node"] = self._parse_node_properties(triplet["end_node"].copy())
|
||||
|
||||
triplets.append(triplet)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing triplet at index {idx}: {e}", exc_info=True)
|
||||
continue
|
||||
|
||||
return triplets
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from neo4j import AsyncSession
|
|||
from neo4j import AsyncGraphDatabase
|
||||
from neo4j.exceptions import Neo4jError
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple, Coroutine
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||
|
|
@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
logger.error(f"Error during graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_id_filtered_graph_data(self, target_ids: list[str]):
|
||||
"""
|
||||
Retrieve graph data filtered by specific node IDs, including their direct neighbors
|
||||
and only edges where one endpoint matches those IDs.
|
||||
|
||||
This version uses a single Cypher query for efficiency.
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if not target_ids:
|
||||
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
|
||||
return [], []
|
||||
|
||||
query = """
|
||||
MATCH ()-[r]-()
|
||||
WHERE startNode(r).id IN $target_ids
|
||||
OR endNode(r).id IN $target_ids
|
||||
WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b
|
||||
RETURN
|
||||
properties(a) AS n_properties,
|
||||
properties(b) AS m_properties,
|
||||
type(r) AS type,
|
||||
properties(r) AS properties
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"target_ids": target_ids})
|
||||
|
||||
nodes_dict = {}
|
||||
edges = []
|
||||
|
||||
for record in result:
|
||||
n_props = record["n_properties"]
|
||||
m_props = record["m_properties"]
|
||||
r_props = record["properties"]
|
||||
r_type = record["type"]
|
||||
|
||||
nodes_dict[n_props["id"]] = (n_props["id"], n_props)
|
||||
nodes_dict[m_props["id"]] = (m_props["id"], m_props)
|
||||
|
||||
source_id = r_props.get("source_node_id", n_props["id"])
|
||||
target_id = r_props.get("target_node_id", m_props["id"])
|
||||
edges.append((source_id, target_id, r_type, r_props))
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
|
||||
)
|
||||
|
||||
return list(nodes_dict.values()), edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
|
|
@ -1470,3 +1527,25 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
|
||||
|
||||
return ", ".join(f"'{uid}'" for uid in time_ids_list)
|
||||
|
||||
async def get_triplets_batch(self, offset: int, limit: int) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve a batch of triplets (start_node, relationship, end_node) from the graph.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- offset (int): Number of triplets to skip before returning results.
|
||||
- limit (int): Maximum number of triplets to return.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- list[dict[str, Any]]: A list of triplets.
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (start_node:`{BASE_LABEL}`)-[relationship]->(end_node:`{BASE_LABEL}`)
|
||||
RETURN start_node, properties(relationship) AS relationship_properties, end_node
|
||||
SKIP $offset LIMIT $limit
|
||||
"""
|
||||
results = await self.query(query, {"offset": offset, "limit": limit})
|
||||
|
||||
return results
|
||||
|
|
|
|||
53
cognee/memify_pipelines/create_triplet_embeddings.py
Normal file
53
cognee/memify_pipelines/create_triplet_embeddings.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from typing import Any
|
||||
|
||||
from cognee import memify
|
||||
from cognee.context_global_variables import (
|
||||
set_database_global_context_variables,
|
||||
)
|
||||
from cognee.exceptions import CogneeValidationError
|
||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.tasks.memify.get_triplet_datapoints import get_triplet_datapoints
|
||||
from cognee.tasks.storage import index_data_points
|
||||
|
||||
logger = get_logger("create_triplet_embeddings")
|
||||
|
||||
|
||||
async def create_triplet_embeddings(
|
||||
user: User,
|
||||
dataset: str = "main_dataset",
|
||||
run_in_background: bool = False,
|
||||
triplets_batch_size: int = 100,
|
||||
) -> dict[str, Any]:
|
||||
dataset_to_write = await get_authorized_existing_datasets(
|
||||
user=user, datasets=[dataset], permission_type="write"
|
||||
)
|
||||
|
||||
if not dataset_to_write:
|
||||
raise CogneeValidationError(
|
||||
message=f"User does not have write access to dataset: {dataset}",
|
||||
log=False,
|
||||
)
|
||||
|
||||
await set_database_global_context_variables(
|
||||
dataset_to_write[0].id, dataset_to_write[0].owner_id
|
||||
)
|
||||
|
||||
extraction_tasks = [Task(get_triplet_datapoints, triplets_batch_size=triplets_batch_size)]
|
||||
|
||||
enrichment_tasks = [
|
||||
Task(index_data_points, task_config={"batch_size": triplets_batch_size}),
|
||||
]
|
||||
|
||||
result = await memify(
|
||||
extraction_tasks=extraction_tasks,
|
||||
enrichment_tasks=enrichment_tasks,
|
||||
dataset=dataset_to_write[0].id,
|
||||
data=[{}],
|
||||
user=user,
|
||||
run_in_background=run_in_background,
|
||||
)
|
||||
|
||||
return result
|
||||
9
cognee/modules/engine/models/Triplet.py
Normal file
9
cognee/modules/engine/models/Triplet.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class Triplet(DataPoint):
|
||||
text: str
|
||||
from_node_id: str
|
||||
to_node_id: str
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
|
@ -7,3 +7,4 @@ from .ColumnValue import ColumnValue
|
|||
from .Timestamp import Timestamp
|
||||
from .Interval import Interval
|
||||
from .Event import Event
|
||||
from .Triplet import Triplet
|
||||
|
|
|
|||
|
|
@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
def get_edges(self) -> List[Edge]:
|
||||
return self.edges
|
||||
|
||||
async def _get_nodeset_subgraph(
|
||||
self,
|
||||
adapter,
|
||||
node_type,
|
||||
node_name,
|
||||
):
|
||||
"""Retrieve subgraph based on node type and name."""
|
||||
logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Nodeset does not exist, or empty nodeset projected from the database."
|
||||
)
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def _get_full_or_id_filtered_graph(
|
||||
self,
|
||||
adapter,
|
||||
relevant_ids_to_filter,
|
||||
):
|
||||
"""Retrieve full or ID-filtered graph with fallback."""
|
||||
if relevant_ids_to_filter is None:
|
||||
logger.info("Retrieving full graph.")
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty graph projected from the database.")
|
||||
return nodes_data, edges_data
|
||||
|
||||
get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
|
||||
if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
|
||||
logger.info("Retrieving ID-filtered graph from database.")
|
||||
nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
|
||||
else:
|
||||
logger.info("Retrieving full graph from database.")
|
||||
nodes_data, edges_data = await get_graph_data_fn()
|
||||
if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
|
||||
logger.warning(
|
||||
"Id filtered graph returned empty, falling back to full graph retrieval."
|
||||
)
|
||||
logger.info("Retrieving full graph")
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError("Empty graph projected from the database.")
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def _get_filtered_graph(
|
||||
self,
|
||||
adapter,
|
||||
memory_fragment_filter,
|
||||
):
|
||||
"""Retrieve graph filtered by attributes."""
|
||||
logger.info("Retrieving graph filtered by memory fragment")
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def project_graph_from_db(
|
||||
self,
|
||||
adapter: Union[GraphDBInterface],
|
||||
|
|
@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
memory_fragment_filter=[],
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
relevant_ids_to_filter: Optional[List[str]] = None,
|
||||
triplet_distance_penalty: float = 3.5,
|
||||
) -> None:
|
||||
if node_dimension < 1 or edge_dimension < 1:
|
||||
raise InvalidDimensionsError()
|
||||
try:
|
||||
if node_type is not None and node_name not in [None, [], ""]:
|
||||
nodes_data, edges_data = await self._get_nodeset_subgraph(
|
||||
adapter, node_type, node_name
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
|
||||
adapter, relevant_ids_to_filter
|
||||
)
|
||||
else:
|
||||
nodes_data, edges_data = await self._get_filtered_graph(
|
||||
adapter, memory_fragment_filter
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Determine projection strategy
|
||||
if node_type is not None and node_name not in [None, [], ""]:
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Nodeset does not exist, or empty nodetes projected from the database."
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty graph projected from the database.")
|
||||
else:
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Empty filtered graph projected from the database."
|
||||
)
|
||||
|
||||
# Process nodes
|
||||
for node_id, properties in nodes_data:
|
||||
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
||||
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
|
||||
self.add_node(
|
||||
Node(
|
||||
str(node_id),
|
||||
node_attributes,
|
||||
dimension=node_dimension,
|
||||
node_penalty=triplet_distance_penalty,
|
||||
)
|
||||
)
|
||||
|
||||
# Process edges
|
||||
for source_id, target_id, relationship_type, properties in edges_data:
|
||||
|
|
@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
attributes=edge_attributes,
|
||||
directed=directed,
|
||||
dimension=edge_dimension,
|
||||
edge_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.add_edge(edge)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,13 +20,17 @@ class Node:
|
|||
status: np.ndarray
|
||||
|
||||
def __init__(
|
||||
self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1
|
||||
self,
|
||||
node_id: str,
|
||||
attributes: Optional[Dict[str, Any]] = None,
|
||||
dimension: int = 1,
|
||||
node_penalty: float = 3.5,
|
||||
):
|
||||
if dimension <= 0:
|
||||
raise InvalidDimensionsError()
|
||||
self.id = node_id
|
||||
self.attributes = attributes if attributes is not None else {}
|
||||
self.attributes["vector_distance"] = float("inf")
|
||||
self.attributes["vector_distance"] = node_penalty
|
||||
self.skeleton_neighbours = []
|
||||
self.skeleton_edges = []
|
||||
self.status = np.ones(dimension, dtype=int)
|
||||
|
|
@ -105,13 +109,14 @@ class Edge:
|
|||
attributes: Optional[Dict[str, Any]] = None,
|
||||
directed: bool = True,
|
||||
dimension: int = 1,
|
||||
edge_penalty: float = 3.5,
|
||||
):
|
||||
if dimension <= 0:
|
||||
raise InvalidDimensionsError()
|
||||
self.node1 = node1
|
||||
self.node2 = node2
|
||||
self.attributes = attributes if attributes is not None else {}
|
||||
self.attributes["vector_distance"] = float("inf")
|
||||
self.attributes["vector_distance"] = edge_penalty
|
||||
self.directed = directed
|
||||
self.status = np.ones(dimension, dtype=int)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
async def get_completion(
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -74,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.validation_system_prompt_path = validation_system_prompt_path
|
||||
self.validation_user_prompt_path = validation_user_prompt_path
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.save_interaction = save_interaction
|
||||
|
|
@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
self.system_prompt_path = system_prompt_path
|
||||
self.system_prompt = system_prompt
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
self.wide_search_top_k = wide_search_top_k
|
||||
self.node_type = node_type
|
||||
self.node_name = node_name
|
||||
self.triplet_distance_penalty = triplet_distance_penalty
|
||||
|
||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||
"""
|
||||
|
|
@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
collections=vector_index_collections or None,
|
||||
node_type=self.node_type,
|
||||
node_name=self.node_name,
|
||||
wide_search_top_k=self.wide_search_top_k,
|
||||
triplet_distance_penalty=self.triplet_distance_penalty,
|
||||
)
|
||||
|
||||
return found_triplets
|
||||
|
|
@ -141,6 +147,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
return triplets
|
||||
|
||||
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
|
||||
context = await self.resolve_edges_to_text(triplets)
|
||||
return context
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
"""Initialize retriever with default prompt paths and search parameters."""
|
||||
super().__init__(
|
||||
|
|
@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.summarize_prompt_path = summarize_prompt_path
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
|
|
|
|||
182
cognee/modules/retrieval/triplet_retriever.py
Normal file
182
cognee/modules/retrieval/triplet_retriever.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional, Type, List
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
get_conversation_history,
|
||||
)
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.context_global_variables import session_user
|
||||
from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||
|
||||
logger = get_logger("TripletRetriever")
|
||||
|
||||
|
||||
class TripletRetriever(BaseRetriever):
|
||||
"""
|
||||
Retriever for handling LLM-based completion searches using triplets.
|
||||
|
||||
Public methods:
|
||||
- get_context(query: str) -> str
|
||||
- get_completion(query: str, context: Optional[Any] = None) -> Any
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_prompt_path: str = "context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
system_prompt: Optional[str] = None,
|
||||
top_k: Optional[int] = 5,
|
||||
):
|
||||
"""Initialize retriever with optional custom prompt paths."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k if top_k is not None else 1
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
async def get_context(self, query: str) -> str:
|
||||
"""
|
||||
Retrieves relevant triplets as context.
|
||||
|
||||
Fetches triplets based on a query from a vector engine and combines their text.
|
||||
Returns empty string if no triplets are found. Raises NoDataError if the collection is not
|
||||
found.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string used to search for relevant triplets.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A string containing the combined text of the retrieved triplets, or an
|
||||
empty string if none are found.
|
||||
"""
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
try:
|
||||
if not await vector_engine.has_collection(collection_name="Triplet_text"):
|
||||
logger.error("Triplet_text collection not found")
|
||||
raise NoDataError(
|
||||
"In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
|
||||
)
|
||||
|
||||
found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
|
||||
|
||||
if len(found_triplets) == 0:
|
||||
return ""
|
||||
|
||||
triplets_payload = [found_triplet.payload["text"] for found_triplet in found_triplets]
|
||||
combined_context = "\n".join(triplets_payload)
|
||||
return combined_context
|
||||
except CollectionNotFoundError as error:
|
||||
logger.error("Triplet_text collection not found")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
response_model: Type = str,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Generates an LLM completion using the context.
|
||||
|
||||
Retrieves context if not provided and generates a completion based on the query and
|
||||
context using an external completion generator.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string to be used for generating a completion.
|
||||
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
||||
completion; if None, it retrieves the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The generated completion based on the provided query and context.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
||||
cache_config = CacheConfig()
|
||||
user = session_user.get()
|
||||
user_id = getattr(user, "id", None)
|
||||
session_save = user_id and cache_config.caching
|
||||
|
||||
if session_save:
|
||||
completion = await self._get_completion_with_session(
|
||||
query=query,
|
||||
context=context,
|
||||
session_id=session_id,
|
||||
response_model=response_model,
|
||||
)
|
||||
else:
|
||||
completion = await self._get_completion_without_session(
|
||||
query=query,
|
||||
context=context,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return [completion]
|
||||
|
||||
async def _get_completion_with_session(
|
||||
self,
|
||||
query: str,
|
||||
context: str,
|
||||
session_id: Optional[str],
|
||||
response_model: Type,
|
||||
) -> Any:
|
||||
"""Generate completion with session history and caching."""
|
||||
conversation_history = await get_conversation_history(session_id=session_id)
|
||||
|
||||
context_summary, completion = await asyncio.gather(
|
||||
summarize_text(context),
|
||||
generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
conversation_history=conversation_history,
|
||||
response_model=response_model,
|
||||
),
|
||||
)
|
||||
|
||||
await save_conversation_history(
|
||||
query=query,
|
||||
context_summary=context_summary,
|
||||
answer=completion,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return completion
|
||||
|
||||
async def _get_completion_without_session(
|
||||
self,
|
||||
query: str,
|
||||
context: str,
|
||||
response_model: Type,
|
||||
) -> Any:
|
||||
"""Generate completion without session history."""
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
system_prompt=self.system_prompt,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return completion
|
||||
|
|
@ -58,6 +58,8 @@ async def get_memory_fragment(
|
|||
properties_to_project: Optional[List[str]] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
relevant_ids_to_filter: Optional[List[str]] = None,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> CogneeGraph:
|
||||
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
|
||||
if properties_to_project is None:
|
||||
|
|
@ -74,6 +76,8 @@ async def get_memory_fragment(
|
|||
edge_properties_to_project=["relationship_name", "edge_text"],
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
except EntityNotFoundError:
|
||||
|
|
@ -95,6 +99,8 @@ async def brute_force_triplet_search(
|
|||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> List[Edge]:
|
||||
"""
|
||||
Performs a brute force search to retrieve the top triplets from the graph.
|
||||
|
|
@ -107,6 +113,8 @@ async def brute_force_triplet_search(
|
|||
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
|
||||
node_type: node type to filter
|
||||
node_name: node name to filter
|
||||
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
|
||||
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
|
||||
|
||||
Returns:
|
||||
list: The top triplet results.
|
||||
|
|
@ -116,10 +124,10 @@ async def brute_force_triplet_search(
|
|||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project, node_type=node_type, node_name=node_name
|
||||
)
|
||||
# Setting wide search limit based on the parameters
|
||||
non_global_search = node_name is None
|
||||
|
||||
wide_search_limit = wide_search_top_k if non_global_search else None
|
||||
|
||||
if collections is None:
|
||||
collections = [
|
||||
|
|
@ -140,7 +148,7 @@ async def brute_force_triplet_search(
|
|||
async def search_in_collection(collection_name: str):
|
||||
try:
|
||||
return await vector_engine.search(
|
||||
collection_name=collection_name, query_vector=query_vector, limit=None
|
||||
collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
|
|
@ -156,15 +164,38 @@ async def brute_force_triplet_search(
|
|||
return []
|
||||
|
||||
# Final statistics
|
||||
projection_time = time.time() - start_time
|
||||
vector_collection_search_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
|
||||
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
|
||||
)
|
||||
|
||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
|
||||
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
||||
|
||||
if wide_search_limit is not None:
|
||||
relevant_ids_to_filter = list(
|
||||
{
|
||||
str(getattr(scored_node, "id"))
|
||||
for collection_name, score_collection in node_distances.items()
|
||||
if collection_name != "EdgeType_relationship_name"
|
||||
and isinstance(score_collection, (list, tuple))
|
||||
for scored_node in score_collection
|
||||
if getattr(scored_node, "id", None)
|
||||
}
|
||||
)
|
||||
else:
|
||||
relevant_ids_to_filter = None
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(
|
||||
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import os
|
|||
from typing import Callable, List, Optional, Type
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.search.operations import select_search_type
|
||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||
|
|
@ -37,6 +38,8 @@ async def get_search_type_tools(
|
|||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, List[Callable]] = {
|
||||
SearchType.SUMMARIES: [
|
||||
|
|
@ -59,6 +62,18 @@ async def get_search_type_tools(
|
|||
system_prompt=system_prompt,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.TRIPLET_COMPLETION: [
|
||||
TripletRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
system_prompt=system_prompt,
|
||||
).get_completion,
|
||||
TripletRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
system_prompt=system_prompt,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_COMPLETION: [
|
||||
GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -67,6 +82,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -75,6 +92,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_COMPLETION_COT: [
|
||||
|
|
@ -85,6 +104,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
GraphCompletionCotRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -93,6 +114,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
|
||||
|
|
@ -103,6 +126,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
GraphCompletionContextExtensionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -111,6 +136,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION: [
|
||||
|
|
@ -121,6 +148,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
GraphSummaryCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -129,6 +158,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.CODE: [
|
||||
|
|
@ -145,8 +176,16 @@ async def get_search_type_tools(
|
|||
],
|
||||
SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
|
||||
SearchType.TEMPORAL: [
|
||||
TemporalRetriever(top_k=top_k).get_completion,
|
||||
TemporalRetriever(top_k=top_k).get_context,
|
||||
TemporalRetriever(
|
||||
top_k=top_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
TemporalRetriever(
|
||||
top_k=top_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.CHUNKS_LEXICAL: (
|
||||
lambda _r=JaccardChunksRetriever(top_k=top_k): [
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ async def no_access_control_search(
|
|||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||
search_tools = await get_search_type_tools(
|
||||
query_type=query_type,
|
||||
|
|
@ -35,6 +37,8 @@ async def no_access_control_search(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
graph_engine = await get_graph_engine()
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ async def search(
|
|||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
||||
"""
|
||||
|
||||
|
|
@ -90,6 +92,8 @@ async def search(
|
|||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
else:
|
||||
search_results = [
|
||||
|
|
@ -105,6 +109,8 @@ async def search(
|
|||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -219,6 +225,8 @@ async def authorized_search(
|
|||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[
|
||||
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
||||
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
||||
|
|
@ -246,6 +254,8 @@ async def authorized_search(
|
|||
last_k=last_k,
|
||||
only_context=True,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
context = {}
|
||||
|
|
@ -267,6 +277,8 @@ async def authorized_search(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
search_tools = specific_search_tools
|
||||
if len(search_tools) == 2:
|
||||
|
|
@ -306,6 +318,7 @@ async def authorized_search(
|
|||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
)
|
||||
|
||||
return search_results
|
||||
|
|
@ -325,6 +338,8 @@ async def search_in_datasets_context(
|
|||
only_context: bool = False,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
|
||||
"""
|
||||
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
||||
|
|
@ -345,6 +360,8 @@ async def search_in_datasets_context(
|
|||
only_context: bool = False,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||
# Set database configuration in async context for each dataset user has access for
|
||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||
|
|
@ -378,6 +395,8 @@ async def search_in_datasets_context(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
search_tools = specific_search_tools
|
||||
if len(search_tools) == 2:
|
||||
|
|
@ -413,6 +432,8 @@ async def search_in_datasets_context(
|
|||
only_context=only_context,
|
||||
context=context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ class SearchType(Enum):
|
|||
SUMMARIES = "SUMMARIES"
|
||||
CHUNKS = "CHUNKS"
|
||||
RAG_COMPLETION = "RAG_COMPLETION"
|
||||
TRIPLET_COMPLETION = "TRIPLET_COMPLETION"
|
||||
GRAPH_COMPLETION = "GRAPH_COMPLETION"
|
||||
GRAPH_SUMMARY_COMPLETION = "GRAPH_SUMMARY_COMPLETION"
|
||||
CODE = "CODE"
|
||||
|
|
|
|||
283
cognee/tasks/memify/get_triplet_datapoints.py
Normal file
283
cognee/tasks/memify/get_triplet_datapoints.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
from typing import AsyncGenerator, Dict, Any, List, Optional
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.engine.models import Triplet
|
||||
from cognee.tasks.storage import index_data_points
|
||||
|
||||
logger = get_logger("get_triplet_datapoints")
|
||||
|
||||
|
||||
def _build_datapoint_type_index_mapping() -> Dict[str, List[str]]:
|
||||
"""
|
||||
Build a mapping of DataPoint type names to their index_fields.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- Dict[str, List[str]]: Mapping of type name to list of index field names
|
||||
"""
|
||||
logger.debug("Building DataPoint type to index_fields mapping")
|
||||
subclasses = get_all_subclasses(DataPoint)
|
||||
datapoint_type_index_property = {}
|
||||
|
||||
for subclass in subclasses:
|
||||
if "metadata" in subclass.model_fields:
|
||||
metadata_field = subclass.model_fields["metadata"]
|
||||
default = getattr(metadata_field, "default", None)
|
||||
if isinstance(default, dict):
|
||||
index_fields = default.get("index_fields", [])
|
||||
if index_fields:
|
||||
datapoint_type_index_property[subclass.__name__] = index_fields
|
||||
logger.debug(
|
||||
f"Registered {subclass.__name__} with index_fields: {index_fields}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Found {len(datapoint_type_index_property)} DataPoint types with index_fields: "
|
||||
f"{list(datapoint_type_index_property.keys())}"
|
||||
)
|
||||
return datapoint_type_index_property
|
||||
|
||||
|
||||
def _extract_embeddable_text(node_or_edge: Dict[str, Any], index_fields: List[str]) -> str:
|
||||
"""
|
||||
Extract and concatenate embeddable properties from a node or edge dictionary.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- node_or_edge (Dict[str, Any]): Dictionary containing node or edge properties.
|
||||
- index_fields (List[str]): List of field names to extract and concatenate.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- str: Concatenated string of all embeddable property values, or empty string if none found.
|
||||
"""
|
||||
if not node_or_edge or not index_fields:
|
||||
return ""
|
||||
|
||||
embeddable_values = []
|
||||
for field_name in index_fields:
|
||||
field_value = node_or_edge.get(field_name)
|
||||
if field_value is not None:
|
||||
field_value = str(field_value).strip()
|
||||
|
||||
if field_value:
|
||||
embeddable_values.append(field_value)
|
||||
|
||||
return " ".join(embeddable_values) if embeddable_values else ""
|
||||
|
||||
|
||||
def _extract_relationship_text(
|
||||
relationship: Dict[str, Any], datapoint_type_index_property: Dict[str, List[str]]
|
||||
) -> str:
|
||||
"""
|
||||
Extract relationship text from edge properties.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- relationship (Dict[str, Any]): Dictionary containing relationship properties
|
||||
- datapoint_type_index_property (Dict[str, List[str]]): Mapping of type to index fields
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- str: Extracted relationship text or empty string
|
||||
"""
|
||||
if not relationship:
|
||||
return ""
|
||||
|
||||
edge_text = relationship.get("edge_text")
|
||||
if edge_text and isinstance(edge_text, str) and edge_text.strip():
|
||||
return edge_text.strip()
|
||||
|
||||
# Fallback to extracting from EdgeType index_fields
|
||||
edge_type_index_fields = datapoint_type_index_property.get("EdgeType", [])
|
||||
return _extract_embeddable_text(relationship, edge_type_index_fields)
|
||||
|
||||
|
||||
def _process_single_triplet(
|
||||
triplet_datapoint: Dict[str, Any],
|
||||
datapoint_type_index_property: Dict[str, List[str]],
|
||||
offset: int,
|
||||
idx: int,
|
||||
) -> tuple[Optional[Triplet], Optional[str]]:
|
||||
"""
|
||||
Process a single triplet and create a Triplet object.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- triplet_datapoint (Dict[str, Any]): Raw triplet data from graph engine
|
||||
- datapoint_type_index_property (Dict[str, List[str]]): Type to index fields mapping
|
||||
- offset (int): Current batch offset
|
||||
- idx (int): Index within current batch
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- tuple[Optional[Triplet], Optional[str]]: (Triplet object, error message if skipped)
|
||||
"""
|
||||
start_node = triplet_datapoint.get("start_node", {})
|
||||
end_node = triplet_datapoint.get("end_node", {})
|
||||
relationship = triplet_datapoint.get("relationship_properties", {})
|
||||
|
||||
start_node_type = start_node.get("type")
|
||||
end_node_type = end_node.get("type")
|
||||
|
||||
start_index_fields = datapoint_type_index_property.get(start_node_type, [])
|
||||
end_index_fields = datapoint_type_index_property.get(end_node_type, [])
|
||||
|
||||
if not start_index_fields:
|
||||
logger.debug(
|
||||
f"No index_fields found for start_node type '{start_node_type}' in triplet {offset + idx}"
|
||||
)
|
||||
if not end_index_fields:
|
||||
logger.debug(
|
||||
f"No index_fields found for end_node type '{end_node_type}' in triplet {offset + idx}"
|
||||
)
|
||||
|
||||
start_node_id = start_node.get("id", "")
|
||||
end_node_id = end_node.get("id", "")
|
||||
|
||||
if not start_node_id or not end_node_id:
|
||||
return None, (
|
||||
f"Skipping triplet at offset {offset + idx}: missing node IDs "
|
||||
f"(start: {start_node_id}, end: {end_node_id})"
|
||||
)
|
||||
|
||||
relationship_text = _extract_relationship_text(relationship, datapoint_type_index_property)
|
||||
start_node_text = _extract_embeddable_text(start_node, start_index_fields)
|
||||
end_node_text = _extract_embeddable_text(end_node, end_index_fields)
|
||||
|
||||
if not start_node_text and not end_node_text and not relationship_text:
|
||||
return None, (
|
||||
f"Skipping triplet at offset {offset + idx}: empty embeddable text "
|
||||
f"(start_node_id: {start_node_id}, end_node_id: {end_node_id})"
|
||||
)
|
||||
|
||||
embeddable_text = f"{start_node_text}-›{relationship_text}-›{end_node_text}".strip()
|
||||
|
||||
triplet_obj = Triplet(from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text)
|
||||
|
||||
return triplet_obj, None
|
||||
|
||||
|
||||
async def get_triplet_datapoints(
|
||||
data,
|
||||
triplets_batch_size: int = 100,
|
||||
) -> AsyncGenerator[Triplet, None]:
|
||||
"""
|
||||
Async generator that yields batches of triplet datapoints with embeddable text extracted.
|
||||
|
||||
Each triplet in the batch includes:
|
||||
- Original triplet structure (start_node, relationship_properties, end_node)
|
||||
- Extracted embeddable text for each element based on index_fields
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- triplets_batch_size (int): Number of triplets to retrieve per batch. Default is 100.
|
||||
|
||||
Yields:
|
||||
-------
|
||||
- List[Dict[str, Any]]: A batch of triplets, each enriched with embeddable text.
|
||||
"""
|
||||
if not data or data == [{}]:
|
||||
logger.info("Fetching graph data for current user")
|
||||
|
||||
logger.info(f"Starting triplet datapoints extraction with batch size: {triplets_batch_size}")
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
graph_engine_type = type(graph_engine).__name__
|
||||
logger.debug(f"Using graph engine: {graph_engine_type}")
|
||||
|
||||
if not hasattr(graph_engine, "get_triplets_batch"):
|
||||
error_msg = f"Graph adapter {graph_engine_type} does not support get_triplets_batch method"
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(error_msg)
|
||||
|
||||
datapoint_type_index_property = _build_datapoint_type_index_mapping()
|
||||
|
||||
offset = 0
|
||||
total_triplets_processed = 0
|
||||
batch_number = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
batch_number += 1
|
||||
logger.debug(
|
||||
f"Fetching triplet batch {batch_number} (offset: {offset}, limit: {triplets_batch_size})"
|
||||
)
|
||||
|
||||
triplets_batch = await graph_engine.get_triplets_batch(
|
||||
offset=offset, limit=triplets_batch_size
|
||||
)
|
||||
|
||||
if not triplets_batch:
|
||||
logger.info(f"No more triplets found at offset {offset}. Processing complete.")
|
||||
break
|
||||
|
||||
logger.debug(f"Retrieved {len(triplets_batch)} triplets in batch {batch_number}")
|
||||
|
||||
triplet_datapoints = []
|
||||
skipped_count = 0
|
||||
|
||||
for idx, triplet_datapoint in enumerate(triplets_batch):
|
||||
try:
|
||||
triplet_obj, error_msg = _process_single_triplet(
|
||||
triplet_datapoint, datapoint_type_index_property, offset, idx
|
||||
)
|
||||
|
||||
if error_msg:
|
||||
logger.warning(error_msg)
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
if triplet_obj:
|
||||
triplet_datapoints.append(triplet_obj)
|
||||
yield triplet_obj
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error processing triplet at offset {offset + idx}: {e}. "
|
||||
f"Skipping this triplet and continuing."
|
||||
)
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
if skipped_count > 0:
|
||||
logger.warning(
|
||||
f"Skipped {skipped_count} out of {len(triplets_batch)} triplets in batch {batch_number}"
|
||||
)
|
||||
|
||||
if not triplet_datapoints:
|
||||
logger.warning(
|
||||
f"No valid triplet datapoints in batch {batch_number} after processing"
|
||||
)
|
||||
offset += len(triplets_batch)
|
||||
if len(triplets_batch) < triplets_batch_size:
|
||||
break
|
||||
continue
|
||||
|
||||
total_triplets_processed += len(triplet_datapoints)
|
||||
logger.info(
|
||||
f"Batch {batch_number} complete: processed {len(triplet_datapoints)} triplets "
|
||||
f"(total processed: {total_triplets_processed})"
|
||||
)
|
||||
|
||||
offset += len(triplets_batch)
|
||||
if len(triplets_batch) < triplets_batch_size:
|
||||
logger.info(
|
||||
f"Last batch retrieved (got {len(triplets_batch)} < {triplets_batch_size} triplets). "
|
||||
f"Processing complete."
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving triplet batch {batch_number} at offset {offset}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Triplet datapoints extraction complete. "
|
||||
f"Processed {total_triplets_processed} triplets across {batch_number} batch(es)."
|
||||
)
|
||||
84
cognee/tests/integration/retrieval/test_triplet_retriever.py
Normal file
84
cognee/tests/integration/retrieval/test_triplet_retriever.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
import pytest_asyncio
|
||||
import cognee
|
||||
|
||||
from cognee.low_level import setup
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
||||
from cognee.modules.engine.models import Triplet
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_with_triplets():
|
||||
"""Set up a clean test environment with triplets."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(base_dir / ".cognee_system/test_triplet_retriever_context_simple")
|
||||
data_directory_path = str(base_dir / ".data_storage/test_triplet_retriever_context_simple")
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
triplet1 = Triplet(
|
||||
from_node_id="node1",
|
||||
to_node_id="node2",
|
||||
text="Alice knows Bob",
|
||||
)
|
||||
triplet2 = Triplet(
|
||||
from_node_id="node2",
|
||||
to_node_id="node3",
|
||||
text="Bob works at Tech Corp",
|
||||
)
|
||||
|
||||
triplets = [triplet1, triplet2]
|
||||
await add_data_points(triplets)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment_empty():
|
||||
"""Set up a clean test environment without triplets."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
system_directory_path = str(
|
||||
base_dir / ".cognee_system/test_triplet_retriever_context_empty_collection"
|
||||
)
|
||||
data_directory_path = str(
|
||||
base_dir / ".data_storage/test_triplet_retriever_context_empty_collection"
|
||||
)
|
||||
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
yield
|
||||
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triplet_retriever_context_simple(setup_test_environment_with_triplets):
|
||||
"""Integration test: verify TripletRetriever can retrieve triplet context."""
|
||||
retriever = TripletRetriever(top_k=5)
|
||||
|
||||
context = await retriever.get_context("Alice")
|
||||
|
||||
assert "Alice knows Bob" in context, "Failed to get Alice triplet"
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
import os
|
||||
import pathlib
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import cognee
|
||||
from cognee.tasks.memify.get_triplet_datapoints import get_triplet_datapoints
|
||||
from cognee.modules.engine.models import Triplet
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_test_environment():
|
||||
"""Set up a clean test environment with a simple graph."""
|
||||
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
data_directory_path = str(base_dir / ".data_storage/test_get_triplet_datapoints_integration")
|
||||
cognee_directory_path = str(base_dir / ".cognee_system/test_get_triplet_datapoints_integration")
|
||||
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
dataset_name = "test_triplets"
|
||||
|
||||
text = "Volkswagen is a german car manufacturer from Wolfsburg. They produce different models such as Golf, Polo and Touareg."
|
||||
await cognee.add(text, dataset_name)
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
yield dataset_name
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_triplet_datapoints_integration(setup_test_environment):
|
||||
"""Integration test: verify get_triplet_datapoints works with real graph data."""
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
if not hasattr(graph_engine, "get_triplets_batch"):
|
||||
pytest.skip("Graph engine does not support get_triplets_batch")
|
||||
|
||||
triplets = []
|
||||
with patch(
|
||||
"cognee.tasks.memify.get_triplet_datapoints.index_data_points", new_callable=AsyncMock
|
||||
):
|
||||
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=10):
|
||||
triplets.append(triplet)
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
|
||||
if len(edges) > 0 and len(triplets) == 0:
|
||||
test_triplets = await graph_engine.get_triplets_batch(offset=0, limit=10)
|
||||
if len(test_triplets) == 0:
|
||||
pytest.fail(
|
||||
f"Edges exist in graph ({len(edges)} edges) but get_triplets_batch found none. "
|
||||
f"This indicates the query pattern may not match the graph structure."
|
||||
)
|
||||
|
||||
for triplet in triplets:
|
||||
assert isinstance(triplet, Triplet), "Each item should be a Triplet instance"
|
||||
assert triplet.from_node_id, "Triplet should have from_node_id"
|
||||
assert triplet.to_node_id, "Triplet should have to_node_id"
|
||||
assert triplet.text, "Triplet should have embeddable text"
|
||||
|
|
@ -8,10 +8,10 @@ Tests all retrievers that save conversation history to Redis cache:
|
|||
4. GRAPH_COMPLETION_CONTEXT_EXTENSION
|
||||
5. GRAPH_SUMMARY_COMPLETION
|
||||
6. TEMPORAL
|
||||
7. TRIPLET_COMPLETION
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import cognee
|
||||
import pathlib
|
||||
|
||||
|
|
@ -63,6 +63,10 @@ async def main():
|
|||
|
||||
user = await get_default_user()
|
||||
|
||||
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
|
||||
|
||||
await create_triplet_embeddings(user=user, dataset=dataset_name)
|
||||
|
||||
cache_engine = get_cache_engine()
|
||||
assert cache_engine is not None, "Cache engine should be available for testing"
|
||||
|
||||
|
|
@ -216,6 +220,24 @@ async def main():
|
|||
]
|
||||
assert len(our_qa_temporal) == 1, "Should find Temporal question in history"
|
||||
|
||||
session_id_triplet = "test_session_triplet"
|
||||
|
||||
result_triplet = await cognee.search(
|
||||
query_type=SearchType.TRIPLET_COMPLETION,
|
||||
query_text="What companies are mentioned?",
|
||||
session_id=session_id_triplet,
|
||||
)
|
||||
|
||||
assert isinstance(result_triplet, list) and len(result_triplet) > 0, (
|
||||
f"TRIPLET_COMPLETION should return non-empty list, got: {result_triplet!r}"
|
||||
)
|
||||
|
||||
history_triplet = await cache_engine.get_latest_qa(str(user.id), session_id_triplet, last_n=10)
|
||||
our_qa_triplet = [
|
||||
h for h in history_triplet if h["question"] == "What companies are mentioned?"
|
||||
]
|
||||
assert len(our_qa_triplet) == 1, "Should find Triplet question in history"
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
get_conversation_history,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import pathlib
|
|||
import os
|
||||
import cognee
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
|
@ -12,8 +13,10 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
|
|||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
)
|
||||
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from collections import Counter
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -37,6 +40,23 @@ async def main():
|
|||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
user = await get_default_user()
|
||||
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
|
||||
|
||||
await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
collection = await vector_engine.search(
|
||||
query_text="Test", limit=None, collection_name="Triplet_text"
|
||||
)
|
||||
|
||||
assert len(edges) == len(collection), (
|
||||
f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection"
|
||||
)
|
||||
|
||||
context_gk = await GraphCompletionRetriever().get_context(
|
||||
query="Next to which country is Germany located?"
|
||||
)
|
||||
|
|
@ -49,6 +69,9 @@ async def main():
|
|||
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
|
||||
query="Next to which country is Germany located?"
|
||||
)
|
||||
context_triplet = await TripletRetriever().get_context(
|
||||
query="Next to which country is Germany located?"
|
||||
)
|
||||
|
||||
for name, context in [
|
||||
("GraphCompletionRetriever", context_gk),
|
||||
|
|
@ -65,6 +88,13 @@ async def main():
|
|||
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
|
||||
)
|
||||
|
||||
assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string"
|
||||
assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty"
|
||||
lower_triplet = context_triplet.lower()
|
||||
assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
|
||||
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
|
||||
)
|
||||
|
||||
triplets_gk = await GraphCompletionRetriever().get_triplets(
|
||||
query="Next to which country is Germany located?"
|
||||
)
|
||||
|
|
@ -129,6 +159,11 @@ async def main():
|
|||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=True,
|
||||
)
|
||||
completion_triplet = await cognee.search(
|
||||
query_type=SearchType.TRIPLET_COMPLETION,
|
||||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=True,
|
||||
)
|
||||
|
||||
await cognee.search(
|
||||
query_type=SearchType.FEEDBACK,
|
||||
|
|
@ -141,6 +176,7 @@ async def main():
|
|||
("GRAPH_COMPLETION_COT", completion_cot),
|
||||
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
|
||||
("GRAPH_SUMMARY_COMPLETION", completion_sum),
|
||||
("TRIPLET_COMPLETION", completion_triplet),
|
||||
]:
|
||||
assert isinstance(search_results, list), f"{name}: should return a list"
|
||||
assert len(search_results) == 1, (
|
||||
|
|
@ -168,7 +204,7 @@ async def main():
|
|||
|
||||
# Assert there are exactly 4 CogneeUserInteraction nodes.
|
||||
assert type_counts.get("CogneeUserInteraction", 0) == 4, (
|
||||
f"Expected exactly four DCogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
|
||||
f"Expected exactly four CogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
|
||||
)
|
||||
|
||||
# Assert there is exactly two CogneeUserFeedback nodes.
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def test_node_initialization():
|
|||
"""Test that a Node is initialized correctly."""
|
||||
node = Node("node1", {"attr1": "value1"}, dimension=2)
|
||||
assert node.id == "node1"
|
||||
assert node.attributes == {"attr1": "value1", "vector_distance": np.inf}
|
||||
assert node.attributes == {"attr1": "value1", "vector_distance": 3.5}
|
||||
assert len(node.status) == 2
|
||||
assert np.all(node.status == 1)
|
||||
|
||||
|
|
@ -96,7 +96,7 @@ def test_edge_initialization():
|
|||
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
|
||||
assert edge.node1 == node1
|
||||
assert edge.node2 == node2
|
||||
assert edge.attributes == {"vector_distance": np.inf, "weight": 10}
|
||||
assert edge.attributes == {"vector_distance": 3.5, "weight": 10}
|
||||
assert edge.directed is False
|
||||
assert len(edge.status) == 2
|
||||
assert np.all(edge.status == 1)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
|
|
@ -11,6 +12,30 @@ def setup_graph():
|
|||
return CogneeGraph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter():
|
||||
"""Fixture to create a mock adapter for database operations."""
|
||||
adapter = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_engine():
|
||||
"""Fixture to create a mock vector engine."""
|
||||
engine = AsyncMock()
|
||||
engine.search = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
class MockScoredResult:
|
||||
"""Mock class for vector search results."""
|
||||
|
||||
def __init__(self, id, score, payload=None):
|
||||
self.id = id
|
||||
self.score = score
|
||||
self.payload = payload or {}
|
||||
|
||||
|
||||
def test_add_node_success(setup_graph):
|
||||
"""Test successful addition of a node."""
|
||||
graph = setup_graph
|
||||
|
|
@ -73,3 +98,433 @@ def test_get_edges_nonexistent_node(setup_graph):
|
|||
graph = setup_graph
|
||||
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
|
||||
graph.get_edges_from_node("nonexistent")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter):
|
||||
"""Test projecting a full graph from database."""
|
||||
graph = setup_graph
|
||||
|
||||
nodes_data = [
|
||||
("1", {"name": "Node1", "description": "First node"}),
|
||||
("2", {"name": "Node2", "description": "Second node"}),
|
||||
]
|
||||
edges_data = [
|
||||
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
|
||||
]
|
||||
|
||||
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name", "description"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
)
|
||||
|
||||
assert len(graph.nodes) == 2
|
||||
assert len(graph.edges) == 1
|
||||
assert graph.get_node("1") is not None
|
||||
assert graph.get_node("2") is not None
|
||||
assert graph.edges[0].node1.id == "1"
|
||||
assert graph.edges[0].node2.id == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter):
|
||||
"""Test projecting an ID-filtered graph from database."""
|
||||
graph = setup_graph
|
||||
|
||||
nodes_data = [
|
||||
("1", {"name": "Node1"}),
|
||||
("2", {"name": "Node2"}),
|
||||
]
|
||||
edges_data = [
|
||||
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
|
||||
]
|
||||
|
||||
mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
relevant_ids_to_filter=["1", "2"],
|
||||
)
|
||||
|
||||
assert len(graph.nodes) == 2
|
||||
assert len(graph.edges) == 1
|
||||
mock_adapter.get_id_filtered_graph_data.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter):
|
||||
"""Test projecting a nodeset subgraph filtered by node type and name."""
|
||||
graph = setup_graph
|
||||
|
||||
nodes_data = [
|
||||
("1", {"name": "Alice", "type": "Person"}),
|
||||
("2", {"name": "Bob", "type": "Person"}),
|
||||
]
|
||||
edges_data = [
|
||||
("1", "2", "KNOWS", {"relationship_name": "knows"}),
|
||||
]
|
||||
|
||||
mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name", "type"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
node_type="Person",
|
||||
node_name=["Alice"],
|
||||
)
|
||||
|
||||
assert len(graph.nodes) == 2
|
||||
assert graph.get_node("1") is not None
|
||||
assert len(graph.edges) == 1
|
||||
mock_adapter.get_nodeset_subgraph.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
|
||||
"""Test projecting empty graph raises EntityNotFoundError."""
|
||||
graph = setup_graph
|
||||
|
||||
mock_adapter.get_graph_data = AsyncMock(return_value=([], []))
|
||||
|
||||
with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."):
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
|
||||
"""Test that edges referencing missing nodes raise error."""
|
||||
graph = setup_graph
|
||||
|
||||
nodes_data = [
|
||||
("1", {"name": "Node1"}),
|
||||
]
|
||||
edges_data = [
|
||||
("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}),
|
||||
]
|
||||
|
||||
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"):
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_nodes(setup_graph):
|
||||
"""Test mapping vector distances to graph nodes."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1", {"name": "Node1"})
|
||||
node2 = Node("2", {"name": "Node2"})
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
node_distances = {
|
||||
"Entity_name": [
|
||||
MockScoredResult("1", 0.95),
|
||||
MockScoredResult("2", 0.87),
|
||||
]
|
||||
}
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_partial_node_coverage(setup_graph):
|
||||
"""Test mapping vector distances when only some nodes have results."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1", {"name": "Node1"})
|
||||
node2 = Node("2", {"name": "Node2"})
|
||||
node3 = Node("3", {"name": "Node3"})
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
node_distances = {
|
||||
"Entity_name": [
|
||||
MockScoredResult("1", 0.95),
|
||||
MockScoredResult("2", 0.87),
|
||||
]
|
||||
}
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_multiple_categories(setup_graph):
|
||||
"""Test mapping vector distances from multiple collection categories."""
|
||||
graph = setup_graph
|
||||
|
||||
# Create nodes
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
node4 = Node("4")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
graph.add_node(node4)
|
||||
|
||||
node_distances = {
|
||||
"Entity_name": [
|
||||
MockScoredResult("1", 0.95),
|
||||
MockScoredResult("2", 0.87),
|
||||
],
|
||||
"TextSummary_text": [
|
||||
MockScoredResult("3", 0.92),
|
||||
],
|
||||
}
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == 0.92
|
||||
assert graph.get_node("4").attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine):
|
||||
"""Test mapping vector distances to edges when edge_distances provided."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine):
|
||||
"""Test mapping edge distances when searching for them."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
mock_vector_engine.search.return_value = [
|
||||
MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=None,
|
||||
)
|
||||
|
||||
mock_vector_engine.search.assert_called_once()
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.88
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine):
|
||||
"""Test mapping edge distances when only some edges have results."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"})
|
||||
edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"})
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
assert graph.edges[1].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_edges_fallback_to_relationship_type(
|
||||
setup_graph, mock_vector_engine
|
||||
):
|
||||
"""Test that edge mapping falls back to relationship_type when edge_text is missing."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"relationship_type": "KNOWS"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.85
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine):
|
||||
"""Test edge mapping when no edges match the distance results."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine):
|
||||
"""Test that invalid query vector raises error."""
|
||||
graph = setup_graph
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to generate query embedding"):
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[],
|
||||
edge_distances=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances(setup_graph):
|
||||
"""Test calculating top triplet importances by score."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
node4 = Node("4")
|
||||
|
||||
node1.add_attribute("vector_distance", 0.9)
|
||||
node2.add_attribute("vector_distance", 0.8)
|
||||
node3.add_attribute("vector_distance", 0.7)
|
||||
node4.add_attribute("vector_distance", 0.6)
|
||||
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
graph.add_node(node4)
|
||||
|
||||
edge1 = Edge(node1, node2)
|
||||
edge2 = Edge(node2, node3)
|
||||
edge3 = Edge(node3, node4)
|
||||
|
||||
edge1.add_attribute("vector_distance", 0.85)
|
||||
edge2.add_attribute("vector_distance", 0.75)
|
||||
edge3.add_attribute("vector_distance", 0.65)
|
||||
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
graph.add_edge(edge3)
|
||||
|
||||
top_triplets = await graph.calculate_top_triplet_importances(k=2)
|
||||
|
||||
assert len(top_triplets) == 2
|
||||
|
||||
assert top_triplets[0] == edge3
|
||||
assert top_triplets[1] == edge2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
||||
"""Test calculating importances when nodes/edges have no vector distances."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(node1, node2)
|
||||
graph.add_edge(edge)
|
||||
|
||||
top_triplets = await graph.calculate_top_triplet_importances(k=1)
|
||||
|
||||
assert len(top_triplets) == 1
|
||||
assert top_triplets[0] == edge
|
||||
|
|
|
|||
|
|
@ -0,0 +1,214 @@
|
|||
import sys
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from cognee.tasks.memify.get_triplet_datapoints import get_triplet_datapoints
|
||||
from cognee.modules.engine.models import Triplet
|
||||
from cognee.modules.engine.models.Entity import Entity
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.models.EdgeType import EdgeType
|
||||
|
||||
|
||||
get_triplet_datapoints_module = sys.modules["cognee.tasks.memify.get_triplet_datapoints"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_engine():
|
||||
"""Create a mock graph engine with get_triplets_batch method."""
|
||||
engine = AsyncMock()
|
||||
engine.get_triplets_batch = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_triplet_datapoints_success(mock_graph_engine):
|
||||
"""Test successful extraction of triplet datapoints."""
|
||||
mock_triplets_batch = [
|
||||
{
|
||||
"start_node": {
|
||||
"id": "node1",
|
||||
"type": "Entity",
|
||||
"name": "Alice",
|
||||
"description": "A person",
|
||||
},
|
||||
"end_node": {
|
||||
"id": "node2",
|
||||
"type": "Entity",
|
||||
"name": "Bob",
|
||||
"description": "Another person",
|
||||
},
|
||||
"relationship_properties": {
|
||||
"relationship_name": "knows",
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
|
||||
),
|
||||
patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
|
||||
):
|
||||
mock_get_subclasses.return_value = [Triplet, EdgeType, Entity]
|
||||
|
||||
triplets = []
|
||||
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
|
||||
triplets.append(triplet)
|
||||
|
||||
assert len(triplets) == 1
|
||||
assert isinstance(triplets[0], Triplet)
|
||||
assert triplets[0].from_node_id == "node1"
|
||||
assert triplets[0].to_node_id == "node2"
|
||||
assert "Alice" in triplets[0].text
|
||||
assert "knows" in triplets[0].text
|
||||
assert "Bob" in triplets[0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_triplet_datapoints_edge_text_priority_and_fallback(mock_graph_engine):
|
||||
"""Test that edge_text is prioritized over relationship_name, and fallback works."""
|
||||
|
||||
class MockEntity(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
mock_triplets_batch = [
|
||||
{
|
||||
"start_node": {"id": "node1", "type": "Entity", "name": "Alice"},
|
||||
"end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
|
||||
"relationship_properties": {
|
||||
"relationship_name": "knows",
|
||||
"edge_text": "has a close friendship with",
|
||||
},
|
||||
},
|
||||
{
|
||||
"start_node": {"id": "node3", "type": "Entity", "name": "Charlie"},
|
||||
"end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
|
||||
"relationship_properties": {
|
||||
"relationship_name": "works_with",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
|
||||
),
|
||||
patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
|
||||
):
|
||||
mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
|
||||
|
||||
triplets = []
|
||||
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
|
||||
triplets.append(triplet)
|
||||
|
||||
assert len(triplets) == 2
|
||||
assert "has a close friendship with" in triplets[0].text
|
||||
assert "knows" not in triplets[0].text
|
||||
assert "works_with" in triplets[1].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_triplet_datapoints_skips_missing_node_ids(mock_graph_engine):
|
||||
"""Test that triplets with missing node IDs are skipped."""
|
||||
|
||||
class MockEntity(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
mock_triplets_batch = [
|
||||
{
|
||||
"start_node": {"id": "", "type": "Entity", "name": "Alice"},
|
||||
"end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
|
||||
"relationship_properties": {"relationship_name": "knows"},
|
||||
},
|
||||
{
|
||||
"start_node": {"id": "node3", "type": "Entity", "name": "Charlie"},
|
||||
"end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
|
||||
"relationship_properties": {"relationship_name": "works_with"},
|
||||
},
|
||||
]
|
||||
|
||||
mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
|
||||
),
|
||||
patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
|
||||
):
|
||||
mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
|
||||
|
||||
triplets = []
|
||||
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
|
||||
triplets.append(triplet)
|
||||
|
||||
assert len(triplets) == 1
|
||||
assert triplets[0].from_node_id == "node3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_triplet_datapoints_error_handling(mock_graph_engine):
|
||||
"""Test that errors are handled correctly - invalid data is skipped, query errors propagate."""
|
||||
|
||||
class MockEntity(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
mock_triplets_batch = [
|
||||
{
|
||||
"start_node": {"id": "node1", "type": "Entity", "name": "Alice"},
|
||||
"end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
|
||||
"relationship_properties": {"relationship_name": "knows"},
|
||||
},
|
||||
{
|
||||
"start_node": None,
|
||||
"end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
|
||||
"relationship_properties": {"relationship_name": "works_with"},
|
||||
},
|
||||
]
|
||||
|
||||
mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
|
||||
),
|
||||
patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
|
||||
):
|
||||
mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
|
||||
|
||||
triplets = []
|
||||
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
|
||||
triplets.append(triplet)
|
||||
|
||||
assert len(triplets) == 1
|
||||
assert triplets[0].from_node_id == "node1"
|
||||
|
||||
mock_graph_engine.get_triplets_batch.side_effect = Exception("Database connection error")
|
||||
|
||||
with patch.object(
|
||||
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
|
||||
):
|
||||
triplets = []
|
||||
with pytest.raises(Exception, match="Database connection error"):
|
||||
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
|
||||
triplets.append(triplet)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_triplet_datapoints_no_get_triplets_batch_method(mock_graph_engine):
|
||||
"""Test that NotImplementedError is raised when graph engine lacks get_triplets_batch."""
|
||||
del mock_graph_engine.get_triplets_batch
|
||||
|
||||
with patch.object(
|
||||
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
|
||||
):
|
||||
triplets = []
|
||||
with pytest.raises(NotImplementedError, match="does not support get_triplets_batch"):
|
||||
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
|
||||
triplets.append(triplet)
|
||||
|
|
@ -0,0 +1,582 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
||||
brute_force_triplet_search,
|
||||
get_memory_fragment,
|
||||
)
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
|
||||
|
||||
class MockScoredResult:
|
||||
"""Mock class for vector search results."""
|
||||
|
||||
def __init__(self, id, score, payload=None):
|
||||
self.id = id
|
||||
self.score = score
|
||||
self.payload = payload or {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_empty_query():
|
||||
"""Test that empty query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
await brute_force_triplet_search(query="")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_none_query():
|
||||
"""Test that None query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
await brute_force_triplet_search(query=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_negative_top_k():
|
||||
"""Test that negative top_k raises ValueError."""
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||
await brute_force_triplet_search(query="test query", top_k=-1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_zero_top_k():
|
||||
"""Test that zero top_k raises ValueError."""
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||
await brute_force_triplet_search(query="test query", top_k=0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_limit_global_search():
|
||||
"""Test that wide_search_limit is applied for global search (node_name=None)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=None, # Global search
|
||||
wide_search_top_k=75,
|
||||
)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] == 75
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
|
||||
"""Test that wide_search_limit is None for filtered search (node_name provided)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=["Node1"],
|
||||
wide_search_top_k=50,
|
||||
)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_default():
|
||||
"""Test that wide_search_top_k defaults to 100."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] == 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_default_collections():
|
||||
"""Test that default collections are used when none provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test")
|
||||
|
||||
expected_collections = [
|
||||
"Entity_name",
|
||||
"TextSummary_text",
|
||||
"EntityType_name",
|
||||
"DocumentChunk_text",
|
||||
]
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert call_collections == expected_collections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_custom_collections():
|
||||
"""Test that custom collections are used when provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
custom_collections = ["CustomCol1", "CustomCol2"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=custom_collections)
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert call_collections == custom_collections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_all_collections_empty():
|
||||
"""Test that empty list is returned when all collections return no results."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
results = await brute_force_triplet_search(query="test")
|
||||
assert results == []
|
||||
|
||||
|
||||
# Tests for query embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_embeds_query():
|
||||
"""Test that query is embedded before searching."""
|
||||
query_text = "test query"
|
||||
expected_vector = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query=query_text)
|
||||
|
||||
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["query_vector"] == expected_vector
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_extracts_node_ids_global_search():
|
||||
"""Test that node IDs are extracted from search results for global search."""
|
||||
scored_results = [
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
MockScoredResult("node3", 0.92),
|
||||
]
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=scored_results)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_reuses_provided_fragment():
|
||||
"""Test that provided memory fragment is reused instead of creating new one."""
|
||||
provided_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
memory_fragment=provided_fragment,
|
||||
node_name=["node"],
|
||||
)
|
||||
|
||||
mock_get_fragment.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
|
||||
"""Test that memory fragment is created when not provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=["node"])
|
||||
|
||||
mock_get_fragment.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
|
||||
"""Test that custom top_k is passed to importance calculation."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
),
|
||||
):
|
||||
custom_top_k = 15
|
||||
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
|
||||
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
|
||||
"""Test that get_memory_fragment returns empty graph when entity not found."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.project_graph_from_db = AsyncMock(
|
||||
side_effect=EntityNotFoundError("Entity not found")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
):
|
||||
fragment = await get_memory_fragment()
|
||||
|
||||
assert isinstance(fragment, CogneeGraph)
|
||||
assert len(fragment.nodes) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_fragment_returns_empty_graph_on_error():
|
||||
"""Test that get_memory_fragment returns empty graph on generic error."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
):
|
||||
fragment = await get_memory_fragment()
|
||||
|
||||
assert isinstance(fragment, CogneeGraph)
|
||||
assert len(fragment.nodes) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_deduplicates_node_ids():
|
||||
"""Test that duplicate node IDs across collections are deduplicated."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
]
|
||||
elif collection_name == "TextSummary_text":
|
||||
return [
|
||||
MockScoredResult("node1", 0.90),
|
||||
MockScoredResult("node3", 0.92),
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
|
||||
assert len(call_kwargs["relevant_ids_to_filter"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_excludes_edge_collection():
|
||||
"""Test that EdgeType_relationship_name collection is excluded from ID extraction."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [MockScoredResult("node1", 0.95)]
|
||||
elif collection_name == "EdgeType_relationship_name":
|
||||
return [MockScoredResult("edge1", 0.88)]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=None,
|
||||
collections=["Entity_name", "EdgeType_relationship_name"],
|
||||
)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_skips_nodes_without_ids():
|
||||
"""Test that nodes without ID attribute are skipped."""
|
||||
|
||||
class ScoredResultNoId:
|
||||
"""Mock result without id attribute."""
|
||||
|
||||
def __init__(self, score):
|
||||
self.score = score
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
MockScoredResult("node1", 0.95),
|
||||
ScoredResultNoId(0.90),
|
||||
MockScoredResult("node2", 0.87),
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_handles_tuple_results():
|
||||
"""Test that both list and tuple results are handled correctly."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return (
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_mixed_empty_collections():
|
||||
"""Test ID extraction with mixed empty and non-empty collections."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [MockScoredResult("node1", 0.95)]
|
||||
elif collection_name == "TextSummary_text":
|
||||
return []
|
||||
elif collection_name == "EntityType_name":
|
||||
return [MockScoredResult("node2", 0.92)]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_engine():
|
||||
"""Create a mock vector engine."""
|
||||
engine = AsyncMock()
|
||||
engine.has_collection = AsyncMock(return_value=True)
|
||||
engine.search = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_success(mock_vector_engine):
|
||||
"""Test successful retrieval of triplet context."""
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.payload = {"text": "Alice knows Bob"}
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.payload = {"text": "Bob works at Tech Corp"}
|
||||
|
||||
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
||||
|
||||
retriever = TripletRetriever(top_k=5)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "Alice knows Bob\nBob works at Tech Corp"
|
||||
mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_no_collection(mock_vector_engine):
|
||||
"""Test that NoDataError is raised when Triplet_text collection doesn't exist."""
|
||||
mock_vector_engine.has_collection.return_value = False
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
with pytest.raises(NoDataError, match="create_triplet_embeddings"):
|
||||
await retriever.get_context("test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_empty_results(mock_vector_engine):
|
||||
"""Test that empty string is returned when no triplets are found."""
|
||||
mock_vector_engine.search.return_value = []
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_collection_not_found_error(mock_vector_engine):
|
||||
"""Test that CollectionNotFoundError is converted to NoDataError."""
|
||||
mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found")
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
with pytest.raises(NoDataError, match="No data found"):
|
||||
await retriever.get_context("test query")
|
||||
79
examples/python/triplet_embeddings_example.py
Normal file
79
examples/python/triplet_embeddings_example.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
import asyncio
|
||||
|
||||
import cognee
|
||||
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import setup_logging, INFO
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
|
||||
text_1 = """
|
||||
1. Audi
|
||||
Audi is known for its modern designs and advanced technology. Founded in the early 1900s, the brand has earned a reputation for precision engineering and innovation. With features like the Quattro all-wheel-drive system, Audi offers a range of vehicles from stylish sedans to high-performance sports cars.
|
||||
|
||||
2. BMW
|
||||
BMW, short for Bayerische Motoren Werke, is celebrated for its focus on performance and driving pleasure. The company's vehicles are designed to provide a dynamic and engaging driving experience, and their slogan, "The Ultimate Driving Machine," reflects that commitment. BMW produces a variety of cars that combine luxury with sporty performance.
|
||||
|
||||
3. Mercedes-Benz
|
||||
Mercedes-Benz is synonymous with luxury and quality. With a history dating back to the early 20th century, the brand is known for its elegant designs, innovative safety features, and high-quality engineering. Mercedes-Benz manufactures not only luxury sedans but also SUVs, sports cars, and commercial vehicles, catering to a wide range of needs.
|
||||
|
||||
4. Porsche
|
||||
Porsche is a name that stands for high-performance sports cars. Founded in 1931, the brand has become famous for models like the iconic Porsche 911. Porsche cars are celebrated for their speed, precision, and distinctive design, appealing to car enthusiasts who value both performance and style.
|
||||
|
||||
5. Volkswagen
|
||||
Volkswagen, which means "people's car" in German, was established with the idea of making affordable and reliable vehicles accessible to everyone. Over the years, Volkswagen has produced several iconic models, such as the Beetle and the Golf. Today, it remains one of the largest car manufacturers in the world, offering a wide range of vehicles that balance practicality with quality.
|
||||
|
||||
Each of these car manufacturer contributes to Germany's reputation as a leader in the global automotive industry, showcasing a blend of innovation, performance, and design excellence.
|
||||
"""
|
||||
|
||||
text_2 = """
|
||||
1. Apple
|
||||
Apple is renowned for its innovative consumer electronics and software. Its product lineup includes the iPhone, iPad, Mac computers, and wearables like the Apple Watch. Known for its emphasis on sleek design and user-friendly interfaces, Apple has built a loyal customer base and created a seamless ecosystem that integrates hardware, software, and services.
|
||||
|
||||
2. Google
|
||||
Founded in 1998, Google started as a search engine and quickly became the go-to resource for finding information online. Over the years, the company has diversified its offerings to include digital advertising, cloud computing, mobile operating systems (Android), and various web services like Gmail and Google Maps. Google's innovations have played a major role in shaping the internet landscape.
|
||||
|
||||
3. Microsoft
|
||||
Microsoft Corporation has been a dominant force in software for decades. Its Windows operating system and Microsoft Office suite are staples in both business and personal computing. In recent years, Microsoft has expanded into cloud computing with Azure, gaming with the Xbox platform, and even hardware through products like the Surface line. This evolution has helped the company maintain its relevance in a rapidly changing tech world.
|
||||
|
||||
4. Amazon
|
||||
What began as an online bookstore has grown into one of the largest e-commerce platforms globally. Amazon is known for its vast online marketplace, but its influence extends far beyond retail. With Amazon Web Services (AWS), the company has become a leader in cloud computing, offering robust solutions that power websites, applications, and businesses around the world. Amazon's constant drive for innovation continues to reshape both retail and technology sectors.
|
||||
|
||||
5. Meta
|
||||
Meta, originally known as Facebook, revolutionized social media by connecting billions of people worldwide. Beyond its core social networking service, Meta is investing in the next generation of digital experiences through virtual and augmented reality technologies, with projects like Oculus. The company's efforts signal a commitment to evolving digital interaction and building the metaverse—a shared virtual space where users can connect and collaborate.
|
||||
|
||||
Each of these companies has significantly impacted the technology landscape, driving innovation and transforming everyday life through their groundbreaking products and services.
|
||||
"""
|
||||
|
||||
|
||||
async def main():
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
await cognee.add([text_1, text_2])
|
||||
await cognee.cognify()
|
||||
|
||||
default_user = await get_default_user()
|
||||
|
||||
await create_triplet_embeddings(
|
||||
user=default_user,
|
||||
triplets_batch_size=100,
|
||||
)
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.TRIPLET_COMPLETION,
|
||||
query_text="What are the models produced by Volkswagen based on the context?",
|
||||
)
|
||||
print(search_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=INFO)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
8047
poetry.lock
generated
8047
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -22,7 +22,7 @@ classifiers = [
|
|||
dependencies = [
|
||||
"openai>=1.80.1",
|
||||
"python-dotenv>=1.0.1,<2.0.0",
|
||||
"pydantic>=2.10.5,<3.0.0",
|
||||
"pydantic>=2.10.5,<2.12.0",
|
||||
"pydantic-settings>=2.2.1,<3",
|
||||
"typing_extensions>=4.12.2,<5.0.0",
|
||||
"numpy>=1.26.4, <=4.0.0",
|
||||
|
|
@ -33,7 +33,7 @@ dependencies = [
|
|||
"instructor>=1.9.1,<2.0.0",
|
||||
"filetype>=1.2.0,<2.0.0",
|
||||
"aiohttp>=3.11.14,<4.0.0",
|
||||
"aiofiles>=23.2.1,<24.0.0",
|
||||
"aiofiles>=23.2.1",
|
||||
"rdflib>=7.1.4,<7.2.0",
|
||||
"pypdf>=4.1.0,<7.0.0",
|
||||
"jinja2>=3.1.3,<4",
|
||||
|
|
@ -199,8 +199,3 @@ exclude = [
|
|||
|
||||
[tool.ruff.lint]
|
||||
ignore = ["F401"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest-timeout>=2.4.0",
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue