fix search, add improvements
This commit is contained in:
parent
8e850c19f8
commit
e1a9a236a5
12 changed files with 132 additions and 11 deletions
|
|
@ -14,6 +14,7 @@ from cognee.modules.cognify.graph.add_cognitive_layer_graphs import add_cognitiv
|
||||||
from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
|
from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
|
||||||
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
|
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
|
||||||
graph_ready_output, connect_nodes_in_graph
|
graph_ready_output, connect_nodes_in_graph
|
||||||
|
from cognee.modules.cognify.graph.initialize_graph import initialize_graph
|
||||||
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
|
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
|
||||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||||
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
|
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
|
||||||
|
|
@ -73,7 +74,14 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
||||||
if dataset_name in added_dataset:
|
if dataset_name in added_dataset:
|
||||||
dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset)))
|
dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset)))
|
||||||
|
|
||||||
# await initialize_graph(USER_ID, graph_data_model, graph_client)
|
|
||||||
|
|
||||||
|
|
||||||
|
graph_topology = infrastructure_config.get_config()["graph_topology"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
await initialize_graph(USER_ID, graph_client=graph_client)
|
||||||
|
|
||||||
data_chunks = {}
|
data_chunks = {}
|
||||||
|
|
||||||
|
|
@ -174,11 +182,11 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
async def test():
|
async def test():
|
||||||
#
|
#
|
||||||
# from cognee.api.v1.add import add
|
from cognee.api.v1.add import add
|
||||||
#
|
|
||||||
# await add(["A large language model (LLM) is a language model notable for its ability to achieve general-purpose language generation and other natural language processing tasks such as classification"], "code")
|
await add(["A large language model (LLM) is a language model notable for its ability to achieve general-purpose language generation and other natural language processing tasks such as classification"], "code")
|
||||||
#
|
|
||||||
# graph = await cognify()
|
graph = await cognify()
|
||||||
|
|
||||||
from cognee.utils import render_graph
|
from cognee.utils import render_graph
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -72,3 +72,9 @@ class config():
|
||||||
infrastructure_config.set_config({
|
infrastructure_config.set_config({
|
||||||
"chunk_strategy": chunk_strategy
|
"chunk_strategy": chunk_strategy
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_graph_topology(graph_topology: object):
|
||||||
|
infrastructure_config.set_config({
|
||||||
|
"graph_topology": graph_topology
|
||||||
|
})
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from cognee.root_dir import get_absolute_path
|
from cognee.root_dir import get_absolute_path
|
||||||
from cognee.shared.data_models import ChunkStrategy
|
from cognee.shared.data_models import ChunkStrategy, DefaultGraphModel
|
||||||
|
|
||||||
base_dir = Path(__file__).resolve().parent.parent
|
base_dir = Path(__file__).resolve().parent.parent
|
||||||
# Load the .env file from the base directory
|
# Load the .env file from the base directory
|
||||||
|
|
@ -74,6 +74,7 @@ class Config:
|
||||||
|
|
||||||
# Database parameters
|
# Database parameters
|
||||||
graph_database_provider: str = os.getenv("GRAPH_DB_PROVIDER", "NETWORKX")
|
graph_database_provider: str = os.getenv("GRAPH_DB_PROVIDER", "NETWORKX")
|
||||||
|
graph_topology:str = DefaultGraphModel
|
||||||
|
|
||||||
if (
|
if (
|
||||||
os.getenv("ENV") == "prod"
|
os.getenv("ENV") == "prod"
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ class InfrastructureConfig():
|
||||||
database_file_path: str = None
|
database_file_path: str = None
|
||||||
chunk_strategy = config.chunk_strategy
|
chunk_strategy = config.chunk_strategy
|
||||||
chunk_engine = None
|
chunk_engine = None
|
||||||
|
graph_topology = config.graph_topology
|
||||||
|
|
||||||
def get_config(self, config_entity: str = None) -> dict:
|
def get_config(self, config_entity: str = None) -> dict:
|
||||||
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
|
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
|
||||||
|
|
@ -78,6 +79,9 @@ class InfrastructureConfig():
|
||||||
if self.chunk_engine is None:
|
if self.chunk_engine is None:
|
||||||
self.chunk_engine = DefaultChunkEngine()
|
self.chunk_engine = DefaultChunkEngine()
|
||||||
|
|
||||||
|
if self.graph_topology is None:
|
||||||
|
self.graph_topology = config.graph_topology
|
||||||
|
|
||||||
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
|
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
|
||||||
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
|
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
|
||||||
|
|
||||||
|
|
@ -129,6 +133,7 @@ class InfrastructureConfig():
|
||||||
"database_path": self.database_file_path,
|
"database_path": self.database_file_path,
|
||||||
"chunk_strategy": self.chunk_strategy,
|
"chunk_strategy": self.chunk_strategy,
|
||||||
"chunk_engine": self.chunk_engine,
|
"chunk_engine": self.chunk_engine,
|
||||||
|
"graph_topology": self.graph_topology
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_config(self, new_config: dict):
|
def set_config(self, new_config: dict):
|
||||||
|
|
@ -183,4 +188,7 @@ class InfrastructureConfig():
|
||||||
if "chunk_engine" in new_config:
|
if "chunk_engine" in new_config:
|
||||||
self.chunk_engine = new_config["chunk_engine"]
|
self.chunk_engine = new_config["chunk_engine"]
|
||||||
|
|
||||||
|
if "graph_topology" in new_config:
|
||||||
|
self.graph_topology = new_config["graph_topology"]
|
||||||
|
|
||||||
infrastructure_config = InfrastructureConfig()
|
infrastructure_config = InfrastructureConfig()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
import re
|
||||||
|
|
||||||
|
from cognee.infrastructure.data.chunking.DefaultChunkEngine import DefaultChunkEngine
|
||||||
|
from cognee.shared.data_models import ChunkStrategy
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LangchainChunkEngine():
|
||||||
|
@staticmethod
|
||||||
|
def chunk_data(
|
||||||
|
chunk_strategy = None,
|
||||||
|
source_data = None,
|
||||||
|
chunk_size = None,
|
||||||
|
chunk_overlap = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Chunk data based on the specified strategy.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- chunk_strategy: The strategy to use for chunking.
|
||||||
|
- source_data: The data to be chunked.
|
||||||
|
- chunk_size: The size of each chunk.
|
||||||
|
- chunk_overlap: The overlap between chunks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- The chunked data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if chunk_strategy == ChunkStrategy.CODE:
|
||||||
|
chunked_data = LangchainChunkEngine.chunk_data_by_code(source_data,chunk_size, chunk_overlap)
|
||||||
|
else:
|
||||||
|
chunked_data = DefaultChunkEngine.chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
|
||||||
|
return chunked_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def chunk_data_by_code(data_chunks, chunk_size, chunk_overlap, language=None):
|
||||||
|
from langchain_text_splitters import (
|
||||||
|
Language,
|
||||||
|
RecursiveCharacterTextSplitter,
|
||||||
|
)
|
||||||
|
if language is None:
|
||||||
|
language = Language.PYTHON
|
||||||
|
python_splitter = RecursiveCharacterTextSplitter.from_language(
|
||||||
|
language=language, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
|
code_chunks = python_splitter.create_documents([data_chunks])
|
||||||
|
|
||||||
|
return code_chunks
|
||||||
|
|
||||||
|
|
@ -16,7 +16,14 @@ def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
||||||
|
|
||||||
file.seek(0)
|
file.seek(0)
|
||||||
file_text = extract_text_from_file(file, file_type)
|
file_text = extract_text_from_file(file, file_type)
|
||||||
keywords = extract_keywords(file_text)
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
try:
|
||||||
|
keywords = extract_keywords(file_text)
|
||||||
|
except:
|
||||||
|
keywords = ["no keywords detected" + str(uuid.uuid4())]
|
||||||
|
|
||||||
|
|
||||||
file_path = file.name
|
file_path = file.name
|
||||||
file_name = file_path.split("/")[-1].split(".")[0]
|
file_name = file_path.split("/")[-1].split(".")[0]
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ async def add_node(client, parent_id: Optional[str], node_id: str, node_data: di
|
||||||
|
|
||||||
# Add an edge if a parent ID is provided and the graph engine is NETWORKX
|
# Add an edge if a parent ID is provided and the graph engine is NETWORKX
|
||||||
if parent_id and "default_relationship" in node_data and infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
|
if parent_id and "default_relationship" in node_data and infrastructure_config.get_config()["graph_engine"] == GraphDBType.NETWORKX:
|
||||||
|
print("Node id", node_id)
|
||||||
await client.add_edge(parent_id, node_id, relationship_name = node_data["default_relationship"]["type"], edge_properties = node_data)
|
await client.add_edge(parent_id, node_id, relationship_name = node_data["default_relationship"]["type"], edge_properties = node_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log the exception; consider a logging framework for production use
|
# Log the exception; consider a logging framework for production use
|
||||||
|
|
@ -103,6 +104,7 @@ async def add_node(client, parent_id: Optional[str], node_id: str, node_data: di
|
||||||
|
|
||||||
|
|
||||||
async def add_edge(client, parent_id: Optional[str], node_id: str, node_data: dict, created_node_ids):
|
async def add_edge(client, parent_id: Optional[str], node_id: str, node_data: dict, created_node_ids):
|
||||||
|
print('NODE ID', node_data)
|
||||||
|
|
||||||
if node_id == "Relationship_default" and parent_id:
|
if node_id == "Relationship_default" and parent_id:
|
||||||
# Initialize source and target variables outside the loop
|
# Initialize source and target variables outside the loop
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ from datetime import datetime
|
||||||
from cognee.shared.data_models import DefaultGraphModel, Relationship, UserProperties, UserLocation
|
from cognee.shared.data_models import DefaultGraphModel, Relationship, UserProperties, UserLocation
|
||||||
from cognee.modules.cognify.graph.create import create_semantic_graph
|
from cognee.modules.cognify.graph.create import create_semantic_graph
|
||||||
|
|
||||||
async def initialize_graph(root_id: str, graphdatamodel, graph_client):
|
async def initialize_graph(root_id: str, graphdatamodel=None, graph_client=None):
|
||||||
if graphdatamodel:
|
if graphdatamodel:
|
||||||
graph = graphdatamodel(id = root_id)
|
graph = graphdatamodel(node_id= root_id)
|
||||||
graph_ = await create_semantic_graph(graph, graph_client)
|
graph_ = await create_semantic_graph(graph, graph_client)
|
||||||
return graph_
|
return graph_
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
36
cognee/shared/GithubTopology.py
Normal file
36
cognee/shared/GithubTopology.py
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional, Dict, Any, Union
|
||||||
|
|
||||||
|
class Relationship(BaseModel):
|
||||||
|
type: str
|
||||||
|
attributes: Optional[Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
class Document(BaseModel):
|
||||||
|
name: str
|
||||||
|
content: str
|
||||||
|
filetype: str
|
||||||
|
|
||||||
|
class Directory(BaseModel):
|
||||||
|
name: str
|
||||||
|
documents: List[Document] = []
|
||||||
|
directories: List['Directory'] = []
|
||||||
|
|
||||||
|
# Allows recursive Directory Model
|
||||||
|
Directory.update_forward_refs()
|
||||||
|
|
||||||
|
class RepositoryProperties(BaseModel):
|
||||||
|
custom_properties: Optional[Dict[str, Any]] = None
|
||||||
|
location: Optional[str] = None # Simplified location reference
|
||||||
|
|
||||||
|
class RepositoryNode(BaseModel):
|
||||||
|
node_id: str
|
||||||
|
node_type: str # 'document' or 'directory'
|
||||||
|
properties: RepositoryProperties = RepositoryProperties()
|
||||||
|
content: Union[Document, Directory, None] = None
|
||||||
|
relationships: List[Relationship] = []
|
||||||
|
|
||||||
|
class RepositoryGraphModel(BaseModel):
|
||||||
|
root: RepositoryNode
|
||||||
|
default_relationships: List[Relationship] = []
|
||||||
|
|
@ -34,6 +34,7 @@ class ChunkStrategy(Enum):
|
||||||
EXACT = "exact"
|
EXACT = "exact"
|
||||||
PARAGRAPH = "paragraph"
|
PARAGRAPH = "paragraph"
|
||||||
SENTENCE = "sentence"
|
SENTENCE = "sentence"
|
||||||
|
CODE = "code"
|
||||||
|
|
||||||
class MemorySummary(BaseModel):
|
class MemorySummary(BaseModel):
|
||||||
""" Memory summary. """
|
""" Memory summary. """
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
""" This module contains utility functions for the cognee. """
|
""" This module contains utility functions for the cognee. """
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
import datetime
|
import datetime
|
||||||
|
|
@ -20,6 +20,8 @@ config.load()
|
||||||
|
|
||||||
def send_telemetry(event_name: str):
|
def send_telemetry(event_name: str):
|
||||||
if os.getenv("TELEMETRY_DISABLED"):
|
if os.getenv("TELEMETRY_DISABLED"):
|
||||||
|
print("Telemetry is disabled.")
|
||||||
|
logging.info("Telemetry is disabled.")
|
||||||
return
|
return
|
||||||
|
|
||||||
env = os.getenv("ENV")
|
env = os.getenv("ENV")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue