cognee/cognee/shared/utils.py

304 lines
9.4 KiB
Python

""" This module contains utility functions for the cognee. """
import os
from typing import BinaryIO, Union
import requests
import hashlib
from datetime import datetime, timezone
import graphistry
import networkx as nx
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tiktoken
import nltk
import logging
import sys
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.graph import get_graph_engine
from uuid import uuid4
import pathlib
from cognee.shared.exceptions import IngestionError
# Analytics Proxy Url, currently hosted by Vercel
proxy_url = "https://test.prometh.ai"
def get_anonymous_id():
"""Creates or reads a anonymous user id"""
home_dir = str(pathlib.Path(pathlib.Path(__file__).parent.parent.parent.resolve()))
if not os.path.isdir(home_dir):
os.makedirs(home_dir, exist_ok=True)
anonymous_id_file = os.path.join(home_dir, ".anon_id")
if not os.path.isfile(anonymous_id_file):
anonymous_id = str(uuid4())
with open(anonymous_id_file, "w", encoding="utf-8") as f:
f.write(anonymous_id)
else:
with open(anonymous_id_file, "r", encoding="utf-8") as f:
anonymous_id = f.read()
return anonymous_id
def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
if os.getenv("TELEMETRY_DISABLED"):
return
env = os.getenv("ENV")
if env in ["test", "dev"]:
return
current_time = datetime.now(timezone.utc)
payload = {
"anonymous_id": str(get_anonymous_id()),
"event_name": event_name,
"user_properties": {
"user_id": str(user_id),
},
"properties": {
"time": current_time.strftime("%m/%d/%Y"),
"user_id": str(user_id),
**additional_properties
},
}
response = requests.post(proxy_url, json=payload)
if response.status_code != 200:
print(f"Error sending telemetry through proxy: {response.status_code}")
def num_tokens_from_string(string: str, encoding_name: str) -> int:
"""Returns the number of tokens in a text string."""
# tiktoken.get_encoding("cl100k_base")
encoding = tiktoken.encoding_for_model(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
h = hashlib.md5()
try:
if isinstance(file_obj, str):
with open(file_obj, 'rb') as file:
while True:
# Reading is buffered, so we can read smaller chunks.
chunk = file.read(h.block_size)
if not chunk:
break
h.update(chunk)
else:
while True:
# Reading is buffered, so we can read smaller chunks.
chunk = file_obj.read(h.block_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
except IOError as e:
raise IngestionError(message=f"Failed to load data from {file}: {e}")
def trim_text_to_max_tokens(text: str, max_tokens: int, encoding_name: str) -> str:
"""
Trims the text so that the number of tokens does not exceed max_tokens.
Args:
text (str): Original text string to be trimmed.
max_tokens (int): Maximum number of tokens allowed.
encoding_name (str): The name of the token encoding to use.
Returns:
str: Trimmed version of text or original text if under the limit.
"""
# First check the number of tokens
num_tokens = num_tokens_from_string(text, encoding_name)
# If the number of tokens is within the limit, return the text as is
if num_tokens <= max_tokens:
return text
# If the number exceeds the limit, trim the text
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
encoded_text = tiktoken.get_encoding(encoding_name).encode(text)
trimmed_encoded_text = encoded_text[:max_tokens]
# Decoding the trimmed text
trimmed_text = tiktoken.get_encoding(encoding_name).decode(trimmed_encoded_text)
return trimmed_text
def generate_color_palette(unique_layers):
colormap = plt.cm.get_cmap("viridis", len(unique_layers))
colors = [colormap(i) for i in range(len(unique_layers))]
hex_colors = ["#%02x%02x%02x" % (int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255)) for rgb in colors]
return dict(zip(unique_layers, hex_colors))
async def register_graphistry():
config = get_base_config()
graphistry.register(api = 3, username = config.graphistry_username, password = config.graphistry_password)
def prepare_edges(graph, source, target, edge_key):
edge_list = [{
source: str(edge[0]),
target: str(edge[1]),
edge_key: str(edge[2]),
} for edge in graph.edges(keys = True, data = True)]
return pd.DataFrame(edge_list)
def prepare_nodes(graph, include_size=False):
nodes_data = []
for node in graph.nodes:
node_info = graph.nodes[node]
if not node_info:
continue
node_data = {
"id": str(node),
"name": node_info["name"] if "name" in node_info else str(node),
}
if include_size:
default_size = 10 # Default node size
larger_size = 20 # Size for nodes with specific keywords in their ID
keywords = ["DOCUMENT", "User"]
node_size = larger_size if any(keyword in str(node) for keyword in keywords) else default_size
node_data["size"] = node_size
nodes_data.append(node_data)
return pd.DataFrame(nodes_data)
async def render_graph(graph, include_nodes=False, include_color=False, include_size=False, include_labels=False):
await register_graphistry()
if not isinstance(graph, nx.MultiDiGraph):
graph_engine = await get_graph_engine()
networkx_graph = nx.MultiDiGraph()
(nodes, edges) = await graph_engine.get_graph_data()
networkx_graph.add_nodes_from(nodes)
networkx_graph.add_edges_from(edges)
graph = networkx_graph
edges = prepare_edges(graph, "source_node", "target_node", "relationship_name")
plotter = graphistry.edges(edges, "source_node", "target_node")
plotter = plotter.bind(edge_label = "relationship_name")
if include_nodes:
nodes = prepare_nodes(graph, include_size = include_size)
plotter = plotter.nodes(nodes, "id")
if include_size:
plotter = plotter.bind(point_size = "size")
if include_color:
pass
# unique_layers = nodes["layer_description"].unique()
# color_palette = generate_color_palette(unique_layers)
# plotter = plotter.encode_point_color("layer_description", categorical_mapping=color_palette,
# default_mapping="silver")
if include_labels:
plotter = plotter.bind(point_label = "name")
# Visualization
url = plotter.plot(render=False, as_files=True, memoize=False)
print(f"Graph is visualized at: {url}")
return url
def sanitize_df(df):
"""Replace NaNs and infinities in a DataFrame with None, making it JSON compliant."""
return df.replace([np.inf, -np.inf, np.nan], None)
def get_entities(tagged_tokens):
nltk.download("maxent_ne_chunker", quiet=True)
from nltk.chunk import ne_chunk
return ne_chunk(tagged_tokens)
def extract_pos_tags(sentence):
"""Extract Part-of-Speech (POS) tags for words in a sentence."""
# Ensure that the necessary NLTK resources are downloaded
nltk.download("words", quiet=True)
nltk.download("punkt", quiet=True)
nltk.download("averaged_perceptron_tagger", quiet=True)
from nltk.tag import pos_tag
from nltk.tokenize import word_tokenize
# Tokenize the sentence into words
tokens = word_tokenize(sentence)
# Tag each word with its corresponding POS tag
pos_tags = pos_tag(tokens)
return pos_tags
def extract_named_entities(sentence):
"""Extract Named Entities from a sentence."""
# Tokenize the sentence into words
tagged_tokens = extract_pos_tags(sentence)
# Perform Named Entity Recognition (NER) on the tagged tokens
entities = get_entities(tagged_tokens)
return entities
def extract_sentiment_vader(text):
"""
Analyzes the sentiment of a given text using the VADER Sentiment Intensity Analyzer.
Parameters:
text (str): The text to analyze.
Returns:
dict: A dictionary containing the polarity scores for the text.
"""
from nltk.sentiment import SentimentIntensityAnalyzer
nltk.download("vader_lexicon", quiet=True)
# Initialize the VADER Sentiment Intensity Analyzer
sia = SentimentIntensityAnalyzer()
# Obtain the polarity scores for the text
polarity_scores = sia.polarity_scores(text)
return polarity_scores
def setup_logging(log_level=logging.INFO):
""" This method sets up the logging configuration. """
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s\n")
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
stream_handler.setLevel(log_level)
logging.basicConfig(
level=log_level,
handlers=[stream_handler],
)
if __name__ == "__main__":
sample_text = "I love sunny days, but I hate the rain."
sentiment_scores = extract_sentiment_vader(sample_text)
print("Sentiment analysis results:", sentiment_scores)