fix: download nltk files when needed

This commit is contained in:
Boris Arzentar 2024-04-21 22:03:18 +02:00
parent bdd664a2aa
commit e58251b00c
8 changed files with 54 additions and 40 deletions

View file

@ -5,7 +5,6 @@ import dlt
import duckdb import duckdb
import cognee.modules.ingestion as ingestion import cognee.modules.ingestion as ingestion
from cognee.infrastructure import infrastructure_config from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.files import get_file_metadata
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.discovery import discover_directory_datasets from cognee.modules.discovery import discover_directory_datasets
@ -85,7 +84,7 @@ async def add_files(file_paths: List[str], dataset_name: str):
data_id = ingestion.identify(classified_data) data_id = ingestion.identify(classified_data)
file_metadata = get_file_metadata(classified_data.get_data()) file_metadata = classified_data.get_metadata()
yield { yield {
"id": data_id, "id": data_id,

View file

@ -1,13 +1,11 @@
import nltk
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from cognee.utils import extract_pos_tags
def extract_keywords(text: str) -> list[str]: def extract_keywords(text: str) -> list[str]:
if len(text) == 0: if len(text) == 0:
raise ValueError("extract_keywords cannot extract keywords from empty text.") raise ValueError("extract_keywords cannot extract keywords from empty text.")
tokens = nltk.word_tokenize(text) tags = extract_pos_tags(text)
tags = nltk.pos_tag(tokens)
nouns = [word for (word, tag) in tags if tag == "NN"] nouns = [word for (word, tag) in tags if tag == "NN"]
vectorizer = TfidfVectorizer() vectorizer = TfidfVectorizer()

View file

@ -56,9 +56,9 @@ async def add_cognitive_layer_graphs(
dict(relationship_name = "contains"), dict(relationship_name = "contains"),
)) ))
pos_tags = await extract_pos_tags(node.entity_description) pos_tags = extract_pos_tags(node.entity_description)
named_entities = await extract_named_entities(node.entity_description) named_entities = extract_named_entities(node.entity_description)
sentiment = await extract_sentiment_vader(node.entity_description) sentiment = extract_sentiment_vader(node.entity_description)
graph_nodes.append(( graph_nodes.append((
node_id, node_id,

View file

@ -13,9 +13,14 @@ class BinaryData(IngestionData):
self.data = data self.data = data
def get_identifier(self): def get_identifier(self):
metadata = self.get_metadata()
return metadata["mime_type"] + "_" + "|".join(metadata["keywords"])
def get_metadata(self):
self.ensure_metadata() self.ensure_metadata()
return self.metadata["mime_type"] + "_" + "|".join(self.metadata["keywords"]) return self.metadata
def ensure_metadata(self): def ensure_metadata(self):
if self.metadata is None: if self.metadata is None:

View file

@ -8,3 +8,6 @@ class IngestionData(Protocol):
def get_identifier(self): def get_identifier(self):
raise NotImplementedError() raise NotImplementedError()
def get_metadata(self):
raise NotImplementedError()

View file

@ -7,14 +7,24 @@ def create_text_data(data: str):
class TextData(IngestionData): class TextData(IngestionData):
data: str = None data: str = None
metadata = None
def __init__(self, data: BinaryIO): def __init__(self, data: BinaryIO):
self.data = data self.data = data
def get_identifier(self): def get_identifier(self):
keywords = extract_keywords(self.data) keywords = self.get_metadata()["keywords"]
return "text/plain" + "_" + "|".join(keywords) return "text/plain" + "_" + "|".join(keywords)
def get_metadata(self):
self.ensure_metadata()
return self.metadata
def ensure_metadata(self):
if self.metadata is None:
self.metadata = dict(keywords = extract_keywords(self.data))
def get_data(self): def get_data(self):
return self.data return self.data

View file

@ -7,13 +7,11 @@ import numpy as np
import pandas as pd import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import tiktoken import tiktoken
from nltk.sentiment import SentimentIntensityAnalyzer
import nltk import nltk
from nltk.tokenize import word_tokenize
from nltk.tag import pos_tag
from nltk.chunk import ne_chunk
from cognee.config import Config from cognee.config import Config
config = Config()
config.load()
def get_document_names(doc_input): def get_document_names(doc_input):
""" """
@ -93,8 +91,6 @@ def trim_text_to_max_tokens(text: str, max_tokens: int, encoding_name: str) -> s
return trimmed_text return trimmed_text
def format_dict(d): def format_dict(d):
"""Format a dictionary as a string.""" """Format a dictionary as a string."""
# Initialize an empty list to store formatted items # Initialize an empty list to store formatted items
@ -117,9 +113,6 @@ def format_dict(d):
return formatted_string return formatted_string
config = Config()
config.load()
def generate_color_palette(unique_layers): def generate_color_palette(unique_layers):
colormap = plt.cm.get_cmap("viridis", len(unique_layers)) colormap = plt.cm.get_cmap("viridis", len(unique_layers))
colors = [colormap(i) for i in range(len(unique_layers))] colors = [colormap(i) for i in range(len(unique_layers))]
@ -140,8 +133,8 @@ def prepare_nodes(graph, include_size=False):
nodes_data = [] nodes_data = []
for node in graph.nodes: for node in graph.nodes:
node_info = graph.nodes[node] node_info = graph.nodes[node]
description = node_info.get('layer_description', {}).get('layer', 'Default Layer') if isinstance( description = node_info.get("layer_description", {}).get("layer", "Default Layer") if isinstance(
node_info.get('layer_description'), dict) else node_info.get('layer_description', 'Default Layer') node_info.get("layer_description"), dict) else node_info.get("layer_description", "Default Layer")
# description = node_info['layer_description']['layer'] if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info['layer_description'] else node_info.get('layer_description', node) # description = node_info['layer_description']['layer'] if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info['layer_description'] else node_info.get('layer_description', node)
# if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info.get('layer_description'): # if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info.get('layer_description'):
# description = node_info['layer_description']['layer'] # description = node_info['layer_description']['layer']
@ -161,8 +154,6 @@ def prepare_nodes(graph, include_size=False):
return pd.DataFrame(nodes_data) return pd.DataFrame(nodes_data)
async def render_graph(graph, include_nodes=False, include_color=False, include_size=False, include_labels=False): async def render_graph(graph, include_nodes=False, include_color=False, include_size=False, include_labels=False):
await register_graphistry() await register_graphistry()
edges = prepare_edges(graph) edges = prepare_edges(graph)
@ -174,7 +165,7 @@ async def render_graph(graph, include_nodes=False, include_color=False, include_
if include_size: if include_size:
plotter = plotter.bind(point_size='size') plotter = plotter.bind(point_size="size")
if include_color: if include_color:
@ -185,7 +176,7 @@ async def render_graph(graph, include_nodes=False, include_color=False, include_
if include_labels: if include_labels:
plotter = plotter.bind(point_label='layer_description') plotter = plotter.bind(point_label = "layer_description")
@ -199,14 +190,23 @@ def sanitize_df(df):
return df.replace([np.inf, -np.inf, np.nan], None) return df.replace([np.inf, -np.inf, np.nan], None)
def get_entities(tagged_tokens):
# # Ensure that the necessary NLTK resources are downloaded nltk.download("maxent_ne_chunker")
# nltk.download('maxent_ne_chunker') from nltk.chunk import ne_chunk
# nltk.download('words') return ne_chunk(tagged_tokens)
async def extract_pos_tags(sentence): def extract_pos_tags(sentence):
"""Extract Part-of-Speech (POS) tags for words in a sentence.""" """Extract Part-of-Speech (POS) tags for words in a sentence."""
# Ensure that the necessary NLTK resources are downloaded
nltk.download("words")
nltk.download("punkt")
nltk.download("averaged_perceptron_tagger")
from nltk.tag import pos_tag
from nltk.tokenize import word_tokenize
# Tokenize the sentence into words # Tokenize the sentence into words
tokens = word_tokenize(sentence) tokens = word_tokenize(sentence)
@ -216,20 +216,18 @@ async def extract_pos_tags(sentence):
return pos_tags return pos_tags
async def extract_named_entities(sentence): def extract_named_entities(sentence):
"""Extract Named Entities from a sentence.""" """Extract Named Entities from a sentence."""
# Tokenize the sentence into words # Tokenize the sentence into words
tokens = word_tokenize(sentence) tagged_tokens = extract_pos_tags(sentence)
# Perform POS tagging on the tokenized sentence
tagged = pos_tag(tokens)
# Perform Named Entity Recognition (NER) on the tagged tokens # Perform Named Entity Recognition (NER) on the tagged tokens
entities = ne_chunk(tagged) entities = get_entities(tagged_tokens)
return entities return entities
async def extract_sentiment_vader(text):
def extract_sentiment_vader(text):
""" """
Analyzes the sentiment of a given text using the VADER Sentiment Intensity Analyzer. Analyzes the sentiment of a given text using the VADER Sentiment Intensity Analyzer.
@ -239,6 +237,7 @@ async def extract_sentiment_vader(text):
Returns: Returns:
dict: A dictionary containing the polarity scores for the text. dict: A dictionary containing the polarity scores for the text.
""" """
from nltk.sentiment import SentimentIntensityAnalyzer
nltk.download("vader_lexicon") nltk.download("vader_lexicon")

View file

@ -151,10 +151,10 @@
"graph_client = await get_graph_client(GraphDBType.NETWORKX)\n", "graph_client = await get_graph_client(GraphDBType.NETWORKX)\n",
"graph = graph_client.graph\n", "graph = graph_client.graph\n",
"\n", "\n",
"results = await search_similarity(\"At My Window was released by which American singer-songwriter?\", graph)\n", "results = await search_similarity(\"Who is Ernie Grunwald?\", graph)\n",
"\n", "\n",
"for result in results:\n", "for result in results:\n",
" print(\"At My Window\" in result)\n", " print(\"Ernie Grunwald\" in result)\n",
" print(result)" " print(result)"
] ]
} }