Compare commits
1 commit
main
...
example_fo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1eb4daef5 |
4 changed files with 504 additions and 13 deletions
|
|
@ -27,6 +27,9 @@ from .api.v1.visualize import visualize_graph, start_visualization_server
|
|||
from cognee.modules.visualization.cognee_network_visualization import (
|
||||
cognee_network_visualization,
|
||||
)
|
||||
from cognee.modules.visualization.embedding_atlas_export import (
|
||||
get_embeddings_for_atlas,
|
||||
)
|
||||
|
||||
# Pipelines
|
||||
from .modules import pipelines
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
from typing import Optional, List, Literal, Union
|
||||
import pandas as pd
|
||||
from cognee.modules.visualization.cognee_network_visualization import (
|
||||
cognee_network_visualization,
|
||||
)
|
||||
from cognee.modules.visualization.embedding_atlas_export import (
|
||||
export_embeddings_to_atlas,
|
||||
get_embeddings_for_atlas,
|
||||
)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging, ERROR
|
||||
|
||||
|
|
@ -11,20 +17,95 @@ import asyncio
|
|||
logger = get_logger()
|
||||
|
||||
|
||||
async def visualize_graph(destination_file_path: str = None):
|
||||
graph_engine = await get_graph_engine()
|
||||
graph_data = await graph_engine.get_graph_data()
|
||||
|
||||
graph = await cognee_network_visualization(graph_data, destination_file_path)
|
||||
|
||||
if destination_file_path:
|
||||
logger.info(f"The HTML file has been stored at path: {destination_file_path}")
|
||||
else:
|
||||
logger.info(
|
||||
"The HTML file has been stored on your home directory! Navigate there with cd ~"
|
||||
async def visualize_graph(
|
||||
destination_file_path: str = None,
|
||||
mode: Literal["network", "atlas", "atlas_component"] = "network",
|
||||
collections: Optional[List[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
compute_projection: bool = True
|
||||
) -> Union[str, pd.DataFrame]:
|
||||
"""
|
||||
Visualize Cognee's knowledge graph using different visualization modes.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
destination_file_path (str, optional): Path where to save the output file.
|
||||
For network mode: saves HTML file
|
||||
For atlas mode: saves parquet file
|
||||
Not used for atlas_component mode
|
||||
mode (str): Visualization mode:
|
||||
- "network": Interactive HTML graph visualization
|
||||
- "atlas": Export to parquet for embedding-atlas CLI
|
||||
- "atlas_component": Return DataFrame for Streamlit component
|
||||
collections (List[str], optional): For atlas modes - list of collections to export
|
||||
limit (int, optional): For atlas modes - maximum number of embeddings to export per collection
|
||||
compute_projection (bool): For atlas_component mode - whether to compute 2D projection
|
||||
|
||||
Returns:
|
||||
--------
|
||||
Union[str, pd.DataFrame]:
|
||||
- str: Path to generated file (network and atlas modes)
|
||||
- pd.DataFrame: DataFrame for Streamlit component (atlas_component mode)
|
||||
|
||||
Usage:
|
||||
------
|
||||
# Traditional network visualization
|
||||
await visualize_graph()
|
||||
|
||||
# Embedding atlas CLI export
|
||||
await visualize_graph(mode="atlas", destination_file_path="my_embeddings.parquet")
|
||||
|
||||
# Streamlit component DataFrame
|
||||
df = await visualize_graph(mode="atlas_component")
|
||||
|
||||
# Then use in Streamlit:
|
||||
from embedding_atlas.streamlit import embedding_atlas
|
||||
selection = embedding_atlas(df, text="text", x="projection_x", y="projection_y")
|
||||
"""
|
||||
|
||||
if mode == "atlas":
|
||||
# Export embeddings for atlas CLI visualization
|
||||
output_path = destination_file_path or "cognee_embeddings.parquet"
|
||||
result_path = await export_embeddings_to_atlas(
|
||||
output_path=output_path,
|
||||
collections=collections,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
logger.info(f"Embeddings exported to: {result_path}")
|
||||
logger.info(f"To visualize with Embedding Atlas, run: embedding-atlas {result_path}")
|
||||
|
||||
return result_path
|
||||
|
||||
elif mode == "atlas_component":
|
||||
# Return DataFrame for Streamlit component
|
||||
df = await get_embeddings_for_atlas(
|
||||
collections=collections,
|
||||
limit=limit,
|
||||
compute_projection=compute_projection
|
||||
)
|
||||
|
||||
logger.info(f"Prepared DataFrame with {len(df)} embeddings for Streamlit component")
|
||||
if compute_projection and 'projection_x' in df.columns:
|
||||
logger.info("DataFrame includes 2D projection coordinates")
|
||||
|
||||
return df
|
||||
|
||||
else:
|
||||
# Traditional network visualization
|
||||
graph_engine = await get_graph_engine()
|
||||
graph_data = await graph_engine.get_graph_data()
|
||||
|
||||
return graph
|
||||
graph = await cognee_network_visualization(graph_data, destination_file_path)
|
||||
|
||||
if destination_file_path:
|
||||
logger.info(f"The HTML file has been stored at path: {destination_file_path}")
|
||||
else:
|
||||
logger.info(
|
||||
"The HTML file has been stored on your home directory! Navigate there with cd ~"
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
406
cognee/modules/visualization/embedding_atlas_export.py
Normal file
406
cognee/modules/visualization/embedding_atlas_export.py
Normal file
|
|
@ -0,0 +1,406 @@
|
|||
import os
|
||||
import pandas as pd
|
||||
from typing import List, Optional, Dict, Any, Union, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||
|
||||
logger = get_logger("EmbeddingAtlasExport")
|
||||
|
||||
|
||||
async def get_embeddings_for_atlas(
|
||||
collections: Optional[List[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
compute_projection: bool = True
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Get embeddings from Cognee vector database as a DataFrame for use with
|
||||
Embedding Atlas Streamlit component.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
collections (List[str], optional): List of collection names to export.
|
||||
If None, exports from all collections
|
||||
limit (int, optional): Maximum number of embeddings to export per collection.
|
||||
If None, exports all embeddings
|
||||
compute_projection (bool): Whether to compute 2D projection for visualization.
|
||||
If False, returns raw embeddings only.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
pd.DataFrame: DataFrame ready for use with embedding_atlas() Streamlit component
|
||||
Contains columns: id, text, collection, embedding, dim_0, dim_1, ...
|
||||
If compute_projection=True, also includes: projection_x, projection_y
|
||||
|
||||
Usage with Streamlit:
|
||||
--------------------
|
||||
```python
|
||||
import streamlit as st
|
||||
from embedding_atlas.streamlit import embedding_atlas
|
||||
from embedding_atlas.projection import compute_text_projection
|
||||
import cognee
|
||||
|
||||
# Get embeddings DataFrame
|
||||
df = await cognee.get_embeddings_for_atlas()
|
||||
|
||||
# Use with Embedding Atlas Streamlit component
|
||||
selection = embedding_atlas(
|
||||
df,
|
||||
text="text",
|
||||
x="projection_x",
|
||||
y="projection_y",
|
||||
show_table=True
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
# Get all collections if none specified
|
||||
if collections is None:
|
||||
collections = await _get_all_collections(vector_engine)
|
||||
logger.info(f"Found {len(collections)} collections: {collections}")
|
||||
|
||||
all_data = []
|
||||
|
||||
for collection_name in collections:
|
||||
logger.info(f"Getting embeddings from collection: {collection_name}")
|
||||
|
||||
try:
|
||||
# Get all data points from the collection with embeddings
|
||||
collection_data = await _get_collection_embeddings(
|
||||
vector_engine, collection_name, limit
|
||||
)
|
||||
|
||||
if collection_data:
|
||||
all_data.extend(collection_data)
|
||||
logger.info(f"Retrieved {len(collection_data)} embeddings from {collection_name}")
|
||||
else:
|
||||
logger.warning(f"No data found in collection: {collection_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting embeddings from collection {collection_name}: {e}")
|
||||
continue
|
||||
|
||||
if not all_data:
|
||||
logger.warning("No embeddings found")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(all_data)
|
||||
|
||||
# Compute 2D projection if requested
|
||||
if compute_projection and 'embedding' in df.columns:
|
||||
try:
|
||||
from embedding_atlas.projection import compute_text_projection
|
||||
|
||||
# Compute projection using the embedding_atlas library
|
||||
df = compute_text_projection(
|
||||
df,
|
||||
text="text",
|
||||
x="projection_x",
|
||||
y="projection_y",
|
||||
neighbors="neighbors"
|
||||
)
|
||||
logger.info("Computed 2D projection for embeddings")
|
||||
|
||||
except ImportError:
|
||||
logger.warning("embedding-atlas not installed. Install with: pip install embedding-atlas")
|
||||
logger.info("Returning DataFrame without projection")
|
||||
except Exception as e:
|
||||
logger.error(f"Error computing projection: {e}")
|
||||
logger.info("Returning DataFrame without projection")
|
||||
|
||||
logger.info(f"Prepared DataFrame with {len(df)} embeddings for Atlas component")
|
||||
return df
|
||||
|
||||
|
||||
async def export_embeddings_to_atlas(
|
||||
output_path: str = None,
|
||||
collections: Optional[List[str]] = None,
|
||||
limit: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Export embeddings and metadata from Cognee vector database to parquet format
|
||||
compatible with Embedding Atlas.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
output_path (str, optional): Path where to save the parquet file.
|
||||
If None, saves to current directory as 'cognee_embeddings.parquet'
|
||||
collections (List[str], optional): List of collection names to export.
|
||||
If None, exports from all collections
|
||||
limit (int, optional): Maximum number of embeddings to export per collection.
|
||||
If None, exports all embeddings
|
||||
|
||||
Returns:
|
||||
--------
|
||||
str: Path to the generated parquet file
|
||||
|
||||
Usage:
|
||||
------
|
||||
After calling this function, you can use the generated parquet file with embedding-atlas:
|
||||
```
|
||||
embedding-atlas your-dataset.parquet
|
||||
```
|
||||
"""
|
||||
|
||||
if output_path is None:
|
||||
output_path = "cognee_embeddings.parquet"
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
# Get all collections if none specified
|
||||
if collections is None:
|
||||
collections = await _get_all_collections(vector_engine)
|
||||
logger.info(f"Found {len(collections)} collections: {collections}")
|
||||
|
||||
all_data = []
|
||||
|
||||
for collection_name in collections:
|
||||
logger.info(f"Exporting embeddings from collection: {collection_name}")
|
||||
|
||||
try:
|
||||
# Get all data points from the collection with embeddings
|
||||
collection_data = await _get_collection_embeddings(
|
||||
vector_engine, collection_name, limit
|
||||
)
|
||||
|
||||
if collection_data:
|
||||
all_data.extend(collection_data)
|
||||
logger.info(f"Exported {len(collection_data)} embeddings from {collection_name}")
|
||||
else:
|
||||
logger.warning(f"No data found in collection: {collection_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting from collection {collection_name}: {e}")
|
||||
continue
|
||||
|
||||
if not all_data:
|
||||
raise ValueError("No embeddings found to export")
|
||||
|
||||
# Convert to DataFrame and save as parquet
|
||||
df = pd.DataFrame(all_data)
|
||||
|
||||
# Ensure output directory exists
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
df.to_parquet(output_path, index=False)
|
||||
|
||||
logger.info(f"Successfully exported {len(all_data)} embeddings to {output_path}")
|
||||
logger.info(f"You can now visualize with: embedding-atlas {output_path}")
|
||||
|
||||
return str(output_path)
|
||||
|
||||
|
||||
async def _get_all_collections(vector_engine) -> List[str]:
|
||||
"""Get all collection names from the vector database."""
|
||||
try:
|
||||
# LanceDB specific method
|
||||
if hasattr(vector_engine, 'get_connection'):
|
||||
connection = await vector_engine.get_connection()
|
||||
if hasattr(connection, 'table_names'):
|
||||
return await connection.table_names()
|
||||
|
||||
# ChromaDB specific method
|
||||
if hasattr(vector_engine, 'get_collection_names'):
|
||||
return await vector_engine.get_collection_names()
|
||||
elif hasattr(vector_engine, 'list_collections'):
|
||||
collections = await vector_engine.list_collections()
|
||||
return [col.name if hasattr(col, 'name') else str(col) for col in collections]
|
||||
else:
|
||||
logger.warning("Vector engine doesn't support listing collections")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting collections: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def _get_collection_embeddings(
|
||||
vector_engine,
|
||||
collection_name: str,
|
||||
limit: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get all embeddings and metadata from a specific collection."""
|
||||
|
||||
try:
|
||||
# First check if collection exists
|
||||
if not await vector_engine.has_collection(collection_name):
|
||||
logger.warning(f"Collection {collection_name} does not exist")
|
||||
return []
|
||||
|
||||
collection_data = []
|
||||
|
||||
# Get collection object to work with directly
|
||||
collection = await vector_engine.get_collection(collection_name)
|
||||
|
||||
if collection is None:
|
||||
logger.warning(f"Could not get collection object for {collection_name}")
|
||||
return []
|
||||
|
||||
# Strategy 1: LanceDB specific - query all data with vectors
|
||||
if hasattr(collection, 'query') and hasattr(collection, 'to_pandas'):
|
||||
try:
|
||||
logger.info(f"Using LanceDB query method for {collection_name}")
|
||||
|
||||
# Query all data from LanceDB table
|
||||
query = collection.query()
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
results_df = await query.to_pandas()
|
||||
|
||||
if not results_df.empty:
|
||||
for _, row in results_df.iterrows():
|
||||
item = {
|
||||
'id': str(row.get('id', '')),
|
||||
'collection': collection_name
|
||||
}
|
||||
|
||||
# Extract text from payload
|
||||
payload = row.get('payload', {})
|
||||
if isinstance(payload, dict):
|
||||
item['text'] = _extract_text_from_payload(payload)
|
||||
|
||||
# Add payload metadata
|
||||
for key, value in payload.items():
|
||||
if key not in ['id', 'text', 'embedding', 'vector']:
|
||||
item[f'meta_{key}'] = value
|
||||
else:
|
||||
item['text'] = str(payload) if payload else ''
|
||||
|
||||
# Add embedding vector if available
|
||||
if 'vector' in row and row['vector'] is not None:
|
||||
embedding = row['vector']
|
||||
if hasattr(embedding, 'tolist'):
|
||||
embedding = embedding.tolist()
|
||||
elif not isinstance(embedding, list):
|
||||
embedding = list(embedding)
|
||||
|
||||
item['embedding'] = embedding
|
||||
# Add individual embedding dimensions as columns for atlas
|
||||
for j, val in enumerate(embedding):
|
||||
item[f'dim_{j}'] = float(val)
|
||||
|
||||
collection_data.append(item)
|
||||
|
||||
logger.info(f"Exported {len(collection_data)} embeddings from LanceDB table {collection_name}")
|
||||
return collection_data
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"LanceDB query failed for {collection_name}: {e}")
|
||||
|
||||
# Strategy 2: ChromaDB specific - collection.get()
|
||||
if hasattr(collection, 'get'):
|
||||
try:
|
||||
logger.info(f"Using ChromaDB get method for {collection_name}")
|
||||
results = await collection.get(
|
||||
include=["metadatas", "embeddings", "documents"]
|
||||
)
|
||||
|
||||
if results and 'ids' in results:
|
||||
for i, id in enumerate(results['ids']):
|
||||
item = {
|
||||
'id': str(id),
|
||||
'text': results.get('documents', [None])[i] or '',
|
||||
'collection': collection_name
|
||||
}
|
||||
|
||||
# Add embedding if available
|
||||
if 'embeddings' in results and i < len(results['embeddings']):
|
||||
embedding = results['embeddings'][i]
|
||||
item['embedding'] = embedding
|
||||
# Add individual embedding dimensions as columns for atlas
|
||||
for j, val in enumerate(embedding):
|
||||
item[f'dim_{j}'] = val
|
||||
|
||||
# Add metadata if available
|
||||
if 'metadatas' in results and i < len(results['metadatas']):
|
||||
metadata = results['metadatas'][i] or {}
|
||||
for key, value in metadata.items():
|
||||
if key not in ['id', 'text', 'embedding']:
|
||||
item[f'meta_{key}'] = value
|
||||
|
||||
collection_data.append(item)
|
||||
|
||||
if limit and len(collection_data) >= limit:
|
||||
break
|
||||
|
||||
logger.info(f"Exported {len(collection_data)} embeddings from ChromaDB collection {collection_name}")
|
||||
return collection_data
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"ChromaDB-style get failed for {collection_name}: {e}")
|
||||
|
||||
# Strategy 3: Fallback - try using search with dummy query
|
||||
try:
|
||||
logger.info(f"Using search fallback for {collection_name}")
|
||||
# Use a very generic search to get all data
|
||||
search_results = await vector_engine.search(
|
||||
collection_name=collection_name,
|
||||
query_text="the", # Use a common word instead of empty query
|
||||
limit=limit or 10000,
|
||||
with_vector=True
|
||||
)
|
||||
|
||||
if search_results:
|
||||
for result in search_results:
|
||||
if isinstance(result, ScoredResult):
|
||||
item = {
|
||||
'id': str(result.id),
|
||||
'text': _extract_text_from_payload(result.payload),
|
||||
'collection': collection_name,
|
||||
'score': result.score
|
||||
}
|
||||
|
||||
# Add embedding if available
|
||||
if hasattr(result, 'vector') and result.vector:
|
||||
embedding = result.vector
|
||||
item['embedding'] = embedding
|
||||
# Add individual embedding dimensions
|
||||
for j, val in enumerate(embedding):
|
||||
item[f'dim_{j}'] = val
|
||||
|
||||
# Add payload metadata
|
||||
if result.payload:
|
||||
for key, value in result.payload.items():
|
||||
if key not in ['id', 'text', 'embedding']:
|
||||
item[f'meta_{key}'] = value
|
||||
|
||||
collection_data.append(item)
|
||||
|
||||
logger.info(f"Exported {len(collection_data)} embeddings using search fallback for {collection_name}")
|
||||
return collection_data
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Search-based export failed for {collection_name}: {e}")
|
||||
|
||||
logger.warning(f"Could not export embeddings from {collection_name}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting embeddings from {collection_name}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _extract_text_from_payload(payload: Dict[str, Any]) -> str:
|
||||
"""Extract text content from payload data."""
|
||||
if not payload:
|
||||
return ""
|
||||
|
||||
# Common text field names
|
||||
text_fields = ['text', 'content', 'document', 'data', 'name', 'title']
|
||||
|
||||
for field in text_fields:
|
||||
if field in payload and payload[field]:
|
||||
return str(payload[field])
|
||||
|
||||
# If no standard text field found, try to find any string value
|
||||
for key, value in payload.items():
|
||||
if isinstance(value, str) and len(value.strip()) > 0:
|
||||
return value
|
||||
|
||||
return ""
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "cognee"
|
||||
|
||||
version = "0.2.4"
|
||||
version = "0.2.4b1"
|
||||
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
||||
authors = [
|
||||
{ name = "Vasilije Markovic" },
|
||||
|
|
@ -61,6 +61,7 @@ dependencies = [
|
|||
"onnxruntime>=1.0.0,<2.0.0",
|
||||
"pylance>=0.22.0,<1.0.0",
|
||||
"kuzu (==0.11.0)",
|
||||
"embedding-atlas>=0.1.0,<1.0.0",
|
||||
"python-magic-bin<0.5 ; platform_system == 'Windows'", # Only needed for Windows
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue