fix search, add improvements
This commit is contained in:
parent
8e850c19f8
commit
e1a9a236a5
12 changed files with 132 additions and 11 deletions
|
|
@ -14,6 +14,7 @@ from cognee.modules.cognify.graph.add_cognitive_layer_graphs import add_cognitiv
|
|||
from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
|
||||
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
|
||||
graph_ready_output, connect_nodes_in_graph
|
||||
from cognee.modules.cognify.graph.initialize_graph import initialize_graph
|
||||
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
|
||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
|
||||
|
|
@ -73,7 +74,14 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
if dataset_name in added_dataset:
|
||||
dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset)))
|
||||
|
||||
# await initialize_graph(USER_ID, graph_data_model, graph_client)
|
||||
|
||||
|
||||
|
||||
graph_topology = infrastructure_config.get_config()["graph_topology"]
|
||||
|
||||
|
||||
|
||||
await initialize_graph(USER_ID, graph_client=graph_client)
|
||||
|
||||
data_chunks = {}
|
||||
|
||||
|
|
@ -174,11 +182,11 @@ if __name__ == "__main__":
|
|||
|
||||
async def test():
|
||||
#
|
||||
# from cognee.api.v1.add import add
|
||||
#
|
||||
# await add(["A large language model (LLM) is a language model notable for its ability to achieve general-purpose language generation and other natural language processing tasks such as classification"], "code")
|
||||
#
|
||||
# graph = await cognify()
|
||||
from cognee.api.v1.add import add
|
||||
|
||||
await add(["A large language model (LLM) is a language model notable for its ability to achieve general-purpose language generation and other natural language processing tasks such as classification"], "code")
|
||||
|
||||
graph = await cognify()
|
||||
|
||||
from cognee.utils import render_graph
|
||||
|
||||
|
|
|
|||
|
|
@ -72,3 +72,9 @@ class config():
|
|||
infrastructure_config.set_config({
|
||||
"chunk_strategy": chunk_strategy
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def set_graph_topology(graph_topology: object):
|
||||
infrastructure_config.set_config({
|
||||
"graph_topology": graph_topology
|
||||
})
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from dataclasses import dataclass, field
|
|||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
from cognee.root_dir import get_absolute_path
|
||||
from cognee.shared.data_models import ChunkStrategy
|
||||
from cognee.shared.data_models import ChunkStrategy, DefaultGraphModel
|
||||
|
||||
base_dir = Path(__file__).resolve().parent.parent
|
||||
# Load the .env file from the base directory
|
||||
|
|
@ -74,6 +74,7 @@ class Config:
|
|||
|
||||
# Database parameters
|
||||
graph_database_provider: str = os.getenv("GRAPH_DB_PROVIDER", "NETWORKX")
|
||||
graph_topology:str = DefaultGraphModel
|
||||
|
||||
if (
|
||||
os.getenv("ENV") == "prod"
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class InfrastructureConfig():
|
|||
database_file_path: str = None
|
||||
chunk_strategy = config.chunk_strategy
|
||||
chunk_engine = None
|
||||
graph_topology = config.graph_topology
|
||||
|
||||
def get_config(self, config_entity: str = None) -> dict:
|
||||
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
|
||||
|
|
@ -78,6 +79,9 @@ class InfrastructureConfig():
|
|||
if self.chunk_engine is None:
|
||||
self.chunk_engine = DefaultChunkEngine()
|
||||
|
||||
if self.graph_topology is None:
|
||||
self.graph_topology = config.graph_topology
|
||||
|
||||
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
|
||||
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
|
||||
|
||||
|
|
@ -129,6 +133,7 @@ class InfrastructureConfig():
|
|||
"database_path": self.database_file_path,
|
||||
"chunk_strategy": self.chunk_strategy,
|
||||
"chunk_engine": self.chunk_engine,
|
||||
"graph_topology": self.graph_topology
|
||||
}
|
||||
|
||||
def set_config(self, new_config: dict):
|
||||
|
|
@ -183,4 +188,7 @@ class InfrastructureConfig():
|
|||
if "chunk_engine" in new_config:
|
||||
self.chunk_engine = new_config["chunk_engine"]
|
||||
|
||||
if "graph_topology" in new_config:
|
||||
self.graph_topology = new_config["graph_topology"]
|
||||
|
||||
infrastructure_config = InfrastructureConfig()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
from __future__ import annotations
|
||||
import re
|
||||
|
||||
from cognee.infrastructure.data.chunking.DefaultChunkEngine import DefaultChunkEngine
|
||||
from cognee.shared.data_models import ChunkStrategy
|
||||
|
||||
|
||||
|
||||
class LangchainChunkEngine():
|
||||
@staticmethod
|
||||
def chunk_data(
|
||||
chunk_strategy = None,
|
||||
source_data = None,
|
||||
chunk_size = None,
|
||||
chunk_overlap = None,
|
||||
):
|
||||
"""
|
||||
Chunk data based on the specified strategy.
|
||||
|
||||
Parameters:
|
||||
- chunk_strategy: The strategy to use for chunking.
|
||||
- source_data: The data to be chunked.
|
||||
- chunk_size: The size of each chunk.
|
||||
- chunk_overlap: The overlap between chunks.
|
||||
|
||||
Returns:
|
||||
- The chunked data.
|
||||
"""
|
||||
|
||||
if chunk_strategy == ChunkStrategy.CODE:
|
||||
chunked_data = LangchainChunkEngine.chunk_data_by_code(source_data,chunk_size, chunk_overlap)
|
||||
else:
|
||||
chunked_data = DefaultChunkEngine.chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
|
||||
return chunked_data
|
||||
|
||||
@staticmethod
|
||||
def chunk_data_by_code(data_chunks, chunk_size, chunk_overlap, language=None):
|
||||
from langchain_text_splitters import (
|
||||
Language,
|
||||
RecursiveCharacterTextSplitter,
|
||||
)
|
||||
if language is None:
|
||||
language = Language.PYTHON
|
||||
python_splitter = RecursiveCharacterTextSplitter.from_language(
|
||||
language=language, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
code_chunks = python_splitter.create_documents([data_chunks])
|
||||
|
||||
return code_chunks
|
||||
|
||||
|
|
@ -16,7 +16,14 @@ def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
|||
|
||||
file.seek(0)
|
||||
file_text = extract_text_from_file(file, file_type)
|
||||
keywords = extract_keywords(file_text)
|
||||
|
||||
import uuid
|
||||
|
||||
try:
|
||||
keywords = extract_keywords(file_text)
|
||||
except:
|
||||
keywords = ["no keywords detected" + str(uuid.uuid4())]
|
||||
|
||||
|
||||
file_path = file.name
|
||||
file_name = file_path.split("/")[-1].split(".")[0]
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ async def add_node(client, parent_id: Optional[str], node_id: str, node_data: di
|
|||
|
||||
# Add an edge if a parent ID is provided and the graph engine is NETWORKX
|
||||
if parent_id and "default_relationship" in node_data and infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
|
||||
print("Node id", node_id)
|
||||
await client.add_edge(parent_id, node_id, relationship_name = node_data["default_relationship"]["type"], edge_properties = node_data)
|
||||
except Exception as e:
|
||||
# Log the exception; consider a logging framework for production use
|
||||
|
|
@ -103,6 +104,7 @@ async def add_node(client, parent_id: Optional[str], node_id: str, node_data: di
|
|||
|
||||
|
||||
async def add_edge(client, parent_id: Optional[str], node_id: str, node_data: dict, created_node_ids):
|
||||
print('NODE ID', node_data)
|
||||
|
||||
if node_id == "Relationship_default" and parent_id:
|
||||
# Initialize source and target variables outside the loop
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ from datetime import datetime
|
|||
from cognee.shared.data_models import DefaultGraphModel, Relationship, UserProperties, UserLocation
|
||||
from cognee.modules.cognify.graph.create import create_semantic_graph
|
||||
|
||||
async def initialize_graph(root_id: str, graphdatamodel, graph_client):
|
||||
async def initialize_graph(root_id: str, graphdatamodel=None, graph_client=None):
|
||||
if graphdatamodel:
|
||||
graph = graphdatamodel(id = root_id)
|
||||
graph = graphdatamodel(node_id= root_id)
|
||||
graph_ = await create_semantic_graph(graph, graph_client)
|
||||
return graph_
|
||||
else:
|
||||
|
|
|
|||
36
cognee/shared/GithubTopology.py
Normal file
36
cognee/shared/GithubTopology.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
|
||||
class Relationship(BaseModel):
|
||||
type: str
|
||||
attributes: Optional[Dict[str, Any]] = {}
|
||||
|
||||
class Document(BaseModel):
|
||||
name: str
|
||||
content: str
|
||||
filetype: str
|
||||
|
||||
class Directory(BaseModel):
|
||||
name: str
|
||||
documents: List[Document] = []
|
||||
directories: List['Directory'] = []
|
||||
|
||||
# Allows recursive Directory Model
|
||||
Directory.update_forward_refs()
|
||||
|
||||
class RepositoryProperties(BaseModel):
|
||||
custom_properties: Optional[Dict[str, Any]] = None
|
||||
location: Optional[str] = None # Simplified location reference
|
||||
|
||||
class RepositoryNode(BaseModel):
|
||||
node_id: str
|
||||
node_type: str # 'document' or 'directory'
|
||||
properties: RepositoryProperties = RepositoryProperties()
|
||||
content: Union[Document, Directory, None] = None
|
||||
relationships: List[Relationship] = []
|
||||
|
||||
class RepositoryGraphModel(BaseModel):
|
||||
root: RepositoryNode
|
||||
default_relationships: List[Relationship] = []
|
||||
|
|
@ -34,6 +34,7 @@ class ChunkStrategy(Enum):
|
|||
EXACT = "exact"
|
||||
PARAGRAPH = "paragraph"
|
||||
SENTENCE = "sentence"
|
||||
CODE = "code"
|
||||
|
||||
class MemorySummary(BaseModel):
|
||||
""" Memory summary. """
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
""" This module contains utility functions for the cognee. """
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import datetime
|
||||
|
|
@ -20,6 +20,8 @@ config.load()
|
|||
|
||||
def send_telemetry(event_name: str):
|
||||
if os.getenv("TELEMETRY_DISABLED"):
|
||||
print("Telemetry is disabled.")
|
||||
logging.info("Telemetry is disabled.")
|
||||
return
|
||||
|
||||
env = os.getenv("ENV")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue