fix search, add improvements

This commit is contained in:
Vasilije 2024-05-09 23:11:33 +02:00
parent 8e850c19f8
commit e1a9a236a5
12 changed files with 132 additions and 11 deletions

View file

@ -14,6 +14,7 @@ from cognee.modules.cognify.graph.add_cognitive_layer_graphs import add_cognitiv
from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \ from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
graph_ready_output, connect_nodes_in_graph graph_ready_output, connect_nodes_in_graph
from cognee.modules.cognify.graph.initialize_graph import initialize_graph
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
@ -73,7 +74,14 @@ async def cognify(datasets: Union[str, List[str]] = None):
if dataset_name in added_dataset: if dataset_name in added_dataset:
dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset))) dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset)))
# await initialize_graph(USER_ID, graph_data_model, graph_client)
graph_topology = infrastructure_config.get_config()["graph_topology"]
await initialize_graph(USER_ID, graph_client=graph_client)
data_chunks = {} data_chunks = {}
@ -174,11 +182,11 @@ if __name__ == "__main__":
async def test(): async def test():
# #
# from cognee.api.v1.add import add from cognee.api.v1.add import add
#
# await add(["A large language model (LLM) is a language model notable for its ability to achieve general-purpose language generation and other natural language processing tasks such as classification"], "code") await add(["A large language model (LLM) is a language model notable for its ability to achieve general-purpose language generation and other natural language processing tasks such as classification"], "code")
#
# graph = await cognify() graph = await cognify()
from cognee.utils import render_graph from cognee.utils import render_graph

View file

@ -72,3 +72,9 @@ class config():
infrastructure_config.set_config({ infrastructure_config.set_config({
"chunk_strategy": chunk_strategy "chunk_strategy": chunk_strategy
}) })
@staticmethod
def set_graph_topology(graph_topology: object):
infrastructure_config.set_config({
"graph_topology": graph_topology
})

View file

@ -8,7 +8,7 @@ from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from dotenv import load_dotenv from dotenv import load_dotenv
from cognee.root_dir import get_absolute_path from cognee.root_dir import get_absolute_path
from cognee.shared.data_models import ChunkStrategy from cognee.shared.data_models import ChunkStrategy, DefaultGraphModel
base_dir = Path(__file__).resolve().parent.parent base_dir = Path(__file__).resolve().parent.parent
# Load the .env file from the base directory # Load the .env file from the base directory
@ -74,6 +74,7 @@ class Config:
# Database parameters # Database parameters
graph_database_provider: str = os.getenv("GRAPH_DB_PROVIDER", "NETWORKX") graph_database_provider: str = os.getenv("GRAPH_DB_PROVIDER", "NETWORKX")
graph_topology:str = DefaultGraphModel
if ( if (
os.getenv("ENV") == "prod" os.getenv("ENV") == "prod"

View file

@ -33,6 +33,7 @@ class InfrastructureConfig():
database_file_path: str = None database_file_path: str = None
chunk_strategy = config.chunk_strategy chunk_strategy = config.chunk_strategy
chunk_engine = None chunk_engine = None
graph_topology = config.graph_topology
def get_config(self, config_entity: str = None) -> dict: def get_config(self, config_entity: str = None) -> dict:
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None: if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
@ -78,6 +79,9 @@ class InfrastructureConfig():
if self.chunk_engine is None: if self.chunk_engine is None:
self.chunk_engine = DefaultChunkEngine() self.chunk_engine = DefaultChunkEngine()
if self.graph_topology is None:
self.graph_topology = config.graph_topology
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None: if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model) self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
@ -129,6 +133,7 @@ class InfrastructureConfig():
"database_path": self.database_file_path, "database_path": self.database_file_path,
"chunk_strategy": self.chunk_strategy, "chunk_strategy": self.chunk_strategy,
"chunk_engine": self.chunk_engine, "chunk_engine": self.chunk_engine,
"graph_topology": self.graph_topology
} }
def set_config(self, new_config: dict): def set_config(self, new_config: dict):
@ -183,4 +188,7 @@ class InfrastructureConfig():
if "chunk_engine" in new_config: if "chunk_engine" in new_config:
self.chunk_engine = new_config["chunk_engine"] self.chunk_engine = new_config["chunk_engine"]
if "graph_topology" in new_config:
self.graph_topology = new_config["graph_topology"]
infrastructure_config = InfrastructureConfig() infrastructure_config = InfrastructureConfig()

View file

@ -0,0 +1,50 @@
from __future__ import annotations
import re
from cognee.infrastructure.data.chunking.DefaultChunkEngine import DefaultChunkEngine
from cognee.shared.data_models import ChunkStrategy
class LangchainChunkEngine():
@staticmethod
def chunk_data(
chunk_strategy = None,
source_data = None,
chunk_size = None,
chunk_overlap = None,
):
"""
Chunk data based on the specified strategy.
Parameters:
- chunk_strategy: The strategy to use for chunking.
- source_data: The data to be chunked.
- chunk_size: The size of each chunk.
- chunk_overlap: The overlap between chunks.
Returns:
- The chunked data.
"""
if chunk_strategy == ChunkStrategy.CODE:
chunked_data = LangchainChunkEngine.chunk_data_by_code(source_data,chunk_size, chunk_overlap)
else:
chunked_data = DefaultChunkEngine.chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
return chunked_data
@staticmethod
def chunk_data_by_code(data_chunks, chunk_size, chunk_overlap, language=None):
from langchain_text_splitters import (
Language,
RecursiveCharacterTextSplitter,
)
if language is None:
language = Language.PYTHON
python_splitter = RecursiveCharacterTextSplitter.from_language(
language=language, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
code_chunks = python_splitter.create_documents([data_chunks])
return code_chunks

View file

@ -16,7 +16,14 @@ def get_file_metadata(file: BinaryIO) -> FileMetadata:
file.seek(0) file.seek(0)
file_text = extract_text_from_file(file, file_type) file_text = extract_text_from_file(file, file_type)
keywords = extract_keywords(file_text)
import uuid
try:
keywords = extract_keywords(file_text)
except:
keywords = ["no keywords detected" + str(uuid.uuid4())]
file_path = file.name file_path = file.name
file_name = file_path.split("/")[-1].split(".")[0] file_name = file_path.split("/")[-1].split(".")[0]

View file

@ -44,6 +44,7 @@ async def add_node(client, parent_id: Optional[str], node_id: str, node_data: di
# Add an edge if a parent ID is provided and the graph engine is NETWORKX # Add an edge if a parent ID is provided and the graph engine is NETWORKX
if parent_id and "default_relationship" in node_data and infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX: if parent_id and "default_relationship" in node_data and infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
print("Node id", node_id)
await client.add_edge(parent_id, node_id, relationship_name = node_data["default_relationship"]["type"], edge_properties = node_data) await client.add_edge(parent_id, node_id, relationship_name = node_data["default_relationship"]["type"], edge_properties = node_data)
except Exception as e: except Exception as e:
# Log the exception; consider a logging framework for production use # Log the exception; consider a logging framework for production use
@ -103,6 +104,7 @@ async def add_node(client, parent_id: Optional[str], node_id: str, node_data: di
async def add_edge(client, parent_id: Optional[str], node_id: str, node_data: dict, created_node_ids): async def add_edge(client, parent_id: Optional[str], node_id: str, node_data: dict, created_node_ids):
print('NODE ID', node_data)
if node_id == "Relationship_default" and parent_id: if node_id == "Relationship_default" and parent_id:
# Initialize source and target variables outside the loop # Initialize source and target variables outside the loop

View file

@ -2,9 +2,9 @@ from datetime import datetime
from cognee.shared.data_models import DefaultGraphModel, Relationship, UserProperties, UserLocation from cognee.shared.data_models import DefaultGraphModel, Relationship, UserProperties, UserLocation
from cognee.modules.cognify.graph.create import create_semantic_graph from cognee.modules.cognify.graph.create import create_semantic_graph
async def initialize_graph(root_id: str, graphdatamodel, graph_client): async def initialize_graph(root_id: str, graphdatamodel=None, graph_client=None):
if graphdatamodel: if graphdatamodel:
graph = graphdatamodel(id = root_id) graph = graphdatamodel(node_id= root_id)
graph_ = await create_semantic_graph(graph, graph_client) graph_ = await create_semantic_graph(graph, graph_client)
return graph_ return graph_
else: else:

View file

@ -0,0 +1,36 @@
from pydantic import BaseModel
from typing import List, Optional, Dict, Any, Union
class Relationship(BaseModel):
type: str
attributes: Optional[Dict[str, Any]] = {}
class Document(BaseModel):
name: str
content: str
filetype: str
class Directory(BaseModel):
name: str
documents: List[Document] = []
directories: List['Directory'] = []
# Allows recursive Directory Model
Directory.update_forward_refs()
class RepositoryProperties(BaseModel):
custom_properties: Optional[Dict[str, Any]] = None
location: Optional[str] = None # Simplified location reference
class RepositoryNode(BaseModel):
node_id: str
node_type: str # 'document' or 'directory'
properties: RepositoryProperties = RepositoryProperties()
content: Union[Document, Directory, None] = None
relationships: List[Relationship] = []
class RepositoryGraphModel(BaseModel):
root: RepositoryNode
default_relationships: List[Relationship] = []

View file

@ -34,6 +34,7 @@ class ChunkStrategy(Enum):
EXACT = "exact" EXACT = "exact"
PARAGRAPH = "paragraph" PARAGRAPH = "paragraph"
SENTENCE = "sentence" SENTENCE = "sentence"
CODE = "code"
class MemorySummary(BaseModel): class MemorySummary(BaseModel):
""" Memory summary. """ """ Memory summary. """

View file

@ -1,5 +1,5 @@
""" This module contains utility functions for the cognee. """ """ This module contains utility functions for the cognee. """
import logging
import os import os
import uuid import uuid
import datetime import datetime
@ -20,6 +20,8 @@ config.load()
def send_telemetry(event_name: str): def send_telemetry(event_name: str):
if os.getenv("TELEMETRY_DISABLED"): if os.getenv("TELEMETRY_DISABLED"):
print("Telemetry is disabled.")
logging.info("Telemetry is disabled.")
return return
env = os.getenv("ENV") env = os.getenv("ENV")