fix: download nltk files when needed
This commit is contained in:
parent
bdd664a2aa
commit
e58251b00c
8 changed files with 54 additions and 40 deletions
|
|
@ -5,7 +5,6 @@ import dlt
|
||||||
import duckdb
|
import duckdb
|
||||||
import cognee.modules.ingestion as ingestion
|
import cognee.modules.ingestion as ingestion
|
||||||
from cognee.infrastructure import infrastructure_config
|
from cognee.infrastructure import infrastructure_config
|
||||||
from cognee.infrastructure.files import get_file_metadata
|
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
from cognee.modules.discovery import discover_directory_datasets
|
from cognee.modules.discovery import discover_directory_datasets
|
||||||
|
|
||||||
|
|
@ -85,7 +84,7 @@ async def add_files(file_paths: List[str], dataset_name: str):
|
||||||
|
|
||||||
data_id = ingestion.identify(classified_data)
|
data_id = ingestion.identify(classified_data)
|
||||||
|
|
||||||
file_metadata = get_file_metadata(classified_data.get_data())
|
file_metadata = classified_data.get_metadata()
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"id": data_id,
|
"id": data_id,
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,11 @@
|
||||||
import nltk
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
from cognee.utils import extract_pos_tags
|
||||||
|
|
||||||
def extract_keywords(text: str) -> list[str]:
|
def extract_keywords(text: str) -> list[str]:
|
||||||
if len(text) == 0:
|
if len(text) == 0:
|
||||||
raise ValueError("extract_keywords cannot extract keywords from empty text.")
|
raise ValueError("extract_keywords cannot extract keywords from empty text.")
|
||||||
|
|
||||||
tokens = nltk.word_tokenize(text)
|
tags = extract_pos_tags(text)
|
||||||
|
|
||||||
tags = nltk.pos_tag(tokens)
|
|
||||||
nouns = [word for (word, tag) in tags if tag == "NN"]
|
nouns = [word for (word, tag) in tags if tag == "NN"]
|
||||||
|
|
||||||
vectorizer = TfidfVectorizer()
|
vectorizer = TfidfVectorizer()
|
||||||
|
|
|
||||||
|
|
@ -56,9 +56,9 @@ async def add_cognitive_layer_graphs(
|
||||||
dict(relationship_name = "contains"),
|
dict(relationship_name = "contains"),
|
||||||
))
|
))
|
||||||
|
|
||||||
pos_tags = await extract_pos_tags(node.entity_description)
|
pos_tags = extract_pos_tags(node.entity_description)
|
||||||
named_entities = await extract_named_entities(node.entity_description)
|
named_entities = extract_named_entities(node.entity_description)
|
||||||
sentiment = await extract_sentiment_vader(node.entity_description)
|
sentiment = extract_sentiment_vader(node.entity_description)
|
||||||
|
|
||||||
graph_nodes.append((
|
graph_nodes.append((
|
||||||
node_id,
|
node_id,
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,14 @@ class BinaryData(IngestionData):
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
def get_identifier(self):
|
def get_identifier(self):
|
||||||
|
metadata = self.get_metadata()
|
||||||
|
|
||||||
|
return metadata["mime_type"] + "_" + "|".join(metadata["keywords"])
|
||||||
|
|
||||||
|
def get_metadata(self):
|
||||||
self.ensure_metadata()
|
self.ensure_metadata()
|
||||||
|
|
||||||
return self.metadata["mime_type"] + "_" + "|".join(self.metadata["keywords"])
|
return self.metadata
|
||||||
|
|
||||||
def ensure_metadata(self):
|
def ensure_metadata(self):
|
||||||
if self.metadata is None:
|
if self.metadata is None:
|
||||||
|
|
|
||||||
|
|
@ -8,3 +8,6 @@ class IngestionData(Protocol):
|
||||||
|
|
||||||
def get_identifier(self):
|
def get_identifier(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_metadata(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
||||||
|
|
@ -7,14 +7,24 @@ def create_text_data(data: str):
|
||||||
|
|
||||||
class TextData(IngestionData):
|
class TextData(IngestionData):
|
||||||
data: str = None
|
data: str = None
|
||||||
|
metadata = None
|
||||||
|
|
||||||
def __init__(self, data: BinaryIO):
|
def __init__(self, data: BinaryIO):
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
def get_identifier(self):
|
def get_identifier(self):
|
||||||
keywords = extract_keywords(self.data)
|
keywords = self.get_metadata()["keywords"]
|
||||||
|
|
||||||
return "text/plain" + "_" + "|".join(keywords)
|
return "text/plain" + "_" + "|".join(keywords)
|
||||||
|
|
||||||
|
def get_metadata(self):
|
||||||
|
self.ensure_metadata()
|
||||||
|
|
||||||
|
return self.metadata
|
||||||
|
|
||||||
|
def ensure_metadata(self):
|
||||||
|
if self.metadata is None:
|
||||||
|
self.metadata = dict(keywords = extract_keywords(self.data))
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
return self.data
|
return self.data
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,11 @@ import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from nltk.sentiment import SentimentIntensityAnalyzer
|
|
||||||
import nltk
|
import nltk
|
||||||
from nltk.tokenize import word_tokenize
|
|
||||||
from nltk.tag import pos_tag
|
|
||||||
from nltk.chunk import ne_chunk
|
|
||||||
from cognee.config import Config
|
from cognee.config import Config
|
||||||
|
|
||||||
|
config = Config()
|
||||||
|
config.load()
|
||||||
|
|
||||||
def get_document_names(doc_input):
|
def get_document_names(doc_input):
|
||||||
"""
|
"""
|
||||||
|
|
@ -93,8 +91,6 @@ def trim_text_to_max_tokens(text: str, max_tokens: int, encoding_name: str) -> s
|
||||||
return trimmed_text
|
return trimmed_text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def format_dict(d):
|
def format_dict(d):
|
||||||
"""Format a dictionary as a string."""
|
"""Format a dictionary as a string."""
|
||||||
# Initialize an empty list to store formatted items
|
# Initialize an empty list to store formatted items
|
||||||
|
|
@ -117,9 +113,6 @@ def format_dict(d):
|
||||||
return formatted_string
|
return formatted_string
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
|
||||||
config.load()
|
|
||||||
|
|
||||||
def generate_color_palette(unique_layers):
|
def generate_color_palette(unique_layers):
|
||||||
colormap = plt.cm.get_cmap("viridis", len(unique_layers))
|
colormap = plt.cm.get_cmap("viridis", len(unique_layers))
|
||||||
colors = [colormap(i) for i in range(len(unique_layers))]
|
colors = [colormap(i) for i in range(len(unique_layers))]
|
||||||
|
|
@ -140,8 +133,8 @@ def prepare_nodes(graph, include_size=False):
|
||||||
nodes_data = []
|
nodes_data = []
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
node_info = graph.nodes[node]
|
node_info = graph.nodes[node]
|
||||||
description = node_info.get('layer_description', {}).get('layer', 'Default Layer') if isinstance(
|
description = node_info.get("layer_description", {}).get("layer", "Default Layer") if isinstance(
|
||||||
node_info.get('layer_description'), dict) else node_info.get('layer_description', 'Default Layer')
|
node_info.get("layer_description"), dict) else node_info.get("layer_description", "Default Layer")
|
||||||
# description = node_info['layer_description']['layer'] if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info['layer_description'] else node_info.get('layer_description', node)
|
# description = node_info['layer_description']['layer'] if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info['layer_description'] else node_info.get('layer_description', node)
|
||||||
# if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info.get('layer_description'):
|
# if isinstance(node_info.get('layer_description'), dict) and 'layer' in node_info.get('layer_description'):
|
||||||
# description = node_info['layer_description']['layer']
|
# description = node_info['layer_description']['layer']
|
||||||
|
|
@ -161,8 +154,6 @@ def prepare_nodes(graph, include_size=False):
|
||||||
return pd.DataFrame(nodes_data)
|
return pd.DataFrame(nodes_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def render_graph(graph, include_nodes=False, include_color=False, include_size=False, include_labels=False):
|
async def render_graph(graph, include_nodes=False, include_color=False, include_size=False, include_labels=False):
|
||||||
await register_graphistry()
|
await register_graphistry()
|
||||||
edges = prepare_edges(graph)
|
edges = prepare_edges(graph)
|
||||||
|
|
@ -174,7 +165,7 @@ async def render_graph(graph, include_nodes=False, include_color=False, include_
|
||||||
|
|
||||||
|
|
||||||
if include_size:
|
if include_size:
|
||||||
plotter = plotter.bind(point_size='size')
|
plotter = plotter.bind(point_size="size")
|
||||||
|
|
||||||
|
|
||||||
if include_color:
|
if include_color:
|
||||||
|
|
@ -185,7 +176,7 @@ async def render_graph(graph, include_nodes=False, include_color=False, include_
|
||||||
|
|
||||||
|
|
||||||
if include_labels:
|
if include_labels:
|
||||||
plotter = plotter.bind(point_label='layer_description')
|
plotter = plotter.bind(point_label = "layer_description")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -199,14 +190,23 @@ def sanitize_df(df):
|
||||||
return df.replace([np.inf, -np.inf, np.nan], None)
|
return df.replace([np.inf, -np.inf, np.nan], None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_entities(tagged_tokens):
|
||||||
# # Ensure that the necessary NLTK resources are downloaded
|
nltk.download("maxent_ne_chunker")
|
||||||
# nltk.download('maxent_ne_chunker')
|
from nltk.chunk import ne_chunk
|
||||||
# nltk.download('words')
|
return ne_chunk(tagged_tokens)
|
||||||
|
|
||||||
|
|
||||||
async def extract_pos_tags(sentence):
|
def extract_pos_tags(sentence):
|
||||||
"""Extract Part-of-Speech (POS) tags for words in a sentence."""
|
"""Extract Part-of-Speech (POS) tags for words in a sentence."""
|
||||||
|
|
||||||
|
# Ensure that the necessary NLTK resources are downloaded
|
||||||
|
nltk.download("words")
|
||||||
|
nltk.download("punkt")
|
||||||
|
nltk.download("averaged_perceptron_tagger")
|
||||||
|
|
||||||
|
from nltk.tag import pos_tag
|
||||||
|
from nltk.tokenize import word_tokenize
|
||||||
|
|
||||||
# Tokenize the sentence into words
|
# Tokenize the sentence into words
|
||||||
tokens = word_tokenize(sentence)
|
tokens = word_tokenize(sentence)
|
||||||
|
|
||||||
|
|
@ -216,20 +216,18 @@ async def extract_pos_tags(sentence):
|
||||||
return pos_tags
|
return pos_tags
|
||||||
|
|
||||||
|
|
||||||
async def extract_named_entities(sentence):
|
def extract_named_entities(sentence):
|
||||||
"""Extract Named Entities from a sentence."""
|
"""Extract Named Entities from a sentence."""
|
||||||
# Tokenize the sentence into words
|
# Tokenize the sentence into words
|
||||||
tokens = word_tokenize(sentence)
|
tagged_tokens = extract_pos_tags(sentence)
|
||||||
|
|
||||||
# Perform POS tagging on the tokenized sentence
|
|
||||||
tagged = pos_tag(tokens)
|
|
||||||
|
|
||||||
# Perform Named Entity Recognition (NER) on the tagged tokens
|
# Perform Named Entity Recognition (NER) on the tagged tokens
|
||||||
entities = ne_chunk(tagged)
|
entities = get_entities(tagged_tokens)
|
||||||
|
|
||||||
return entities
|
return entities
|
||||||
|
|
||||||
async def extract_sentiment_vader(text):
|
|
||||||
|
def extract_sentiment_vader(text):
|
||||||
"""
|
"""
|
||||||
Analyzes the sentiment of a given text using the VADER Sentiment Intensity Analyzer.
|
Analyzes the sentiment of a given text using the VADER Sentiment Intensity Analyzer.
|
||||||
|
|
||||||
|
|
@ -239,6 +237,7 @@ async def extract_sentiment_vader(text):
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the polarity scores for the text.
|
dict: A dictionary containing the polarity scores for the text.
|
||||||
"""
|
"""
|
||||||
|
from nltk.sentiment import SentimentIntensityAnalyzer
|
||||||
|
|
||||||
nltk.download("vader_lexicon")
|
nltk.download("vader_lexicon")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -151,10 +151,10 @@
|
||||||
"graph_client = await get_graph_client(GraphDBType.NETWORKX)\n",
|
"graph_client = await get_graph_client(GraphDBType.NETWORKX)\n",
|
||||||
"graph = graph_client.graph\n",
|
"graph = graph_client.graph\n",
|
||||||
"\n",
|
"\n",
|
||||||
"results = await search_similarity(\"At My Window was released by which American singer-songwriter?\", graph)\n",
|
"results = await search_similarity(\"Who is Ernie Grunwald?\", graph)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for result in results:\n",
|
"for result in results:\n",
|
||||||
" print(\"At My Window\" in result)\n",
|
" print(\"Ernie Grunwald\" in result)\n",
|
||||||
" print(result)"
|
" print(result)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue