Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
vasilije
c1eb4daef5 add example of pythonic code 2025-09-06 09:46:59 -07:00
4 changed files with 504 additions and 13 deletions

View file

@ -27,6 +27,9 @@ from .api.v1.visualize import visualize_graph, start_visualization_server
from cognee.modules.visualization.cognee_network_visualization import (
cognee_network_visualization,
)
from cognee.modules.visualization.embedding_atlas_export import (
get_embeddings_for_atlas,
)
# Pipelines
from .modules import pipelines

View file

@ -1,6 +1,12 @@
from typing import Optional, List, Literal, Union
import pandas as pd
from cognee.modules.visualization.cognee_network_visualization import (
cognee_network_visualization,
)
from cognee.modules.visualization.embedding_atlas_export import (
export_embeddings_to_atlas,
get_embeddings_for_atlas,
)
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.shared.logging_utils import get_logger, setup_logging, ERROR
@ -11,20 +17,95 @@ import asyncio
logger = get_logger()
async def visualize_graph(destination_file_path: str = None):
graph_engine = await get_graph_engine()
graph_data = await graph_engine.get_graph_data()
graph = await cognee_network_visualization(graph_data, destination_file_path)
if destination_file_path:
logger.info(f"The HTML file has been stored at path: {destination_file_path}")
else:
logger.info(
"The HTML file has been stored on your home directory! Navigate there with cd ~"
async def visualize_graph(
destination_file_path: str = None,
mode: Literal["network", "atlas", "atlas_component"] = "network",
collections: Optional[List[str]] = None,
limit: Optional[int] = None,
compute_projection: bool = True
) -> Union[str, pd.DataFrame]:
"""
Visualize Cognee's knowledge graph using different visualization modes.
Parameters:
-----------
destination_file_path (str, optional): Path where to save the output file.
For network mode: saves HTML file
For atlas mode: saves parquet file
Not used for atlas_component mode
mode (str): Visualization mode:
- "network": Interactive HTML graph visualization
- "atlas": Export to parquet for embedding-atlas CLI
- "atlas_component": Return DataFrame for Streamlit component
collections (List[str], optional): For atlas modes - list of collections to export
limit (int, optional): For atlas modes - maximum number of embeddings to export per collection
compute_projection (bool): For atlas_component mode - whether to compute 2D projection
Returns:
--------
Union[str, pd.DataFrame]:
- str: Path to generated file (network and atlas modes)
- pd.DataFrame: DataFrame for Streamlit component (atlas_component mode)
Usage:
------
# Traditional network visualization
await visualize_graph()
# Embedding atlas CLI export
await visualize_graph(mode="atlas", destination_file_path="my_embeddings.parquet")
# Streamlit component DataFrame
df = await visualize_graph(mode="atlas_component")
# Then use in Streamlit:
from embedding_atlas.streamlit import embedding_atlas
selection = embedding_atlas(df, text="text", x="projection_x", y="projection_y")
"""
if mode == "atlas":
# Export embeddings for atlas CLI visualization
output_path = destination_file_path or "cognee_embeddings.parquet"
result_path = await export_embeddings_to_atlas(
output_path=output_path,
collections=collections,
limit=limit
)
logger.info(f"Embeddings exported to: {result_path}")
logger.info(f"To visualize with Embedding Atlas, run: embedding-atlas {result_path}")
return result_path
elif mode == "atlas_component":
# Return DataFrame for Streamlit component
df = await get_embeddings_for_atlas(
collections=collections,
limit=limit,
compute_projection=compute_projection
)
logger.info(f"Prepared DataFrame with {len(df)} embeddings for Streamlit component")
if compute_projection and 'projection_x' in df.columns:
logger.info("DataFrame includes 2D projection coordinates")
return df
else:
# Traditional network visualization
graph_engine = await get_graph_engine()
graph_data = await graph_engine.get_graph_data()
return graph
graph = await cognee_network_visualization(graph_data, destination_file_path)
if destination_file_path:
logger.info(f"The HTML file has been stored at path: {destination_file_path}")
else:
logger.info(
"The HTML file has been stored on your home directory! Navigate there with cd ~"
)
return graph
if __name__ == "__main__":

View file

@ -0,0 +1,406 @@
import os
import pandas as pd
from typing import List, Optional, Dict, Any, Union, Tuple
from pathlib import Path
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
logger = get_logger("EmbeddingAtlasExport")
async def get_embeddings_for_atlas(
collections: Optional[List[str]] = None,
limit: Optional[int] = None,
compute_projection: bool = True
) -> pd.DataFrame:
"""
Get embeddings from Cognee vector database as a DataFrame for use with
Embedding Atlas Streamlit component.
Parameters:
-----------
collections (List[str], optional): List of collection names to export.
If None, exports from all collections
limit (int, optional): Maximum number of embeddings to export per collection.
If None, exports all embeddings
compute_projection (bool): Whether to compute 2D projection for visualization.
If False, returns raw embeddings only.
Returns:
--------
pd.DataFrame: DataFrame ready for use with embedding_atlas() Streamlit component
Contains columns: id, text, collection, embedding, dim_0, dim_1, ...
If compute_projection=True, also includes: projection_x, projection_y
Usage with Streamlit:
--------------------
```python
import streamlit as st
from embedding_atlas.streamlit import embedding_atlas
from embedding_atlas.projection import compute_text_projection
import cognee
# Get embeddings DataFrame
df = await cognee.get_embeddings_for_atlas()
# Use with Embedding Atlas Streamlit component
selection = embedding_atlas(
df,
text="text",
x="projection_x",
y="projection_y",
show_table=True
)
```
"""
vector_engine = get_vector_engine()
# Get all collections if none specified
if collections is None:
collections = await _get_all_collections(vector_engine)
logger.info(f"Found {len(collections)} collections: {collections}")
all_data = []
for collection_name in collections:
logger.info(f"Getting embeddings from collection: {collection_name}")
try:
# Get all data points from the collection with embeddings
collection_data = await _get_collection_embeddings(
vector_engine, collection_name, limit
)
if collection_data:
all_data.extend(collection_data)
logger.info(f"Retrieved {len(collection_data)} embeddings from {collection_name}")
else:
logger.warning(f"No data found in collection: {collection_name}")
except Exception as e:
logger.error(f"Error getting embeddings from collection {collection_name}: {e}")
continue
if not all_data:
logger.warning("No embeddings found")
return pd.DataFrame()
# Convert to DataFrame
df = pd.DataFrame(all_data)
# Compute 2D projection if requested
if compute_projection and 'embedding' in df.columns:
try:
from embedding_atlas.projection import compute_text_projection
# Compute projection using the embedding_atlas library
df = compute_text_projection(
df,
text="text",
x="projection_x",
y="projection_y",
neighbors="neighbors"
)
logger.info("Computed 2D projection for embeddings")
except ImportError:
logger.warning("embedding-atlas not installed. Install with: pip install embedding-atlas")
logger.info("Returning DataFrame without projection")
except Exception as e:
logger.error(f"Error computing projection: {e}")
logger.info("Returning DataFrame without projection")
logger.info(f"Prepared DataFrame with {len(df)} embeddings for Atlas component")
return df
async def export_embeddings_to_atlas(
output_path: str = None,
collections: Optional[List[str]] = None,
limit: Optional[int] = None
) -> str:
"""
Export embeddings and metadata from Cognee vector database to parquet format
compatible with Embedding Atlas.
Parameters:
-----------
output_path (str, optional): Path where to save the parquet file.
If None, saves to current directory as 'cognee_embeddings.parquet'
collections (List[str], optional): List of collection names to export.
If None, exports from all collections
limit (int, optional): Maximum number of embeddings to export per collection.
If None, exports all embeddings
Returns:
--------
str: Path to the generated parquet file
Usage:
------
After calling this function, you can use the generated parquet file with embedding-atlas:
```
embedding-atlas your-dataset.parquet
```
"""
if output_path is None:
output_path = "cognee_embeddings.parquet"
vector_engine = get_vector_engine()
# Get all collections if none specified
if collections is None:
collections = await _get_all_collections(vector_engine)
logger.info(f"Found {len(collections)} collections: {collections}")
all_data = []
for collection_name in collections:
logger.info(f"Exporting embeddings from collection: {collection_name}")
try:
# Get all data points from the collection with embeddings
collection_data = await _get_collection_embeddings(
vector_engine, collection_name, limit
)
if collection_data:
all_data.extend(collection_data)
logger.info(f"Exported {len(collection_data)} embeddings from {collection_name}")
else:
logger.warning(f"No data found in collection: {collection_name}")
except Exception as e:
logger.error(f"Error exporting from collection {collection_name}: {e}")
continue
if not all_data:
raise ValueError("No embeddings found to export")
# Convert to DataFrame and save as parquet
df = pd.DataFrame(all_data)
# Ensure output directory exists
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(output_path, index=False)
logger.info(f"Successfully exported {len(all_data)} embeddings to {output_path}")
logger.info(f"You can now visualize with: embedding-atlas {output_path}")
return str(output_path)
async def _get_all_collections(vector_engine) -> List[str]:
"""Get all collection names from the vector database."""
try:
# LanceDB specific method
if hasattr(vector_engine, 'get_connection'):
connection = await vector_engine.get_connection()
if hasattr(connection, 'table_names'):
return await connection.table_names()
# ChromaDB specific method
if hasattr(vector_engine, 'get_collection_names'):
return await vector_engine.get_collection_names()
elif hasattr(vector_engine, 'list_collections'):
collections = await vector_engine.list_collections()
return [col.name if hasattr(col, 'name') else str(col) for col in collections]
else:
logger.warning("Vector engine doesn't support listing collections")
return []
except Exception as e:
logger.error(f"Error getting collections: {e}")
return []
async def _get_collection_embeddings(
vector_engine,
collection_name: str,
limit: Optional[int] = None
) -> List[Dict[str, Any]]:
"""Get all embeddings and metadata from a specific collection."""
try:
# First check if collection exists
if not await vector_engine.has_collection(collection_name):
logger.warning(f"Collection {collection_name} does not exist")
return []
collection_data = []
# Get collection object to work with directly
collection = await vector_engine.get_collection(collection_name)
if collection is None:
logger.warning(f"Could not get collection object for {collection_name}")
return []
# Strategy 1: LanceDB specific - query all data with vectors
if hasattr(collection, 'query') and hasattr(collection, 'to_pandas'):
try:
logger.info(f"Using LanceDB query method for {collection_name}")
# Query all data from LanceDB table
query = collection.query()
if limit:
query = query.limit(limit)
results_df = await query.to_pandas()
if not results_df.empty:
for _, row in results_df.iterrows():
item = {
'id': str(row.get('id', '')),
'collection': collection_name
}
# Extract text from payload
payload = row.get('payload', {})
if isinstance(payload, dict):
item['text'] = _extract_text_from_payload(payload)
# Add payload metadata
for key, value in payload.items():
if key not in ['id', 'text', 'embedding', 'vector']:
item[f'meta_{key}'] = value
else:
item['text'] = str(payload) if payload else ''
# Add embedding vector if available
if 'vector' in row and row['vector'] is not None:
embedding = row['vector']
if hasattr(embedding, 'tolist'):
embedding = embedding.tolist()
elif not isinstance(embedding, list):
embedding = list(embedding)
item['embedding'] = embedding
# Add individual embedding dimensions as columns for atlas
for j, val in enumerate(embedding):
item[f'dim_{j}'] = float(val)
collection_data.append(item)
logger.info(f"Exported {len(collection_data)} embeddings from LanceDB table {collection_name}")
return collection_data
except Exception as e:
logger.debug(f"LanceDB query failed for {collection_name}: {e}")
# Strategy 2: ChromaDB specific - collection.get()
if hasattr(collection, 'get'):
try:
logger.info(f"Using ChromaDB get method for {collection_name}")
results = await collection.get(
include=["metadatas", "embeddings", "documents"]
)
if results and 'ids' in results:
for i, id in enumerate(results['ids']):
item = {
'id': str(id),
'text': results.get('documents', [None])[i] or '',
'collection': collection_name
}
# Add embedding if available
if 'embeddings' in results and i < len(results['embeddings']):
embedding = results['embeddings'][i]
item['embedding'] = embedding
# Add individual embedding dimensions as columns for atlas
for j, val in enumerate(embedding):
item[f'dim_{j}'] = val
# Add metadata if available
if 'metadatas' in results and i < len(results['metadatas']):
metadata = results['metadatas'][i] or {}
for key, value in metadata.items():
if key not in ['id', 'text', 'embedding']:
item[f'meta_{key}'] = value
collection_data.append(item)
if limit and len(collection_data) >= limit:
break
logger.info(f"Exported {len(collection_data)} embeddings from ChromaDB collection {collection_name}")
return collection_data
except Exception as e:
logger.debug(f"ChromaDB-style get failed for {collection_name}: {e}")
# Strategy 3: Fallback - try using search with dummy query
try:
logger.info(f"Using search fallback for {collection_name}")
# Use a very generic search to get all data
search_results = await vector_engine.search(
collection_name=collection_name,
query_text="the", # Use a common word instead of empty query
limit=limit or 10000,
with_vector=True
)
if search_results:
for result in search_results:
if isinstance(result, ScoredResult):
item = {
'id': str(result.id),
'text': _extract_text_from_payload(result.payload),
'collection': collection_name,
'score': result.score
}
# Add embedding if available
if hasattr(result, 'vector') and result.vector:
embedding = result.vector
item['embedding'] = embedding
# Add individual embedding dimensions
for j, val in enumerate(embedding):
item[f'dim_{j}'] = val
# Add payload metadata
if result.payload:
for key, value in result.payload.items():
if key not in ['id', 'text', 'embedding']:
item[f'meta_{key}'] = value
collection_data.append(item)
logger.info(f"Exported {len(collection_data)} embeddings using search fallback for {collection_name}")
return collection_data
except Exception as e:
logger.debug(f"Search-based export failed for {collection_name}: {e}")
logger.warning(f"Could not export embeddings from {collection_name}")
return []
except Exception as e:
logger.error(f"Error getting embeddings from {collection_name}: {e}")
return []
def _extract_text_from_payload(payload: Dict[str, Any]) -> str:
"""Extract text content from payload data."""
if not payload:
return ""
# Common text field names
text_fields = ['text', 'content', 'document', 'data', 'name', 'title']
for field in text_fields:
if field in payload and payload[field]:
return str(payload[field])
# If no standard text field found, try to find any string value
for key, value in payload.items():
if isinstance(value, str) and len(value.strip()) > 0:
return value
return ""

View file

@ -1,7 +1,7 @@
[project]
name = "cognee"
version = "0.2.4"
version = "0.2.4b1"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = [
{ name = "Vasilije Markovic" },
@ -61,6 +61,7 @@ dependencies = [
"onnxruntime>=1.0.0,<2.0.0",
"pylance>=0.22.0,<1.0.0",
"kuzu (==0.11.0)",
"embedding-atlas>=0.1.0,<1.0.0",
"python-magic-bin<0.5 ; platform_system == 'Windows'", # Only needed for Windows
]