Merge branch 'dev' into main

This commit is contained in:
Vasilije 2025-10-29 20:34:58 +01:00 committed by GitHub
commit 9865e4c5ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
49 changed files with 6165 additions and 4425 deletions

View file

@ -358,6 +358,34 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/tasks/entity_extraction/entity_extraction_test.py
test-feedback-enrichment:
name: Test Feedback Enrichment
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
- name: Run Feedback Enrichment Test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_feedback_enrichment.py
run_conversation_sessions_test:
name: Conversation sessions test
runs-on: ubuntu-latest

View file

@ -110,6 +110,47 @@ If you'd rather run cognee-mcp in a container, you have two options:
# For stdio transport (default)
docker run -e TRANSPORT_MODE=stdio --env-file ./.env --rm -it cognee/cognee-mcp:main
```
**Installing optional dependencies at runtime:**
You can install optional dependencies when running the container by setting the `EXTRAS` environment variable:
```bash
# Install a single optional dependency group at runtime
docker run \
-e TRANSPORT_MODE=http \
-e EXTRAS=aws \
--env-file ./.env \
-p 8000:8000 \
--rm -it cognee/cognee-mcp:main
# Install multiple optional dependency groups at runtime (comma-separated)
docker run \
-e TRANSPORT_MODE=sse \
-e EXTRAS=aws,postgres,neo4j \
--env-file ./.env \
-p 8000:8000 \
--rm -it cognee/cognee-mcp:main
```
**Available optional dependency groups:**
- `aws` - S3 storage support
- `postgres` / `postgres-binary` - PostgreSQL database support
- `neo4j` - Neo4j graph database support
- `neptune` - AWS Neptune support
- `chromadb` - ChromaDB vector store support
- `scraping` - Web scraping capabilities
- `distributed` - Modal distributed execution
- `langchain` - LangChain integration
- `llama-index` - LlamaIndex integration
- `anthropic` - Anthropic models
- `groq` - Groq models
- `mistral` - Mistral models
- `ollama` / `huggingface` - Local model support
- `docs` - Document processing
- `codegraph` - Code analysis
- `monitoring` - Sentry & Langfuse monitoring
- `redis` - Redis support
- And more (see [pyproject.toml](https://github.com/topoteretes/cognee/blob/main/pyproject.toml) for full list)
2. **Pull from Docker Hub** (no build required):
```bash
# With HTTP transport (recommended for web deployments)
@ -119,6 +160,17 @@ If you'd rather run cognee-mcp in a container, you have two options:
# With stdio transport (default)
docker run -e TRANSPORT_MODE=stdio --env-file ./.env --rm -it cognee/cognee-mcp:main
```
**With runtime installation of optional dependencies:**
```bash
# Install optional dependencies from Docker Hub image
docker run \
-e TRANSPORT_MODE=http \
-e EXTRAS=aws,postgres \
--env-file ./.env \
-p 8000:8000 \
--rm -it cognee/cognee-mcp:main
```
### **Important: Docker vs Direct Usage**
**Docker uses environment variables**, not command line arguments:

View file

@ -4,6 +4,42 @@ set -e # Exit on error
echo "Debug mode: $DEBUG"
echo "Environment: $ENVIRONMENT"
# Install optional dependencies if EXTRAS is set
if [ -n "$EXTRAS" ]; then
echo "Installing optional dependencies: $EXTRAS"
# Get the cognee version that's currently installed
COGNEE_VERSION=$(uv pip show cognee | grep "Version:" | awk '{print $2}')
echo "Current cognee version: $COGNEE_VERSION"
# Build the extras list for cognee
IFS=',' read -ra EXTRA_ARRAY <<< "$EXTRAS"
# Combine base extras from pyproject.toml with requested extras
ALL_EXTRAS=""
for extra in "${EXTRA_ARRAY[@]}"; do
# Trim whitespace
extra=$(echo "$extra" | xargs)
# Add to extras list if not already present
if [[ ! "$ALL_EXTRAS" =~ (^|,)"$extra"(,|$) ]]; then
if [ -z "$ALL_EXTRAS" ]; then
ALL_EXTRAS="$extra"
else
ALL_EXTRAS="$ALL_EXTRAS,$extra"
fi
fi
done
echo "Installing cognee with extras: $ALL_EXTRAS"
echo "Running: uv pip install 'cognee[$ALL_EXTRAS]==$COGNEE_VERSION'"
uv pip install "cognee[$ALL_EXTRAS]==$COGNEE_VERSION"
# Verify installation
echo ""
echo "✓ Optional dependencies installation completed"
else
echo "No optional dependencies specified"
fi
# Set default transport mode if not specified
TRANSPORT_MODE=${TRANSPORT_MODE:-"stdio"}
echo "Transport mode: $TRANSPORT_MODE"

View file

@ -10,6 +10,7 @@ from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
from cognee import __version__ as cognee_version
logger = get_logger()
@ -63,7 +64,11 @@ def get_add_router() -> APIRouter:
send_telemetry(
"Add API Endpoint Invoked",
user.id,
additional_properties={"endpoint": "POST /v1/add", "node_set": node_set},
additional_properties={
"endpoint": "POST /v1/add",
"node_set": node_set,
"cognee_version": cognee_version,
},
)
from cognee.api.v1.add import add as cognee_add

View file

@ -29,7 +29,7 @@ from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
)
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
logger = get_logger("api.cognify")
@ -98,6 +98,7 @@ def get_cognify_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "POST /v1/cognify",
"cognee_version": cognee_version,
},
)

View file

@ -24,6 +24,7 @@ from cognee.modules.users.permissions.methods import (
from cognee.modules.graph.methods import get_formatted_graph_data
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
logger = get_logger()
@ -100,6 +101,7 @@ def get_datasets_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "GET /v1/datasets",
"cognee_version": cognee_version,
},
)
@ -147,6 +149,7 @@ def get_datasets_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "POST /v1/datasets",
"cognee_version": cognee_version,
},
)
@ -201,6 +204,7 @@ def get_datasets_router() -> APIRouter:
additional_properties={
"endpoint": f"DELETE /v1/datasets/{str(dataset_id)}",
"dataset_id": str(dataset_id),
"cognee_version": cognee_version,
},
)
@ -246,6 +250,7 @@ def get_datasets_router() -> APIRouter:
"endpoint": f"DELETE /v1/datasets/{str(dataset_id)}/data/{str(data_id)}",
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"cognee_version": cognee_version,
},
)
@ -327,6 +332,7 @@ def get_datasets_router() -> APIRouter:
additional_properties={
"endpoint": f"GET /v1/datasets/{str(dataset_id)}/data",
"dataset_id": str(dataset_id),
"cognee_version": cognee_version,
},
)
@ -387,6 +393,7 @@ def get_datasets_router() -> APIRouter:
additional_properties={
"endpoint": "GET /v1/datasets/status",
"datasets": [str(dataset_id) for dataset_id in datasets],
"cognee_version": cognee_version,
},
)
@ -433,6 +440,7 @@ def get_datasets_router() -> APIRouter:
"endpoint": f"GET /v1/datasets/{str(dataset_id)}/data/{str(data_id)}/raw",
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"cognee_version": cognee_version,
},
)

View file

@ -6,6 +6,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
logger = get_logger()
@ -39,6 +40,7 @@ def get_delete_router() -> APIRouter:
"endpoint": "DELETE /v1/delete",
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"cognee_version": cognee_version,
},
)

View file

@ -12,6 +12,7 @@ from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
from cognee import __version__ as cognee_version
logger = get_logger()
@ -73,7 +74,7 @@ def get_memify_router() -> APIRouter:
send_telemetry(
"Memify API Endpoint Invoked",
user.id,
additional_properties={"endpoint": "POST /v1/memify"},
additional_properties={"endpoint": "POST /v1/memify", "cognee_version": cognee_version},
)
if not payload.dataset_id and not payload.dataset_name:

View file

@ -7,6 +7,7 @@ from fastapi.responses import JSONResponse
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
def get_permissions_router() -> APIRouter:
@ -48,6 +49,7 @@ def get_permissions_router() -> APIRouter:
"endpoint": f"POST /v1/permissions/datasets/{str(principal_id)}",
"dataset_ids": str(dataset_ids),
"principal_id": str(principal_id),
"cognee_version": cognee_version,
},
)
@ -89,6 +91,7 @@ def get_permissions_router() -> APIRouter:
additional_properties={
"endpoint": "POST /v1/permissions/roles",
"role_name": role_name,
"cognee_version": cognee_version,
},
)
@ -133,6 +136,7 @@ def get_permissions_router() -> APIRouter:
"endpoint": f"POST /v1/permissions/users/{str(user_id)}/roles",
"user_id": str(user_id),
"role_id": str(role_id),
"cognee_version": cognee_version,
},
)
@ -175,6 +179,7 @@ def get_permissions_router() -> APIRouter:
"endpoint": f"POST /v1/permissions/users/{str(user_id)}/tenants",
"user_id": str(user_id),
"tenant_id": str(tenant_id),
"cognee_version": cognee_version,
},
)
@ -209,6 +214,7 @@ def get_permissions_router() -> APIRouter:
additional_properties={
"endpoint": "POST /v1/permissions/tenants",
"tenant_name": tenant_name,
"cognee_version": cognee_version,
},
)

View file

@ -13,6 +13,7 @@ from cognee.modules.users.models import User
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
# Note: Datasets sent by name will only map to datasets owned by the request sender
@ -61,9 +62,7 @@ def get_search_router() -> APIRouter:
send_telemetry(
"Search API Endpoint Invoked",
user.id,
additional_properties={
"endpoint": "GET /v1/search",
},
additional_properties={"endpoint": "GET /v1/search", "cognee_version": cognee_version},
)
try:
@ -118,6 +117,7 @@ def get_search_router() -> APIRouter:
"top_k": payload.top_k,
"only_context": payload.only_context,
"use_combined_context": payload.use_combined_context,
"cognee_version": cognee_version,
},
)

View file

@ -12,6 +12,7 @@ from cognee.modules.sync.methods import get_running_sync_operations_for_user, ge
from cognee.shared.utils import send_telemetry
from cognee.shared.logging_utils import get_logger
from cognee.api.v1.sync import SyncResponse
from cognee import __version__ as cognee_version
from cognee.context_global_variables import set_database_global_context_variables
logger = get_logger()
@ -99,6 +100,7 @@ def get_sync_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "POST /v1/sync",
"cognee_version": cognee_version,
"dataset_ids": [str(id) for id in request.dataset_ids]
if request.dataset_ids
else "*",
@ -205,6 +207,7 @@ def get_sync_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "GET /v1/sync/status",
"cognee_version": cognee_version,
},
)

View file

@ -503,7 +503,7 @@ def start_ui(
if start_mcp:
logger.info("Starting Cognee MCP server with Docker...")
try:
image = "cognee/cognee-mcp:feature-standalone-mcp" # TODO: change to "cognee/cognee-mcp:main" right before merging into main
image = "cognee/cognee-mcp:main"
subprocess.run(["docker", "pull", image], check=True)
import uuid
@ -538,9 +538,7 @@ def start_ui(
env_file = os.path.join(cwd, ".env")
docker_cmd.extend(["--env-file", env_file])
docker_cmd.append(
image
) # TODO: change to "cognee/cognee-mcp:main" right before merging into main
docker_cmd.append(image)
mcp_process = subprocess.Popen(
docker_cmd,

View file

@ -9,6 +9,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
from cognee.modules.pipelines.models.PipelineRunInfo import (
PipelineRunErrored,
)
@ -64,6 +65,7 @@ def get_update_router() -> APIRouter:
"dataset_id": str(dataset_id),
"data_id": str(data_id),
"node_set": str(node_set),
"cognee_version": cognee_version,
},
)

View file

@ -8,6 +8,7 @@ from cognee.modules.users.models import User
from cognee.context_global_variables import set_database_global_context_variables
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
logger = get_logger()
@ -46,6 +47,7 @@ def get_visualize_router() -> APIRouter:
additional_properties={
"endpoint": "GET /v1/visualize",
"dataset_id": str(dataset_id),
"cognee_version": cognee_version,
},
)

View file

@ -1,4 +1,5 @@
import os
from pathlib import Path
from typing import Optional
from functools import lru_cache
from cognee.root_dir import get_absolute_path, ensure_absolute_path
@ -11,6 +12,9 @@ class BaseConfig(BaseSettings):
data_root_directory: str = get_absolute_path(".data_storage")
system_root_directory: str = get_absolute_path(".cognee_system")
cache_root_directory: str = get_absolute_path(".cognee_cache")
logs_root_directory: str = os.getenv(
"COGNEE_LOGS_DIR", str(os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs"))
)
monitoring_tool: object = Observer.NONE
@pydantic.model_validator(mode="after")
@ -30,6 +34,8 @@ class BaseConfig(BaseSettings):
# Require absolute paths for root directories
self.data_root_directory = ensure_absolute_path(self.data_root_directory)
self.system_root_directory = ensure_absolute_path(self.system_root_directory)
self.logs_root_directory = ensure_absolute_path(self.logs_root_directory)
# Set monitoring tool based on available keys
if self.langfuse_public_key and self.langfuse_secret_key:
self.monitoring_tool = Observer.LANGFUSE
@ -49,6 +55,7 @@ class BaseConfig(BaseSettings):
"system_root_directory": self.system_root_directory,
"monitoring_tool": self.monitoring_tool,
"cache_root_directory": self.cache_root_directory,
"logs_root_directory": self.logs_root_directory,
}

View file

@ -1366,9 +1366,15 @@ class KuzuAdapter(GraphDBInterface):
params[param_name] = values
where_clause = " AND ".join(where_clauses)
nodes_query = (
f"MATCH (n:Node) WHERE {where_clause} RETURN n.id, {{properties: n.properties}}"
)
nodes_query = f"""
MATCH (n:Node)
WHERE {where_clause}
RETURN n.id, {{
name: n.name,
type: n.type,
properties: n.properties
}}
"""
edges_query = f"""
MATCH (n1:Node)-[r:EDGE]->(n2:Node)
WHERE {where_clause.replace("n.", "n1.")} AND {where_clause.replace("n.", "n2.")}

View file

@ -47,7 +47,7 @@ def create_vector_engine(
embedding_engine=embedding_engine,
)
if vector_db_provider == "pgvector":
if vector_db_provider.lower() == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config
# Get configuration for postgres database
@ -78,7 +78,7 @@ def create_vector_engine(
embedding_engine,
)
elif vector_db_provider == "chromadb":
elif vector_db_provider.lower() == "chromadb":
try:
import chromadb
except ImportError:
@ -94,7 +94,7 @@ def create_vector_engine(
embedding_engine=embedding_engine,
)
elif vector_db_provider == "neptune_analytics":
elif vector_db_provider.lower() == "neptune_analytics":
try:
from langchain_aws import NeptuneAnalyticsGraph
except ImportError:
@ -122,7 +122,7 @@ def create_vector_engine(
embedding_engine=embedding_engine,
)
else:
elif vector_db_provider.lower() == "lancedb":
from .lancedb.LanceDBAdapter import LanceDBAdapter
return LanceDBAdapter(
@ -130,3 +130,9 @@ def create_vector_engine(
api_key=vector_db_key,
embedding_engine=embedding_engine,
)
else:
raise EnvironmentError(
f"Unsupported graph database provider: {vector_db_provider}. "
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['LanceDB', 'PGVector', 'neptune_analytics', 'ChromaDB'])}"
)

View file

@ -22,89 +22,6 @@ class FileTypeException(Exception):
self.message = message
class TxtFileType(filetype.Type):
"""
Represents a text file type with specific MIME and extension properties.
Public methods:
- match: Determines whether a given buffer matches the text file type.
"""
MIME = "text/plain"
EXTENSION = "txt"
def __init__(self):
super(TxtFileType, self).__init__(mime=TxtFileType.MIME, extension=TxtFileType.EXTENSION)
def match(self, buf):
"""
Determine if the given buffer contains text content.
Parameters:
-----------
- buf: The buffer to check for text content.
Returns:
--------
Returns True if the buffer is identified as text content, otherwise False.
"""
return is_text_content(buf)
txt_file_type = TxtFileType()
filetype.add_type(txt_file_type)
class CustomPdfMatcher(filetype.Type):
"""
Match PDF file types based on MIME type and extension.
Public methods:
- match
Instance variables:
- MIME: The MIME type of the PDF.
- EXTENSION: The file extension of the PDF.
"""
MIME = "application/pdf"
EXTENSION = "pdf"
def __init__(self):
super(CustomPdfMatcher, self).__init__(
mime=CustomPdfMatcher.MIME, extension=CustomPdfMatcher.EXTENSION
)
def match(self, buf):
"""
Determine if the provided buffer is a PDF file.
This method checks for the presence of the PDF signature in the buffer.
Raises:
- TypeError: If the buffer is not of bytes type.
Parameters:
-----------
- buf: The buffer containing the data to be checked.
Returns:
--------
Returns True if the buffer contains a PDF signature, otherwise returns False.
"""
return b"PDF-" in buf
custom_pdf_matcher = CustomPdfMatcher()
filetype.add_type(custom_pdf_matcher)
def guess_file_type(file: BinaryIO) -> filetype.Type:
"""
Guess the file type from the given binary file stream.

View file

@ -1,15 +1,13 @@
For the purposes of identifying timestamps in a query, you are tasked with extracting relevant timestamps from the query.
## Timestamp requirements
- If the query contains interval extrack both starts_at and ends_at properties
- If the query contains an instantaneous timestamp, starts_at and ends_at should be the same
- If the query its open-ended (before 2009 or after 2009), the corresponding non defined end of the time should be none
-For example: "before 2009" -- starts_at: None, ends_at: 2009 or "after 2009" -- starts_at: 2009, ends_at: None
- Put always the data that comes first in time as starts_at and the timestamps that comes second in time as ends_at
- If starts_at or ends_at cannot be extracted both of them has to be None
## Output Format
Your reply should be a JSON: list of dictionaries with the following structure:
```python
class QueryInterval(BaseModel):
starts_at: Optional[Timestamp] = None
ends_at: Optional[Timestamp] = None
```
You are tasked with identifying relevant time periods where the answer to a given query should be searched.
Current date is: `{{ time_now }}`. Determine relevant period(s) and return structured intervals.
Extraction rules:
1. Query without specific timestamp: use the time period with starts_at set to None and ends_at set to now.
2. Explicit time intervals: If the query specifies a range (e.g., from 2010 to 2020, between January and March 2023), extract both start and end dates. Always assign the earlier date to starts_at and the later date to ends_at.
3. Single timestamp: If the query refers to one specific moment (e.g., in 2015, on March 5, 2022), set starts_at and ends_at to that same timestamp.
4. Open-ended time references: For phrases such as "before X" or "after X", represent the unspecified side as None. For example: before 2009 → starts_at: None, ends_at: 2009; after 2009 → starts_at: 2009, ends_at: None.
5. Current-time references ("now", "current", "today"): If the query explicitly refers to the present, set both starts_at and ends_at to now (the ingestion timestamp).
6. "Who is" and "Who was" questions: These imply a general identity or biographical inquiry without a specific temporal scope. Set both starts_at and ends_at to None.
7. Ordering rule: Always ensure the earlier date is assigned to starts_at and the later date to ends_at.
8. No temporal information: If no valid or inferable time reference is found, set both starts_at and ends_at to None.

View file

@ -0,0 +1,14 @@
A question was previously answered, but the answer received negative feedback.
Please reconsider and improve the response.
Question: {question}
Context originally used: {context}
Previous answer: {wrong_answer}
Feedback on that answer: {negative_feedback}
Task: Provide a better response. The new answer should be short and direct.
Then explain briefly why this answer is better.
Format your reply as:
Answer: <improved answer>
Explanation: <short explanation>

View file

@ -0,0 +1,13 @@
Write a concise, stand-alone paragraph that explains the correct answer to the question below.
The paragraph should read naturally on its own, providing all necessary context and reasoning
so the answer is clear and well-supported.
Question: {question}
Correct answer: {improved_answer}
Supporting context: {new_context}
Your paragraph should:
- First sentence clearly states the correct answer as a full sentence
- Remainder flows from first sentence and provides explanation based on context
- Use simple, direct language that is easy to follow
- Use shorter sentences, no long-winded explanations

View file

@ -0,0 +1,5 @@
Question: {question}
Context: {context}
Provide a one paragraph human readable summary of this interaction context,
listing all the relevant facts and information in a simple and direct way.

View file

@ -1,6 +1,7 @@
import filetype
from typing import Dict, List, Optional, Any
from .LoaderInterface import LoaderInterface
from cognee.infrastructure.files.utils.guess_file_type import guess_file_type
from cognee.shared.logging_utils import get_logger
logger = get_logger(__name__)
@ -80,7 +81,7 @@ class LoaderEngine:
"""
from pathlib import Path
file_info = filetype.guess(file_path)
file_info = guess_file_type(file_path)
path_extension = Path(file_path).suffix.lstrip(".")

View file

@ -21,7 +21,8 @@ def get_ontology_resolver_from_env(
Supported value: "rdflib".
matching_strategy (str): The matching strategy to apply.
Supported value: "fuzzy".
ontology_file_path (str): Path to the ontology file required for the resolver.
ontology_file_path (str): Path to the ontology file(s) required for the resolver.
Can be a single path or comma-separated paths for multiple files.
Returns:
BaseOntologyResolver: An instance of the requested ontology resolver.
@ -31,8 +32,13 @@ def get_ontology_resolver_from_env(
or if required parameters are missing.
"""
if ontology_resolver == "rdflib" and matching_strategy == "fuzzy" and ontology_file_path:
if "," in ontology_file_path:
file_paths = [path.strip() for path in ontology_file_path.split(",")]
else:
file_paths = ontology_file_path
return RDFLibOntologyResolver(
matching_strategy=FuzzyMatchingStrategy(), ontology_file=ontology_file_path
matching_strategy=FuzzyMatchingStrategy(), ontology_file=file_paths
)
else:
raise EnvironmentError(

View file

@ -2,7 +2,7 @@ import os
import difflib
from cognee.shared.logging_utils import get_logger
from collections import deque
from typing import List, Tuple, Dict, Optional, Any
from typing import List, Tuple, Dict, Optional, Any, Union
from rdflib import Graph, URIRef, RDF, RDFS, OWL
from cognee.modules.ontology.exceptions import (
@ -26,22 +26,50 @@ class RDFLibOntologyResolver(BaseOntologyResolver):
def __init__(
self,
ontology_file: Optional[str] = None,
ontology_file: Optional[Union[str, List[str]]] = None,
matching_strategy: Optional[MatchingStrategy] = None,
) -> None:
super().__init__(matching_strategy)
self.ontology_file = ontology_file
try:
if ontology_file and os.path.exists(ontology_file):
files_to_load = []
if ontology_file is not None:
if isinstance(ontology_file, str):
files_to_load = [ontology_file]
elif isinstance(ontology_file, list):
files_to_load = ontology_file
else:
raise ValueError(
f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}"
)
if files_to_load:
self.graph = Graph()
self.graph.parse(ontology_file)
logger.info("Ontology loaded successfully from file: %s", ontology_file)
loaded_files = []
for file_path in files_to_load:
if os.path.exists(file_path):
self.graph.parse(file_path)
loaded_files.append(file_path)
logger.info("Ontology loaded successfully from file: %s", file_path)
else:
logger.warning(
"Ontology file '%s' not found. Skipping this file.",
file_path,
)
if not loaded_files:
logger.info(
"No valid ontology files found. No owl ontology will be attached to the graph."
)
self.graph = None
else:
logger.info("Total ontology files loaded: %d", len(loaded_files))
else:
logger.info(
"Ontology file '%s' not found. No owl ontology will be attached to the graph.",
ontology_file,
"No ontology file provided. No owl ontology will be attached to the graph."
)
self.graph = None
self.build_lookup()
except Exception as e:
logger.error("Failed to load ontology", exc_info=e)

View file

@ -2,6 +2,7 @@ import inspect
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
from ..tasks.task import Task
@ -25,6 +26,8 @@ async def handle_task(
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
},
)
@ -46,6 +49,8 @@ async def handle_task(
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
},
)
except Exception as error:
@ -58,6 +63,8 @@ async def handle_task(
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
},
)
raise error

View file

@ -4,6 +4,7 @@ from cognee.modules.settings import get_current_settings
from cognee.modules.users.models import User
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
from .run_tasks_base import run_tasks_base
from ..tasks.task import Task
@ -26,6 +27,8 @@ async def run_tasks_with_telemetry(
user.id,
additional_properties={
"pipeline_name": str(pipeline_name),
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
}
| config,
)
@ -39,7 +42,10 @@ async def run_tasks_with_telemetry(
user.id,
additional_properties={
"pipeline_name": str(pipeline_name),
},
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
}
| config,
)
except Exception as error:
logger.error(
@ -53,6 +59,8 @@ async def run_tasks_with_telemetry(
user.id,
additional_properties={
"pipeline_name": str(pipeline_name),
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
}
| config,
)

View file

@ -1,10 +1,15 @@
import asyncio
import json
from typing import Optional, List, Type, Any
from pydantic import BaseModel
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.completion import (
generate_structured_completion,
summarize_text,
)
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
get_conversation_history,
@ -17,6 +22,20 @@ from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger()
def _as_answer_text(completion: Any) -> str:
"""Convert completion to human-readable text for validation and follow-up prompts."""
if isinstance(completion, str):
return completion
if isinstance(completion, BaseModel):
# Add notice that this is a structured response
json_str = completion.model_dump_json(indent=2)
return f"[Structured Response]\n{json_str}"
try:
return json.dumps(completion, indent=2)
except TypeError:
return str(completion)
class GraphCompletionCotRetriever(GraphCompletionRetriever):
"""
Handles graph completion by generating responses based on a series of interactions with
@ -25,6 +44,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
questions based on reasoning. The public methods are:
- get_completion
- get_structured_completion
Instance variables include:
- validation_system_prompt_path
@ -61,6 +81,155 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
self.followup_system_prompt_path = followup_system_prompt_path
self.followup_user_prompt_path = followup_user_prompt_path
async def _run_cot_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
conversation_history: str = "",
max_iter: int = 4,
response_model: Type = str,
) -> tuple[Any, str, List[Edge]]:
"""
Run chain-of-thought completion with optional structured output.
Parameters:
-----------
- query: User query
- context: Optional pre-fetched context edges
- conversation_history: Optional conversation history string
- max_iter: Maximum CoT iterations
- response_model: Type for structured output (str for plain text)
Returns:
--------
- completion_result: The generated completion (string or structured model)
- context_text: The resolved context text
- triplets: The list of triplets used
"""
followup_question = ""
triplets = []
completion = ""
for round_idx in range(max_iter + 1):
if round_idx == 0:
if context is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
else:
context_text = await self.resolve_edges_to_text(context)
else:
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_structured_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history if conversation_history else None,
response_model=response_model,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
answer_text = _as_answer_text(completion)
valid_args = {"query": query, "answer": answer_text, "context": context_text}
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning = await LLMGateway.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question = await LLMGateway.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
logger.info(
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
return completion, context_text, triplets
async def get_structured_completion(
self,
query: str,
context: Optional[List[Edge]] = None,
session_id: Optional[str] = None,
max_iter: int = 4,
response_model: Type = str,
) -> Any:
"""
Generate structured completion responses based on a user query and contextual information.
This method applies the same chain-of-thought logic as get_completion but returns
structured output using the provided response model.
Parameters:
-----------
- query (str): The user's query to be processed and answered.
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
If not provided, it will be fetched based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- max_iter: The maximum number of iterations to refine the answer and generate
follow-up questions. (default 4)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
- Any: The generated structured completion based on the response model.
"""
# Check if session saving is enabled
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
# Load conversation history if enabled
conversation_history = ""
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
completion, context_text, triplets = await self._run_cot_completion(
query=query,
context=context,
conversation_history=conversation_history,
max_iter=max_iter,
response_model=response_model,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=str(completion), context=context_text, triplets=triplets
)
# Save to session cache if enabled
if session_save:
context_summary = await summarize_text(context_text)
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=str(completion),
session_id=session_id,
)
return completion
async def get_completion(
self,
query: str,
@ -92,82 +261,12 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer to the user's query.
"""
followup_question = ""
triplets = []
completion = ""
# Retrieve conversation history if session saving is enabled
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
conversation_history = ""
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
for round_idx in range(max_iter + 1):
if round_idx == 0:
if context is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
else:
context_text = await self.resolve_edges_to_text(context)
else:
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_completion(
query=query,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history if session_save else None,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
valid_args = {"query": query, "answer": completion, "context": context_text}
valid_user_prompt = render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
valid_system_prompt = read_query_prompt(
prompt_file_name=self.validation_system_prompt_path
)
reasoning = await LLMGateway.acreate_structured_output(
text_input=valid_user_prompt,
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
followup_prompt = render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
followup_system = read_query_prompt(
prompt_file_name=self.followup_system_prompt_path
)
followup_question = await LLMGateway.acreate_structured_output(
text_input=followup_prompt, system_prompt=followup_system, response_model=str
)
logger.info(
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context_text, triplets=triplets
)
# Save to session cache
if session_save:
context_summary = await summarize_text(context_text)
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
completion = await self.get_structured_completion(
query=query,
context=context,
session_id=session_id,
max_iter=max_iter,
response_model=str,
)
return [completion]

View file

@ -1,7 +1,7 @@
import os
import asyncio
from typing import Any, Optional, List, Type
from datetime import datetime
from operator import itemgetter
from cognee.infrastructure.databases.vector import get_vector_engine
@ -79,7 +79,11 @@ class TemporalRetriever(GraphCompletionRetriever):
else:
base_directory = None
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
time_now = datetime.now().strftime("%d-%m-%Y")
system_prompt = render_prompt(
prompt_path, {"time_now": time_now}, base_directory=base_directory
)
interval = await LLMGateway.acreate_structured_output(query, system_prompt, QueryInterval)
@ -108,8 +112,6 @@ class TemporalRetriever(GraphCompletionRetriever):
graph_engine = await get_graph_engine()
triplets = []
if time_from and time_to:
ids = await graph_engine.collect_time_ids(time_from=time_from, time_to=time_to)
elif time_from:

View file

@ -1,17 +1,18 @@
from typing import Optional
from typing import Optional, Type, Any
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
async def generate_completion(
async def generate_structured_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
response_model: Type = str,
) -> Any:
"""Generates a structured completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args)
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
@ -23,6 +24,26 @@ async def generate_completion(
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=response_model,
)
async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
conversation_history: Optional[str] = None,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
return await generate_structured_completion(
query=query,
context=context,
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
conversation_history=conversation_history,
response_model=str,
)

View file

@ -24,7 +24,7 @@ from cognee.modules.data.models import Dataset
from cognee.modules.data.methods.get_authorized_existing_datasets import (
get_authorized_existing_datasets,
)
from cognee import __version__ as cognee_version
from .get_search_type_tools import get_search_type_tools
from .no_access_control_search import no_access_control_search
from ..utils.prepare_search_result import prepare_search_result
@ -64,7 +64,14 @@ async def search(
Searching by dataset is only available in ENABLE_BACKEND_ACCESS_CONTROL mode
"""
query = await log_query(query_text, query_type.value, user.id)
send_telemetry("cognee.search EXECUTION STARTED", user.id)
send_telemetry(
"cognee.search EXECUTION STARTED",
user.id,
additional_properties={
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
},
)
# Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
@ -101,7 +108,14 @@ async def search(
)
]
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
send_telemetry(
"cognee.search EXECUTION COMPLETED",
user.id,
additional_properties={
"cognee_version": cognee_version,
"tenant_id": str(user.tenant_id) if user.tenant_id else "Single User Tenant",
},
)
await log_result(
query.id,

View file

@ -16,17 +16,17 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
nodes_list = []
color_map = {
"Entity": "#f47710",
"EntityType": "#6510f4",
"DocumentChunk": "#801212",
"TextSummary": "#1077f4",
"TableRow": "#f47710",
"TableType": "#6510f4",
"ColumnValue": "#13613a",
"SchemaTable": "#f47710",
"DatabaseSchema": "#6510f4",
"SchemaRelationship": "#13613a",
"default": "#D3D3D3",
"Entity": "#5C10F4",
"EntityType": "#A550FF",
"DocumentChunk": "#0DFF00",
"TextSummary": "#5C10F4",
"TableRow": "#A550FF",
"TableType": "#5C10F4",
"ColumnValue": "#757470",
"SchemaTable": "#A550FF",
"DatabaseSchema": "#5C10F4",
"SchemaRelationship": "#323332",
"default": "#D8D8D8",
}
for node_id, node_info in nodes_data:
@ -98,16 +98,19 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
<head>
<meta charset="utf-8">
<script src="https://d3js.org/d3.v5.min.js"></script>
<script src="https://d3js.org/d3-contour.v1.min.js"></script>
<style>
body, html { margin: 0; padding: 0; width: 100%; height: 100%; overflow: hidden; background: linear-gradient(90deg, #101010, #1a1a2e); color: white; font-family: 'Inter', sans-serif; }
svg { width: 100vw; height: 100vh; display: block; }
.links line { stroke: rgba(255, 255, 255, 0.4); stroke-width: 2px; }
.links line.weighted { stroke: rgba(255, 215, 0, 0.7); }
.links line.multi-weighted { stroke: rgba(0, 255, 127, 0.8); }
.nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); }
.node-label { font-size: 5px; font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
.edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
.links line { stroke: rgba(160, 160, 160, 0.25); stroke-width: 1.5px; stroke-linecap: round; }
.links line.weighted { stroke: rgba(255, 215, 0, 0.4); }
.links line.multi-weighted { stroke: rgba(0, 255, 127, 0.45); }
.nodes circle { stroke: white; stroke-width: 0.5px; }
.node-label { font-size: 5px; font-weight: bold; fill: #F4F4F4; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
.edge-label { font-size: 3px; fill: #F4F4F4; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; paint-order: stroke; stroke: rgba(50,51,50,0.75); stroke-width: 1px; }
.density path { mix-blend-mode: screen; }
.tooltip {
position: absolute;
@ -125,11 +128,32 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
max-width: 300px;
word-wrap: break-word;
}
#info-panel {
position: fixed;
left: 12px;
top: 12px;
width: 340px;
max-height: calc(100vh - 24px);
overflow: auto;
background: rgba(50, 51, 50, 0.7);
backdrop-filter: blur(6px);
border: 1px solid rgba(216, 216, 216, 0.35);
border-radius: 8px;
color: #F4F4F4;
padding: 12px 14px;
z-index: 1100;
}
#info-panel h3 { margin: 0 0 8px 0; font-size: 14px; color: #F4F4F4; }
#info-panel .kv { font-size: 12px; line-height: 1.4; }
#info-panel .kv .k { color: #D8D8D8; }
#info-panel .kv .v { color: #F4F4F4; }
#info-panel .placeholder { opacity: 0.7; font-size: 12px; }
</style>
</head>
<body>
<svg></svg>
<div class="tooltip" id="tooltip"></div>
<div id="info-panel"><div class="placeholder">Hover a node or edge to inspect details</div></div>
<script>
var nodes = {nodes};
var links = {links};
@ -140,19 +164,141 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
var container = svg.append("g");
var tooltip = d3.select("#tooltip");
var infoPanel = d3.select('#info-panel');
function renderInfo(title, entries){
function esc(s){ return String(s).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;'); }
var html = '<h3>' + esc(title) + '</h3>';
html += '<div class="kv">';
entries.forEach(function(e){
html += '<div><span class="k">' + esc(e.k) + ':</span> <span class="v">' + esc(e.v) + '</span></div>';
});
html += '</div>';
infoPanel.html(html);
}
function pickDescription(obj){
if (!obj) return null;
var keys = ['description','summary','text','content'];
for (var i=0; i<keys.length; i++){
var v = obj[keys[i]];
if (typeof v === 'string' && v.trim()) return v.trim();
}
return null;
}
function truncate(s, n){ if (!s) return s; return s.length > n ? (s.slice(0, n) + '') : s; }
function renderNodeInfo(n){
var entries = [];
if (n.name) entries.push({k:'Name', v: n.name});
if (n.type) entries.push({k:'Type', v: n.type});
if (n.id) entries.push({k:'ID', v: n.id});
var desc = pickDescription(n) || pickDescription(n.properties);
if (desc) entries.push({k:'Description', v: truncate(desc.replace(/\s+/g,' ').trim(), 280)});
if (n.properties) {
Object.keys(n.properties).slice(0, 12).forEach(function(key){
var v = n.properties[key];
if (v !== undefined && v !== null && typeof v !== 'object') entries.push({k: key, v: String(v)});
});
}
renderInfo(n.name || 'Node', entries);
}
function renderEdgeInfo(e){
var entries = [];
if (e.relation) entries.push({k:'Relation', v: e.relation});
if (e.weight !== undefined && e.weight !== null) entries.push({k:'Weight', v: e.weight});
if (e.all_weights && Object.keys(e.all_weights).length){
Object.keys(e.all_weights).slice(0, 8).forEach(function(k){ entries.push({k: 'w.'+k, v: e.all_weights[k]}); });
}
if (e.relationship_type) entries.push({k:'Type', v: e.relationship_type});
var edesc = pickDescription(e.edge_info);
if (edesc) entries.push({k:'Description', v: truncate(edesc.replace(/\s+/g,' ').trim(), 280)});
renderInfo('Edge', entries);
}
// Basic runtime diagnostics
console.log('[Cognee Visualization] nodes:', nodes ? nodes.length : 0, 'links:', links ? links.length : 0);
window.addEventListener('error', function(e){
try {
tooltip.html('<strong>Error:</strong> ' + e.message)
.style('left', '12px')
.style('top', '12px')
.style('opacity', 1);
} catch(_) {}
});
// Normalize node IDs and link endpoints for robustness
function resolveId(d){ return (d && (d.id || d.node_id || d.uuid || d.external_id || d.name)) || undefined; }
if (Array.isArray(nodes)) {
nodes.forEach(function(n){ var id = resolveId(n); if (id !== undefined) n.id = id; });
}
if (Array.isArray(links)) {
links.forEach(function(l){
if (typeof l.source === 'object') l.source = resolveId(l.source);
if (typeof l.target === 'object') l.target = resolveId(l.target);
});
}
if (!nodes || nodes.length === 0) {
container.append('text')
.attr('x', width / 2)
.attr('y', height / 2)
.attr('fill', '#fff')
.attr('font-size', 14)
.attr('text-anchor', 'middle')
.text('No graph data available');
}
// Visual defs - reusable glow
var defs = svg.append("defs");
var glow = defs.append("filter").attr("id", "glow")
.attr("x", "-30%")
.attr("y", "-30%")
.attr("width", "160%")
.attr("height", "160%");
glow.append("feGaussianBlur").attr("stdDeviation", 8).attr("result", "coloredBlur");
var feMerge = glow.append("feMerge");
feMerge.append("feMergeNode").attr("in", "coloredBlur");
feMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Stronger glow for hovered adjacency
var glowStrong = defs.append("filter").attr("id", "glow-strong")
.attr("x", "-40%")
.attr("y", "-40%")
.attr("width", "180%")
.attr("height", "180%");
glowStrong.append("feGaussianBlur").attr("stdDeviation", 14).attr("result", "coloredBlur");
var feMerge2 = glowStrong.append("feMerge");
feMerge2.append("feMergeNode").attr("in", "coloredBlur");
feMerge2.append("feMergeNode").attr("in", "SourceGraphic");
var currentTransform = d3.zoomIdentity;
var densityZoomTimer = null;
var isInteracting = false;
var labelBaseSize = 10;
function getGroupKey(d){ return d && (d.type || d.category || d.group || d.color) || 'default'; }
var simulation = d3.forceSimulation(nodes)
.force("link", d3.forceLink(links).id(d => d.id).strength(0.1))
.force("charge", d3.forceManyBody().strength(-275))
.force("link", d3.forceLink(links).id(function(d){ return d.id; }).distance(100).strength(0.2))
.force("charge", d3.forceManyBody().strength(-180))
.force("collide", d3.forceCollide().radius(16).iterations(2))
.force("center", d3.forceCenter(width / 2, height / 2))
.force("x", d3.forceX().strength(0.1).x(width / 2))
.force("y", d3.forceY().strength(0.1).y(height / 2));
.force("x", d3.forceX().strength(0.06).x(width / 2))
.force("y", d3.forceY().strength(0.06).y(height / 2))
.alphaDecay(0.06)
.velocityDecay(0.6);
// Density layer (sibling of container to avoid double transforms)
var densityLayer = svg.append("g")
.attr("class", "density")
.style("pointer-events", "none");
if (densityLayer.lower) densityLayer.lower();
var link = container.append("g")
.attr("class", "links")
.selectAll("line")
.data(links)
.enter().append("line")
.style("opacity", 0)
.style("pointer-events", "none")
.attr("stroke-width", d => {
if (d.weight) return Math.max(2, d.weight * 5);
if (d.all_weights && Object.keys(d.all_weights).length > 0) {
@ -168,6 +314,7 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
})
.on("mouseover", function(d) {
// Create tooltip content for edge
renderEdgeInfo(d);
var content = "<strong>Edge Information</strong><br/>";
content += "Relationship: " + d.relation + "<br/>";
@ -212,6 +359,7 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
.data(links)
.enter().append("text")
.attr("class", "edge-label")
.style("opacity", 0)
.text(d => {
var label = d.relation;
if (d.all_weights && Object.keys(d.all_weights).length > 1) {
@ -232,21 +380,225 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
.data(nodes)
.enter().append("g");
// Color fallback by type when d.color is missing
var colorByType = {
"Entity": "#5C10F4",
"EntityType": "#A550FF",
"DocumentChunk": "#0DFF00",
"TextSummary": "#5C10F4",
"TableRow": "#A550FF",
"TableType": "#5C10F4",
"ColumnValue": "#757470",
"SchemaTable": "#A550FF",
"DatabaseSchema": "#5C10F4",
"SchemaRelationship": "#323332"
};
var node = nodeGroup.append("circle")
.attr("r", 13)
.attr("fill", d => d.color)
.attr("fill", function(d){ return d.color || colorByType[d.type] || "#D3D3D3"; })
.style("filter", "url(#glow)")
.attr("shape-rendering", "geometricPrecision")
.call(d3.drag()
.on("start", dragstarted)
.on("drag", dragged)
.on("drag", function(d){ dragged(d); updateDensity(); showAdjacency(d); })
.on("end", dragended));
nodeGroup.append("text")
// Show links only for hovered node adjacency
function isAdjacent(linkDatum, nodeId) {
var sid = linkDatum && linkDatum.source && (linkDatum.source.id || linkDatum.source);
var tid = linkDatum && linkDatum.target && (linkDatum.target.id || linkDatum.target);
return sid === nodeId || tid === nodeId;
}
function showAdjacency(d) {
var nodeId = d && (d.id || d.node_id || d.uuid || d.external_id || d.name);
if (!nodeId) return;
// Build neighbor set
var neighborIds = {};
neighborIds[nodeId] = true;
for (var i = 0; i < links.length; i++) {
var l = links[i];
var sid = l && l.source && (l.source.id || l.source);
var tid = l && l.target && (l.target.id || l.target);
if (sid === nodeId) neighborIds[tid] = true;
if (tid === nodeId) neighborIds[sid] = true;
}
link
.style("opacity", function(l){ return isAdjacent(l, nodeId) ? 0.95 : 0; })
.style("stroke", function(l){ return isAdjacent(l, nodeId) ? "rgba(255,255,255,0.95)" : null; })
.style("stroke-width", function(l){ return isAdjacent(l, nodeId) ? 2.5 : 1.5; });
edgeLabels.style("opacity", function(l){ return isAdjacent(l, nodeId) ? 1 : 0; });
densityLayer.style("opacity", 0.35);
// Highlight neighbor nodes and dim others
node
.style("opacity", function(n){ return neighborIds[n.id] ? 1 : 0.25; })
.style("filter", function(n){ return neighborIds[n.id] ? "url(#glow-strong)" : "url(#glow)"; })
.attr("r", function(n){ return neighborIds[n.id] ? 15 : 13; });
// Raise highlighted nodes
node.filter(function(n){ return neighborIds[n.id]; }).raise();
// Neighbor labels brighter
nodeGroup.select("text")
.style("opacity", function(n){ return neighborIds[n.id] ? 1 : 0.2; })
.style("font-size", function(n){
var size = neighborIds[n.id] ? Math.min(22, labelBaseSize * 1.25) : labelBaseSize;
return size + "px";
});
}
function clearAdjacency() {
link.style("opacity", 0)
.style("stroke", null)
.style("stroke-width", 1.5);
edgeLabels.style("opacity", 0);
densityLayer.style("opacity", 1);
node
.style("opacity", 1)
.style("filter", "url(#glow)")
.attr("r", 13);
nodeGroup.select("text")
.style("opacity", 1)
.style("font-size", labelBaseSize + "px");
}
node.on("mouseover", function(d){ showAdjacency(d); })
.on("mouseout", function(){ clearAdjacency(); });
node.on("mouseover", function(d){ renderNodeInfo(d); tooltip.style('opacity', 0); });
// Also bind on the group so labels trigger adjacency too
nodeGroup.on("mouseover", function(d){ showAdjacency(d); })
.on("mouseout", function(){ clearAdjacency(); });
// Density always on; no hover gating
// Add labels sparsely to reduce clutter (every ~50th node), and truncate long text
nodeGroup
.filter(function(d, i){ return i % 14 === 0; })
.append("text")
.attr("class", "node-label")
.attr("dy", 4)
.attr("text-anchor", "middle")
.text(d => d.name);
.text(function(d){
var s = d && d.name ? String(d.name) : '';
return s.length > 40 ? (s.slice(0, 40) + "") : s;
})
.style("font-size", labelBaseSize + "px");
node.append("title").text(d => JSON.stringify(d));
function applyLabelSize() {
var k = (currentTransform && currentTransform.k) || 1;
// Keep labels readable across zoom levels and hide when too small
labelBaseSize = Math.max(7, Math.min(18, 10 / Math.sqrt(k)));
nodeGroup.select("text")
.style("font-size", labelBaseSize + "px")
.style("display", (k < 0.35 ? "none" : null));
}
// Density cloud computation (throttled)
var densityTick = 0;
var geoPath = d3.geoPath().projection(null);
var MAX_POINTS_PER_GROUP = 400;
function updateDensity() {
try {
if (isInteracting) return; // skip during interaction for smoother UX
if (typeof d3 === 'undefined' || typeof d3.contourDensity !== 'function') {
return; // d3-contour not available; skip gracefully
}
if (!nodes || nodes.length === 0) return;
var usable = nodes.filter(function(d){ return d && typeof d.x === 'number' && isFinite(d.x) && typeof d.y === 'number' && isFinite(d.y); });
if (usable.length < 3) return; // not enough positioned points yet
var t = currentTransform || d3.zoomIdentity;
if (t.k && t.k < 0.08) {
// Skip density at extreme zoom-out to avoid numerical instability/perf issues
densityLayer.selectAll('*').remove();
return;
}
function hexToRgb(hex){
if (!hex) return {r: 0, g: 200, b: 255};
var c = hex.replace('#','');
if (c.length === 3) c = c.split('').map(function(x){ return x+x; }).join('');
var num = parseInt(c, 16);
return { r: (num >> 16) & 255, g: (num >> 8) & 255, b: num & 255 };
}
// Build groups across all nodes
var groups = {};
for (var i = 0; i < usable.length; i++) {
var k = getGroupKey(usable[i]);
if (!groups[k]) groups[k] = [];
groups[k].push(usable[i]);
}
densityLayer.selectAll('*').remove();
Object.keys(groups).forEach(function(key){
var arr = groups[key];
if (!arr || arr.length < 3) return;
// Transform positions into screen space and sample to cap cost
var arrT = [];
var step = Math.max(1, Math.floor(arr.length / MAX_POINTS_PER_GROUP));
for (var j = 0; j < arr.length; j += step) {
var nx = t.applyX(arr[j].x);
var ny = t.applyY(arr[j].y);
if (isFinite(nx) && isFinite(ny)) {
arrT.push({ x: nx, y: ny, type: arr[j].type, color: arr[j].color });
}
}
if (arrT.length < 3) return;
// Compute adaptive bandwidth based on group spread
var cx = 0, cy = 0;
for (var k = 0; k < arrT.length; k++){ cx += arrT[k].x; cy += arrT[k].y; }
cx /= arrT.length; cy /= arrT.length;
var sumR = 0;
for (var k2 = 0; k2 < arrT.length; k2++){
var dx = arrT[k2].x - cx, dy = arrT[k2].y - cy;
sumR += Math.sqrt(dx*dx + dy*dy);
}
var avgR = sumR / arrT.length;
var dynamicBandwidth = Math.max(12, Math.min(80, avgR));
var densityBandwidth = dynamicBandwidth / (t.k || 1);
var contours = d3.contourDensity()
.x(function(d){ return d.x; })
.y(function(d){ return d.y; })
.size([width, height])
.bandwidth(densityBandwidth)
.thresholds(8)
(arrT);
if (!contours || contours.length === 0) return;
var maxVal = d3.max(contours, function(d){ return d.value; }) || 1;
// Use the first node color in the group or fallback neon palette
var baseColor = (arr.find(function(d){ return d && d.color; }) || {}).color || '#00c8ff';
var rgb = hexToRgb(baseColor);
var g = densityLayer.append('g').attr('data-group', key);
g.selectAll('path')
.data(contours)
.enter()
.append('path')
.attr('d', geoPath)
.attr('fill', 'rgb(' + rgb.r + ',' + rgb.g + ',' + rgb.b + ')')
.attr('stroke', 'none')
.style('opacity', function(d){
var v = maxVal ? (d.value / maxVal) : 0;
var alpha = Math.pow(Math.max(0, Math.min(1, v)), 1.6); // accentuate clusters
return 0.65 * alpha; // up to 0.65 opacity at peak density
})
.style('filter', 'blur(2px)');
});
} catch (e) {
// Reduce impact of any runtime errors during zoom
console.warn('Density update failed:', e);
}
}
simulation.on("tick", function() {
link.attr("x1", d => d.source.x)
@ -266,16 +618,29 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
.attr("y", d => d.y)
.attr("dy", 4)
.attr("text-anchor", "middle");
densityTick += 1;
if (densityTick % 24 === 0) updateDensity();
});
svg.call(d3.zoom().on("zoom", function() {
container.attr("transform", d3.event.transform);
}));
var zoomBehavior = d3.zoom()
.on("start", function(){ isInteracting = true; densityLayer.style("opacity", 0.2); })
.on("zoom", function(){
currentTransform = d3.event.transform;
container.attr("transform", currentTransform);
})
.on("end", function(){
if (densityZoomTimer) clearTimeout(densityZoomTimer);
densityZoomTimer = setTimeout(function(){ isInteracting = false; densityLayer.style("opacity", 1); updateDensity(); }, 140);
});
svg.call(zoomBehavior);
function dragstarted(d) {
if (!d3.event.active) simulation.alphaTarget(0.3).restart();
d.fx = d.x;
d.fy = d.y;
isInteracting = true;
densityLayer.style("opacity", 0.2);
}
function dragged(d) {
@ -287,6 +652,8 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
if (!d3.event.active) simulation.alphaTarget(0);
d.fx = null;
d.fy = null;
if (densityZoomTimer) clearTimeout(densityZoomTimer);
densityZoomTimer = setTimeout(function(){ isInteracting = false; densityLayer.style("opacity", 1); updateDensity(); }, 140);
}
window.addEventListener("resize", function() {
@ -295,7 +662,13 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
svg.attr("width", width).attr("height", height);
simulation.force("center", d3.forceCenter(width / 2, height / 2));
simulation.alpha(1).restart();
updateDensity();
applyLabelSize();
});
// Initial density draw
updateDensity();
applyLabelSize();
</script>
<svg style="position: fixed; bottom: 10px; right: 10px; width: 150px; height: auto; z-index: 9999;" viewBox="0 0 158 44" fill="none" xmlns="http://www.w3.org/2000/svg">
@ -305,8 +678,12 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
</html>
"""
html_content = html_template.replace("{nodes}", json.dumps(nodes_list))
html_content = html_content.replace("{links}", json.dumps(links_list))
# Safely embed JSON inside <script> by escaping </ to avoid prematurely closing the tag
def _safe_json_embed(obj):
return json.dumps(obj).replace("</", "<\\/")
html_content = html_template.replace("{nodes}", _safe_json_embed(nodes_list))
html_content = html_content.replace("{links}", _safe_json_embed(links_list))
if not destination_file_path:
home_dir = os.path.expanduser("~")

View file

@ -1,6 +1,7 @@
import os
import sys
import logging
import tempfile
import structlog
import traceback
import platform
@ -76,9 +77,38 @@ log_levels = {
# Track if structlog logging has been configured
_is_structlog_configured = False
# Path to logs directory
LOGS_DIR = Path(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "logs"))
LOGS_DIR.mkdir(exist_ok=True) # Create logs dir if it doesn't exist
def resolve_logs_dir():
"""Resolve a writable logs directory.
Priority:
1) BaseConfig.logs_root_directory (respects COGNEE_LOGS_DIR)
2) /tmp/cognee_logs (default, best-effort create)
Returns a Path or None if none are writable/creatable.
"""
from cognee.base_config import get_base_config
base_config = get_base_config()
logs_root_directory = Path(base_config.logs_root_directory)
try:
logs_root_directory.mkdir(parents=True, exist_ok=True)
if os.access(logs_root_directory, os.W_OK):
return logs_root_directory
except Exception:
pass
try:
tmp_log_path = Path(os.path.join("/tmp", "cognee_logs"))
tmp_log_path.mkdir(parents=True, exist_ok=True)
if os.access(tmp_log_path, os.W_OK):
return tmp_log_path
except Exception:
pass
return None
# Maximum number of log files to keep
MAX_LOG_FILES = 10
@ -430,28 +460,38 @@ def setup_logging(log_level=None, name=None):
stream_handler.setFormatter(console_formatter)
stream_handler.setLevel(log_level)
root_logger = logging.getLogger()
if root_logger.hasHandlers():
root_logger.handlers.clear()
root_logger.addHandler(stream_handler)
# Note: root logger needs to be set at NOTSET to allow all messages through and specific stream and file handlers
# can define their own levels.
root_logger.setLevel(logging.NOTSET)
# Resolve logs directory with env and safe fallbacks
logs_dir = resolve_logs_dir()
# Check if we already have a log file path from the environment
# NOTE: environment variable must be used here as it allows us to
# log to a single file with a name based on a timestamp in a multiprocess setting.
# Without it, we would have a separate log file for every process.
log_file_path = os.environ.get("LOG_FILE_NAME")
if not log_file_path:
if not log_file_path and logs_dir is not None:
# Create a new log file name with the cognee start time
start_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_file_path = os.path.join(LOGS_DIR, f"{start_time}.log")
log_file_path = str((logs_dir / f"{start_time}.log").resolve())
os.environ["LOG_FILE_NAME"] = log_file_path
# Create a file handler that uses our custom PlainFileHandler
file_handler = PlainFileHandler(log_file_path, encoding="utf-8")
file_handler.setLevel(DEBUG)
# Configure root logger
root_logger = logging.getLogger()
if root_logger.hasHandlers():
root_logger.handlers.clear()
root_logger.addHandler(stream_handler)
root_logger.addHandler(file_handler)
root_logger.setLevel(log_level)
try:
# Create a file handler that uses our custom PlainFileHandler
file_handler = PlainFileHandler(log_file_path, encoding="utf-8")
file_handler.setLevel(DEBUG)
root_logger.addHandler(file_handler)
except Exception as e:
# Note: Exceptions happen in case of read only file systems or log file path poiting to location where it does
# not have write permission. Logging to file is not mandatory so we just log a warning to console.
root_logger.warning(f"Warning: Could not create log file handler at {log_file_path}: {e}")
if log_level > logging.DEBUG:
import warnings
@ -466,7 +506,8 @@ def setup_logging(log_level=None, name=None):
)
# Clean up old log files, keeping only the most recent ones
cleanup_old_logs(LOGS_DIR, MAX_LOG_FILES)
if logs_dir is not None:
cleanup_old_logs(logs_dir, MAX_LOG_FILES)
# Mark logging as configured
_is_structlog_configured = True
@ -490,6 +531,10 @@ def setup_logging(log_level=None, name=None):
# Get a configured logger and log system information
logger = structlog.get_logger(name if name else __name__)
if logs_dir is not None:
logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path)
# Detailed initialization for regular usage
logger.info(
"Logging initialized",

View file

@ -8,11 +8,13 @@ import http.server
import socketserver
from threading import Thread
import pathlib
from uuid import uuid4
from uuid import uuid4, uuid5, NAMESPACE_OID
from cognee.base_config import get_base_config
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
logger = get_logger()
# Analytics Proxy Url, currently hosted by Vercel
proxy_url = "https://test.prometh.ai"
@ -38,19 +40,44 @@ def get_anonymous_id():
home_dir = str(pathlib.Path(pathlib.Path(__file__).parent.parent.parent.resolve()))
if not os.path.isdir(home_dir):
os.makedirs(home_dir, exist_ok=True)
anonymous_id_file = os.path.join(home_dir, ".anon_id")
if not os.path.isfile(anonymous_id_file):
anonymous_id = str(uuid4())
with open(anonymous_id_file, "w", encoding="utf-8") as f:
f.write(anonymous_id)
else:
with open(anonymous_id_file, "r", encoding="utf-8") as f:
anonymous_id = f.read()
try:
if not os.path.isdir(home_dir):
os.makedirs(home_dir, exist_ok=True)
anonymous_id_file = os.path.join(home_dir, ".anon_id")
if not os.path.isfile(anonymous_id_file):
anonymous_id = str(uuid4())
with open(anonymous_id_file, "w", encoding="utf-8") as f:
f.write(anonymous_id)
else:
with open(anonymous_id_file, "r", encoding="utf-8") as f:
anonymous_id = f.read()
except Exception as e:
# In case of read-only filesystem or other issues
logger.warning("Could not create or read anonymous id file: %s", e)
return "unknown-anonymous-id"
return anonymous_id
def _sanitize_nested_properties(obj, property_names: list[str]):
"""
Recursively replaces any property whose key matches one of `property_names`
(e.g., ['url', 'path']) in a nested dict or list with a uuid5 hash
of its string value. Returns a new sanitized copy.
"""
if isinstance(obj, dict):
new_obj = {}
for k, v in obj.items():
if k in property_names and isinstance(v, str):
new_obj[k] = str(uuid5(NAMESPACE_OID, v))
else:
new_obj[k] = _sanitize_nested_properties(v, property_names)
return new_obj
elif isinstance(obj, list):
return [_sanitize_nested_properties(item, property_names) for item in obj]
else:
return obj
def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
if os.getenv("TELEMETRY_DISABLED"):
return
@ -58,7 +85,9 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
env = os.getenv("ENV")
if env in ["test", "dev"]:
return
additional_properties = _sanitize_nested_properties(
obj=additional_properties, property_names=["url"]
)
current_time = datetime.now(timezone.utc)
payload = {
"anonymous_id": str(get_anonymous_id()),

View file

@ -0,0 +1,13 @@
from .extract_feedback_interactions import extract_feedback_interactions
from .generate_improved_answers import generate_improved_answers
from .create_enrichments import create_enrichments
from .link_enrichments_to_feedback import link_enrichments_to_feedback
from .models import FeedbackEnrichment
__all__ = [
"extract_feedback_interactions",
"generate_improved_answers",
"create_enrichments",
"link_enrichments_to_feedback",
"FeedbackEnrichment",
]

View file

@ -0,0 +1,84 @@
from __future__ import annotations
from typing import List
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.llm import LLMGateway
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.models import NodeSet
from .models import FeedbackEnrichment
logger = get_logger("create_enrichments")
def _validate_enrichments(enrichments: List[FeedbackEnrichment]) -> bool:
"""Validate that all enrichments contain required fields for completion."""
return all(
enrichment.question is not None
and enrichment.original_answer is not None
and enrichment.improved_answer is not None
and enrichment.new_context is not None
and enrichment.feedback_id is not None
and enrichment.interaction_id is not None
for enrichment in enrichments
)
async def _generate_enrichment_report(
question: str, improved_answer: str, new_context: str, report_prompt_location: str
) -> str:
"""Generate educational report using feedback report prompt."""
try:
prompt_template = read_query_prompt(report_prompt_location)
rendered_prompt = prompt_template.format(
question=question,
improved_answer=improved_answer,
new_context=new_context,
)
return await LLMGateway.acreate_structured_output(
text_input=rendered_prompt,
system_prompt="You are a helpful assistant that creates educational content.",
response_model=str,
)
except Exception as exc:
logger.warning("Failed to generate enrichment report", error=str(exc), question=question)
return f"Educational content for: {question} - {improved_answer}"
async def create_enrichments(
enrichments: List[FeedbackEnrichment],
report_prompt_location: str = "feedback_report_prompt.txt",
) -> List[FeedbackEnrichment]:
"""Fill text and belongs_to_set fields of existing FeedbackEnrichment DataPoints."""
if not enrichments:
logger.info("No enrichments provided; returning empty list")
return []
if not _validate_enrichments(enrichments):
logger.error("Input validation failed; missing required fields")
return []
logger.info("Completing enrichments", count=len(enrichments))
nodeset = NodeSet(id=uuid5(NAMESPACE_OID, name="FeedbackEnrichment"), name="FeedbackEnrichment")
completed_enrichments: List[FeedbackEnrichment] = []
for enrichment in enrichments:
report_text = await _generate_enrichment_report(
enrichment.question,
enrichment.improved_answer,
enrichment.new_context,
report_prompt_location,
)
enrichment.text = report_text
enrichment.belongs_to_set = [nodeset]
completed_enrichments.append(enrichment)
logger.info("Completed enrichments", successful=len(completed_enrichments))
return completed_enrichments

View file

@ -0,0 +1,230 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
from uuid import UUID, uuid5, NAMESPACE_OID
from cognee.infrastructure.llm import LLMGateway
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from .models import FeedbackEnrichment
logger = get_logger("extract_feedback_interactions")
def _filter_negative_feedback(feedback_nodes):
"""Filter for negative sentiment feedback using precise sentiment classification."""
return [
(node_id, props)
for node_id, props in feedback_nodes
if (props.get("sentiment", "").casefold() == "negative" or props.get("score", 0) < 0)
]
def _get_normalized_id(node_id, props) -> str:
"""Return Cognee node id preference: props.id → props.node_id → raw node_id."""
return str(props.get("id") or props.get("node_id") or node_id)
async def _fetch_feedback_and_interaction_graph_data() -> Tuple[List, List]:
"""Fetch feedback and interaction nodes with edges from graph engine."""
try:
graph_engine = await get_graph_engine()
attribute_filters = [{"type": ["CogneeUserFeedback", "CogneeUserInteraction"]}]
return await graph_engine.get_filtered_graph_data(attribute_filters)
except Exception as exc: # noqa: BLE001
logger.error("Failed to fetch filtered graph data", error=str(exc))
return [], []
def _separate_feedback_and_interaction_nodes(graph_nodes: List) -> Tuple[List, List]:
"""Split nodes into feedback and interaction groups by type field."""
feedback_nodes = [
(_get_normalized_id(node_id, props), props)
for node_id, props in graph_nodes
if props.get("type") == "CogneeUserFeedback"
]
interaction_nodes = [
(_get_normalized_id(node_id, props), props)
for node_id, props in graph_nodes
if props.get("type") == "CogneeUserInteraction"
]
return feedback_nodes, interaction_nodes
def _match_feedback_nodes_to_interactions_by_edges(
feedback_nodes: List, interaction_nodes: List, graph_edges: List
) -> List[Tuple[Tuple, Tuple]]:
"""Match feedback to interactions using gives_feedback_to edges."""
interaction_by_id = {node_id: (node_id, props) for node_id, props in interaction_nodes}
feedback_by_id = {node_id: (node_id, props) for node_id, props in feedback_nodes}
feedback_edges = [
(source_id, target_id)
for source_id, target_id, rel, _ in graph_edges
if rel == "gives_feedback_to"
]
feedback_interaction_pairs: List[Tuple[Tuple, Tuple]] = []
for source_id, target_id in feedback_edges:
source_id_str, target_id_str = str(source_id), str(target_id)
feedback_node = feedback_by_id.get(source_id_str)
interaction_node = interaction_by_id.get(target_id_str)
if feedback_node and interaction_node:
feedback_interaction_pairs.append((feedback_node, interaction_node))
return feedback_interaction_pairs
def _sort_pairs_by_recency_and_limit(
feedback_interaction_pairs: List[Tuple[Tuple, Tuple]], last_n_limit: Optional[int]
) -> List[Tuple[Tuple, Tuple]]:
"""Sort by interaction created_at desc with updated_at fallback, then limit."""
def _recency_key(pair):
_, (_, interaction_props) = pair
created_at = interaction_props.get("created_at") or ""
updated_at = interaction_props.get("updated_at") or ""
return (created_at, updated_at)
sorted_pairs = sorted(feedback_interaction_pairs, key=_recency_key, reverse=True)
return sorted_pairs[: last_n_limit or len(sorted_pairs)]
async def _generate_human_readable_context_summary(
question_text: str, raw_context_text: str
) -> str:
"""Generate a concise human-readable summary for given context."""
try:
prompt = read_query_prompt("feedback_user_context_prompt.txt")
rendered = prompt.format(question=question_text, context=raw_context_text)
return await LLMGateway.acreate_structured_output(
text_input=rendered, system_prompt="", response_model=str
)
except Exception as exc: # noqa: BLE001
logger.warning("Failed to summarize context", error=str(exc))
return raw_context_text or ""
def _has_required_feedback_fields(enrichment: FeedbackEnrichment) -> bool:
"""Validate required fields exist in the FeedbackEnrichment DataPoint."""
return (
enrichment.question is not None
and enrichment.original_answer is not None
and enrichment.context is not None
and enrichment.feedback_text is not None
and enrichment.feedback_id is not None
and enrichment.interaction_id is not None
)
async def _build_feedback_interaction_record(
feedback_node_id: str, feedback_props: Dict, interaction_node_id: str, interaction_props: Dict
) -> Optional[FeedbackEnrichment]:
"""Build a single FeedbackEnrichment DataPoint with context summary."""
try:
question_text = interaction_props.get("question")
original_answer_text = interaction_props.get("answer")
raw_context_text = interaction_props.get("context", "")
feedback_text = feedback_props.get("feedback") or feedback_props.get("text") or ""
context_summary_text = await _generate_human_readable_context_summary(
question_text or "", raw_context_text
)
enrichment = FeedbackEnrichment(
id=str(uuid5(NAMESPACE_OID, f"{question_text}_{interaction_node_id}")),
text="",
question=question_text,
original_answer=original_answer_text,
improved_answer="",
feedback_id=UUID(str(feedback_node_id)),
interaction_id=UUID(str(interaction_node_id)),
belongs_to_set=None,
context=context_summary_text,
feedback_text=feedback_text,
new_context="",
explanation="",
)
if _has_required_feedback_fields(enrichment):
return enrichment
else:
logger.warning("Skipping invalid feedback item", interaction=str(interaction_node_id))
return None
except Exception as exc: # noqa: BLE001
logger.error("Failed to process feedback pair", error=str(exc))
return None
async def _build_feedback_interaction_records(
matched_feedback_interaction_pairs: List[Tuple[Tuple, Tuple]],
) -> List[FeedbackEnrichment]:
"""Build all FeedbackEnrichment DataPoints from matched pairs."""
feedback_interaction_records: List[FeedbackEnrichment] = []
for (feedback_node_id, feedback_props), (
interaction_node_id,
interaction_props,
) in matched_feedback_interaction_pairs:
record = await _build_feedback_interaction_record(
feedback_node_id, feedback_props, interaction_node_id, interaction_props
)
if record:
feedback_interaction_records.append(record)
return feedback_interaction_records
async def extract_feedback_interactions(
data: Any, last_n: Optional[int] = None
) -> List[FeedbackEnrichment]:
"""Extract negative feedback-interaction pairs and create FeedbackEnrichment DataPoints."""
if not data or data == [{}]:
logger.info(
"No data passed to the extraction task (extraction task fetches data from graph directly)",
data=data,
)
graph_nodes, graph_edges = await _fetch_feedback_and_interaction_graph_data()
if not graph_nodes:
logger.warning("No graph nodes retrieved from database")
return []
feedback_nodes, interaction_nodes = _separate_feedback_and_interaction_nodes(graph_nodes)
logger.info(
"Retrieved nodes from graph",
total_nodes=len(graph_nodes),
feedback_nodes=len(feedback_nodes),
interaction_nodes=len(interaction_nodes),
)
negative_feedback_nodes = _filter_negative_feedback(feedback_nodes)
logger.info(
"Filtered feedback nodes",
total_feedback=len(feedback_nodes),
negative_feedback=len(negative_feedback_nodes),
)
if not negative_feedback_nodes:
logger.info("No negative feedback found; returning empty list")
return []
matched_feedback_interaction_pairs = _match_feedback_nodes_to_interactions_by_edges(
negative_feedback_nodes, interaction_nodes, graph_edges
)
if not matched_feedback_interaction_pairs:
logger.info("No feedback-to-interaction matches found; returning empty list")
return []
matched_feedback_interaction_pairs = _sort_pairs_by_recency_and_limit(
matched_feedback_interaction_pairs, last_n
)
feedback_interaction_records = await _build_feedback_interaction_records(
matched_feedback_interaction_pairs
)
logger.info("Extracted feedback pairs", count=len(feedback_interaction_records))
return feedback_interaction_records

View file

@ -0,0 +1,130 @@
from __future__ import annotations
from typing import List, Optional
from pydantic import BaseModel
from cognee.infrastructure.llm import LLMGateway
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from .models import FeedbackEnrichment
class ImprovedAnswerResponse(BaseModel):
"""Response model for improved answer generation containing answer and explanation."""
answer: str
explanation: str
logger = get_logger("generate_improved_answers")
def _validate_input_data(enrichments: List[FeedbackEnrichment]) -> bool:
"""Validate that input contains required fields for all enrichments."""
return all(
enrichment.question is not None
and enrichment.original_answer is not None
and enrichment.context is not None
and enrichment.feedback_text is not None
and enrichment.feedback_id is not None
and enrichment.interaction_id is not None
for enrichment in enrichments
)
def _render_reaction_prompt(
question: str, context: str, wrong_answer: str, negative_feedback: str
) -> str:
"""Render the feedback reaction prompt with provided variables."""
prompt_template = read_query_prompt("feedback_reaction_prompt.txt")
return prompt_template.format(
question=question,
context=context,
wrong_answer=wrong_answer,
negative_feedback=negative_feedback,
)
async def _generate_improved_answer_for_single_interaction(
enrichment: FeedbackEnrichment, retriever, reaction_prompt_location: str
) -> Optional[FeedbackEnrichment]:
"""Generate improved answer for a single enrichment using structured retriever completion."""
try:
query_text = _render_reaction_prompt(
enrichment.question,
enrichment.context,
enrichment.original_answer,
enrichment.feedback_text,
)
retrieved_context = await retriever.get_context(query_text)
completion = await retriever.get_structured_completion(
query=query_text,
context=retrieved_context,
response_model=ImprovedAnswerResponse,
max_iter=4,
)
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
if completion:
enrichment.improved_answer = completion.answer
enrichment.new_context = new_context_text
enrichment.explanation = completion.explanation
return enrichment
else:
logger.warning(
"Failed to get structured completion from retriever", question=enrichment.question
)
return None
except Exception as exc: # noqa: BLE001
logger.error(
"Failed to generate improved answer",
error=str(exc),
question=enrichment.question,
)
return None
async def generate_improved_answers(
enrichments: List[FeedbackEnrichment],
top_k: int = 20,
reaction_prompt_location: str = "feedback_reaction_prompt.txt",
) -> List[FeedbackEnrichment]:
"""Generate improved answers using CoT retriever and LLM."""
if not enrichments:
logger.info("No enrichments provided; returning empty list")
return []
if not _validate_input_data(enrichments):
logger.error("Input data validation failed; missing required fields")
return []
retriever = GraphCompletionCotRetriever(
top_k=top_k,
save_interaction=False,
user_prompt_path="graph_context_for_question.txt",
system_prompt_path="answer_simple_question.txt",
)
improved_answers: List[FeedbackEnrichment] = []
for enrichment in enrichments:
result = await _generate_improved_answer_for_single_interaction(
enrichment, retriever, reaction_prompt_location
)
if result:
improved_answers.append(result)
else:
logger.warning(
"Failed to generate improved answer",
question=enrichment.question,
interaction_id=enrichment.interaction_id,
)
logger.info("Generated improved answers", count=len(improved_answers))
return improved_answers

View file

@ -0,0 +1,67 @@
from __future__ import annotations
from typing import List, Tuple
from uuid import UUID
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.tasks.storage import index_graph_edges
from cognee.shared.logging_utils import get_logger
from .models import FeedbackEnrichment
logger = get_logger("link_enrichments_to_feedback")
def _create_edge_tuple(
source_id: UUID, target_id: UUID, relationship_name: str
) -> Tuple[UUID, UUID, str, dict]:
"""Create an edge tuple with proper properties structure."""
return (
source_id,
target_id,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id,
"ontology_valid": False,
},
)
async def link_enrichments_to_feedback(
enrichments: List[FeedbackEnrichment],
) -> List[FeedbackEnrichment]:
"""Manually create edges from enrichments to original feedback/interaction nodes."""
if not enrichments:
logger.info("No enrichments provided; returning empty list")
return []
relationships = []
for enrichment in enrichments:
enrichment_id = enrichment.id
feedback_id = enrichment.feedback_id
interaction_id = enrichment.interaction_id
if enrichment_id and feedback_id:
enriches_feedback_edge = _create_edge_tuple(
enrichment_id, feedback_id, "enriches_feedback"
)
relationships.append(enriches_feedback_edge)
if enrichment_id and interaction_id:
improves_interaction_edge = _create_edge_tuple(
enrichment_id, interaction_id, "improves_interaction"
)
relationships.append(improves_interaction_edge)
if relationships:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
await index_graph_edges(relationships)
logger.info("Linking enrichments to feedback", edge_count=len(relationships))
logger.info("Linked enrichments", enrichment_count=len(enrichments))
return enrichments

View file

@ -0,0 +1,26 @@
from typing import List, Optional, Union
from uuid import UUID
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.models import Entity, NodeSet
from cognee.tasks.temporal_graph.models import Event
class FeedbackEnrichment(DataPoint):
"""Minimal DataPoint for feedback enrichment that works with extract_graph_from_data."""
text: str
contains: Optional[List[Union[Entity, Event]]] = None
metadata: dict = {"index_fields": ["text"]}
question: str
original_answer: str
improved_answer: str
feedback_id: UUID
interaction_id: UUID
belongs_to_set: Optional[List[NodeSet]] = None
context: str = ""
feedback_text: str = ""
new_context: str = ""
explanation: str = ""

View file

@ -239,7 +239,7 @@ async def complete_database_ingestion(schema, migrate_column_data):
id=uuid5(NAMESPACE_OID, name=column_node_id),
name=column_node_id,
properties=f"{key} {value} {table_name}",
description=f"column from relational database table={table_name}. Column name={key} and value={value}. The value of the column is related to the following row with this id: {row_node.id}. This column has the following ID: {column_node_id}",
description=f"column from relational database table={table_name}. Column name={key} and value={value}. This column has the following ID: {column_node_id}",
)
node_mapping[column_node_id] = column_node

View file

@ -0,0 +1,174 @@
"""
End-to-end integration test for feedback enrichment feature.
Tests the complete feedback enrichment pipeline:
1. Add data and cognify
2. Run search with save_interaction=True to create CogneeUserInteraction nodes
3. Submit feedback to create CogneeUserFeedback nodes
4. Run memify with feedback enrichment tasks to create FeedbackEnrichment nodes
5. Verify all nodes and edges are properly created and linked in the graph
"""
import os
import pathlib
from collections import Counter
import cognee
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.search.types import SearchType
from cognee.shared.data_models import KnowledgeGraph
from cognee.shared.logging_utils import get_logger
from cognee.tasks.feedback.create_enrichments import create_enrichments
from cognee.tasks.feedback.extract_feedback_interactions import (
extract_feedback_interactions,
)
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
from cognee.tasks.feedback.link_enrichments_to_feedback import (
link_enrichments_to_feedback,
)
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
logger = get_logger()
async def main():
data_directory_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_feedback_enrichment",
)
).resolve()
)
cognee_directory_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_feedback_enrichment",
)
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
dataset_name = "feedback_enrichment_test"
await cognee.add("Cognee turns documents into AI memory.", dataset_name)
await cognee.cognify([dataset_name])
question_text = "Say something."
result = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=question_text,
save_interaction=True,
)
assert len(result) > 0, "Search should return non-empty results"
feedback_text = "This answer was completely useless, my feedback is definitely negative."
await cognee.search(
query_type=SearchType.FEEDBACK,
query_text=feedback_text,
last_k=1,
)
graph_engine = await get_graph_engine()
nodes_before, edges_before = await graph_engine.get_graph_data()
interaction_nodes_before = [
(node_id, props)
for node_id, props in nodes_before
if props.get("type") == "CogneeUserInteraction"
]
feedback_nodes_before = [
(node_id, props)
for node_id, props in nodes_before
if props.get("type") == "CogneeUserFeedback"
]
edge_types_before = Counter(edge[2] for edge in edges_before)
assert len(interaction_nodes_before) >= 1, (
f"Expected at least 1 CogneeUserInteraction node, found {len(interaction_nodes_before)}"
)
assert len(feedback_nodes_before) >= 1, (
f"Expected at least 1 CogneeUserFeedback node, found {len(feedback_nodes_before)}"
)
for node_id, props in feedback_nodes_before:
sentiment = props.get("sentiment", "")
score = props.get("score", 0)
feedback_text = props.get("feedback", "")
logger.info(
"Feedback node created",
feedback=feedback_text,
sentiment=sentiment,
score=score,
)
assert edge_types_before.get("gives_feedback_to", 0) >= 1, (
f"Expected at least 1 'gives_feedback_to' edge, found {edge_types_before.get('gives_feedback_to', 0)}"
)
extraction_tasks = [Task(extract_feedback_interactions, last_n=5)]
enrichment_tasks = [
Task(generate_improved_answers, top_k=20),
Task(create_enrichments),
Task(
extract_graph_from_data,
graph_model=KnowledgeGraph,
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
Task(link_enrichments_to_feedback),
]
await cognee.memify(
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
data=[{}],
dataset="feedback_enrichment_test_memify",
)
nodes_after, edges_after = await graph_engine.get_graph_data()
enrichment_nodes = [
(node_id, props)
for node_id, props in nodes_after
if props.get("type") == "FeedbackEnrichment"
]
assert len(enrichment_nodes) >= 1, (
f"Expected at least 1 FeedbackEnrichment node, found {len(enrichment_nodes)}"
)
for node_id, props in enrichment_nodes:
assert "text" in props, f"FeedbackEnrichment node {node_id} missing 'text' property"
enrichment_node_ids = {node_id for node_id, _ in enrichment_nodes}
edges_with_enrichments = [
edge
for edge in edges_after
if edge[0] in enrichment_node_ids or edge[1] in enrichment_node_ids
]
assert len(edges_with_enrichments) >= 1, (
f"Expected enrichment nodes to have at least 1 edge, found {len(edges_with_enrichments)}"
)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
logger.info("All feedback enrichment tests passed successfully")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -489,3 +489,154 @@ def test_get_ontology_resolver_from_env_resolver_functionality():
assert nodes == []
assert relationships == []
assert start_node is None
def test_multifile_ontology_loading_success():
"""Test successful loading of multiple ontology files."""
ns1 = Namespace("http://example.org/cars#")
ns2 = Namespace("http://example.org/tech#")
g1 = Graph()
g1.add((ns1.Vehicle, RDF.type, OWL.Class))
g1.add((ns1.Car, RDF.type, OWL.Class))
g1.add((ns1.Car, RDFS.subClassOf, ns1.Vehicle))
g1.add((ns1.Audi, RDF.type, ns1.Car))
g1.add((ns1.BMW, RDF.type, ns1.Car))
g2 = Graph()
g2.add((ns2.Company, RDF.type, OWL.Class))
g2.add((ns2.TechCompany, RDF.type, OWL.Class))
g2.add((ns2.TechCompany, RDFS.subClassOf, ns2.Company))
g2.add((ns2.Apple, RDF.type, ns2.TechCompany))
g2.add((ns2.Google, RDF.type, ns2.TechCompany))
import tempfile
with tempfile.NamedTemporaryFile(mode="w", suffix=".owl", delete=False) as f1:
g1.serialize(f1.name, format="xml")
file1_path = f1.name
with tempfile.NamedTemporaryFile(mode="w", suffix=".owl", delete=False) as f2:
g2.serialize(f2.name, format="xml")
file2_path = f2.name
try:
resolver = RDFLibOntologyResolver(ontology_file=[file1_path, file2_path])
assert resolver.graph is not None
assert "car" in resolver.lookup["classes"]
assert "vehicle" in resolver.lookup["classes"]
assert "company" in resolver.lookup["classes"]
assert "techcompany" in resolver.lookup["classes"]
assert "audi" in resolver.lookup["individuals"]
assert "bmw" in resolver.lookup["individuals"]
assert "apple" in resolver.lookup["individuals"]
assert "google" in resolver.lookup["individuals"]
car_match = resolver.find_closest_match("Audi", "individuals")
assert car_match == "audi"
tech_match = resolver.find_closest_match("Google", "individuals")
assert tech_match == "google"
finally:
import os
os.unlink(file1_path)
os.unlink(file2_path)
def test_multifile_ontology_with_missing_files():
"""Test loading multiple ontology files where some don't exist."""
ns = Namespace("http://example.org/test#")
g = Graph()
g.add((ns.Car, RDF.type, OWL.Class))
g.add((ns.Audi, RDF.type, ns.Car))
import tempfile
with tempfile.NamedTemporaryFile(mode="w", suffix=".owl", delete=False) as f:
g.serialize(f.name, format="xml")
valid_file = f.name
try:
resolver = RDFLibOntologyResolver(
ontology_file=["nonexistent_file_1.owl", valid_file, "nonexistent_file_2.owl"]
)
assert resolver.graph is not None
assert "car" in resolver.lookup["classes"]
assert "audi" in resolver.lookup["individuals"]
match = resolver.find_closest_match("Audi", "individuals")
assert match == "audi"
finally:
import os
os.unlink(valid_file)
def test_multifile_ontology_all_files_missing():
"""Test loading multiple ontology files where all files are missing."""
resolver = RDFLibOntologyResolver(
ontology_file=["nonexistent_file_1.owl", "nonexistent_file_2.owl", "nonexistent_file_3.owl"]
)
assert resolver.graph is None
assert resolver.lookup["classes"] == {}
assert resolver.lookup["individuals"] == {}
def test_multifile_ontology_with_overlapping_entities():
"""Test loading multiple ontology files with overlapping/related entities."""
ns = Namespace("http://example.org/automotive#")
g1 = Graph()
g1.add((ns.Vehicle, RDF.type, OWL.Class))
g1.add((ns.Car, RDF.type, OWL.Class))
g1.add((ns.Car, RDFS.subClassOf, ns.Vehicle))
g2 = Graph()
g2.add((ns.LuxuryCar, RDF.type, OWL.Class))
g2.add((ns.LuxuryCar, RDFS.subClassOf, ns.Car))
g2.add((ns.Mercedes, RDF.type, ns.LuxuryCar))
g2.add((ns.BMW, RDF.type, ns.LuxuryCar))
import tempfile
with tempfile.NamedTemporaryFile(mode="w", suffix=".owl", delete=False) as f1:
g1.serialize(f1.name, format="xml")
file1_path = f1.name
with tempfile.NamedTemporaryFile(mode="w", suffix=".owl", delete=False) as f2:
g2.serialize(f2.name, format="xml")
file2_path = f2.name
try:
resolver = RDFLibOntologyResolver(ontology_file=[file1_path, file2_path])
assert "vehicle" in resolver.lookup["classes"]
assert "car" in resolver.lookup["classes"]
assert "luxurycar" in resolver.lookup["classes"]
assert "mercedes" in resolver.lookup["individuals"]
assert "bmw" in resolver.lookup["individuals"]
nodes, relationships, start_node = resolver.get_subgraph("Mercedes", "individuals")
uri_labels = {resolver._uri_to_key(n.uri) for n in nodes}
assert "mercedes" in uri_labels
assert "luxurycar" in uri_labels
assert "car" in uri_labels
assert "vehicle" in uri_labels
finally:
import os
os.unlink(file1_path)
os.unlink(file2_path)

View file

@ -2,6 +2,7 @@ import os
import pytest
import pathlib
from typing import Optional, Union
from pydantic import BaseModel
import cognee
from cognee.low_level import setup, DataPoint
@ -10,6 +11,11 @@ from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
class TestAnswer(BaseModel):
answer: str
explanation: str
class TestGraphCompletionCoTRetriever:
@pytest.mark.asyncio
async def test_graph_completion_cot_context_simple(self):
@ -168,3 +174,48 @@ class TestGraphCompletionCoTRetriever:
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio
async def test_get_structured_completion(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
class Company(DataPoint):
name: str
class Person(DataPoint):
name: str
works_for: Company
company1 = Company(name="Figma")
person1 = Person(name="Steve Rodger", works_for=company1)
entities = [company1, person1]
await add_data_points(entities)
retriever = GraphCompletionCotRetriever()
# Test with string response model (default)
string_answer = await retriever.get_structured_completion("Who works at Figma?")
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
assert string_answer.strip(), "Answer should not be empty"
# Test with structured response model
structured_answer = await retriever.get_structured_completion(
"Who works at Figma?", response_model=TestAnswer
)
assert isinstance(structured_answer, TestAnswer), (
f"Expected TestAnswer, got {type(structured_answer).__name__}"
)
assert structured_answer.answer.strip(), "Answer field should not be empty"
assert structured_answer.explanation.strip(), "Explanation field should not be empty"

View file

@ -0,0 +1,82 @@
import asyncio
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.pipelines.tasks.task import Task
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
from cognee.shared.data_models import KnowledgeGraph
from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
from cognee.tasks.feedback.create_enrichments import create_enrichments
from cognee.tasks.feedback.link_enrichments_to_feedback import link_enrichments_to_feedback
CONVERSATION = [
"Alice: Hey, Bob. Did you talk to Mallory?",
"Bob: Yeah, I just saw her before coming here.",
"Alice: Then she told you to bring my documents, right?",
"Bob: Uh… not exactly. She said you wanted me to bring you donuts. Which sounded kind of odd…",
"Alice: Ugh, shes so annoying. Thanks for the donuts anyway!",
]
async def initialize_conversation_and_graph(conversation):
"""Prune data/system, add conversation, cognify."""
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add(conversation)
await cognee.cognify()
async def run_question_and_submit_feedback(question_text: str) -> bool:
"""Ask question, submit feedback based on correctness, and return correctness flag."""
result = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=question_text,
save_interaction=True,
)
answer_text = str(result).lower()
mentions_mallory = "mallory" in answer_text
feedback_text = (
"Great answers, very helpful!"
if mentions_mallory
else "The answer about Bob and donuts was wrong."
)
await cognee.search(
query_type=SearchType.FEEDBACK,
query_text=feedback_text,
last_k=1,
)
return mentions_mallory
async def run_feedback_enrichment_memify(last_n: int = 5):
"""Execute memify with extraction, answer improvement, enrichment creation, and graph processing tasks."""
# Instantiate tasks with their own kwargs
extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)]
enrichment_tasks = [
Task(generate_improved_answers, top_k=20),
Task(create_enrichments),
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}),
Task(add_data_points, task_config={"batch_size": 10}),
Task(link_enrichments_to_feedback),
]
await cognee.memify(
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
data=[{}], # A placeholder to prevent fetching the entire graph
dataset="feedback_enrichment_minimal",
)
async def main():
await initialize_conversation_and_graph(CONVERSATION)
is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
if not is_correct:
await run_feedback_enrichment_memify(last_n=5)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -77,6 +77,7 @@ async def main():
"What happened between 2000 and 2006?",
"What happened between 1903 and 1995, I am interested in the Selected Works of Arnulf Øverland Ole Peter Arnulf Øverland?",
"Who is Attaphol Buspakom Attaphol Buspakom?",
"Who was Arnulf Øverland?",
]
for query_text in queries:

6
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
[[package]]
name = "accelerate"
@ -2543,7 +2543,6 @@ files = [
{file = "fastuuid-0.12.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9b31dd488d0778c36f8279b306dc92a42f16904cba54acca71e107d65b60b0c"},
{file = "fastuuid-0.12.0-cp313-cp313-manylinux_2_34_x86_64.whl", hash = "sha256:b19361ee649365eefc717ec08005972d3d1eb9ee39908022d98e3bfa9da59e37"},
{file = "fastuuid-0.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:8fc66b11423e6f3e1937385f655bedd67aebe56a3dcec0cb835351cfe7d358c9"},
{file = "fastuuid-0.12.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:2925f67b88d47cb16aa3eb1ab20fdcf21b94d74490e0818c91ea41434b987493"},
{file = "fastuuid-0.12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7b15c54d300279ab20a9cc0579ada9c9f80d1bc92997fc61fb7bf3103d7cb26b"},
{file = "fastuuid-0.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:458f1bc3ebbd76fdb89ad83e6b81ccd3b2a99fa6707cd3650b27606745cfb170"},
{file = "fastuuid-0.12.0-cp38-cp38-manylinux_2_34_x86_64.whl", hash = "sha256:a8f0f83fbba6dc44271a11b22e15838641b8c45612cdf541b4822a5930f6893c"},
@ -4170,8 +4169,6 @@ groups = ["main"]
markers = "extra == \"dlt\""
files = [
{file = "jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c"},
{file = "jsonpath_ng-1.7.0-py2-none-any.whl", hash = "sha256:898c93fc173f0c336784a3fa63d7434297544b7198124a68f9a3ef9597b0ae6e"},
{file = "jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6"},
]
[package.dependencies]
@ -8593,7 +8590,6 @@ files = [
{file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"},
{file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"},
{file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"},
{file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"},
{file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"},
{file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"},
{file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"},

View file

@ -1,7 +1,7 @@
[project]
name = "cognee"
version = "0.3.7"
version = "0.3.7.dev1"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = [
{ name = "Vasilije Markovic" },

8274
uv.lock generated

File diff suppressed because it is too large Load diff