diff --git a/.gitignore b/.gitignore index 9cb75f3e5..468fc3e80 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,10 @@ __pycache__/ *.py[cod] *$py.class +notebooks/ +full_run.ipynb +evals/ + # C extensions *.so diff --git a/cognee/api/client.py b/cognee/api/client.py index b3bde178e..6e2aa86e1 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -73,9 +73,9 @@ async def get_dataset_graph(dataset_id: str): from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client try: - graph_config = get_graph_config() - graph_engine = graph_config.graph_engine - graph_client = await get_graph_client(graph_engine) + # graph_config = get_graph_config() + # graph_engine = graph_config.graph_engine + graph_client = await get_graph_client() graph_url = await render_graph(graph_client.graph) return JSONResponse( diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 8d11e9ad9..19ca263f1 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -49,9 +49,9 @@ async def cognify(datasets: Union[str, List[str]] = None): stopwords.ensure_loaded() create_task_status_table() - graph_config = get_graph_config() - graph_db_type = graph_config.graph_engine - graph_client = await get_graph_client(graph_db_type) + # graph_config = get_graph_config() + # graph_db_type = graph_config.graph_engine + graph_client = await get_graph_client() relational_config = get_relationaldb_config() db_engine = relational_config.database_engine @@ -180,7 +180,7 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).") graph_config = get_graph_config() - graph_client = await get_graph_client(graph_config.graph_engine) + graph_client = await get_graph_client() graph_topology = graph_config.graph_model if graph_topology == SourceCodeGraph: diff --git a/cognee/api/v1/topology/add_topology.py b/cognee/api/v1/topology/add_topology.py index f2da00135..58d99f21d 100644 --- a/cognee/api/v1/topology/add_topology.py +++ b/cognee/api/v1/topology/add_topology.py @@ -42,9 +42,9 @@ USER_ID = "default_user" async def add_topology(directory: str = "example", model: BaseModel = GitHubRepositoryModel) -> Any: graph_config = get_graph_config() - graph_db_type = graph_config.graph_database_provider + # graph_db_type = graph_config.graph_database_provider - graph_client = await get_graph_client(graph_db_type) + graph_client = await get_graph_client() engine = TopologyEngine() topology = await engine.infer_from_directory_structure(node_id=USER_ID, repository=directory, model=model) diff --git a/cognee/infrastructure/databases/graph/get_graph_client.py b/cognee/infrastructure/databases/graph/get_graph_client.py index 51980903a..9ea6bb9bd 100644 --- a/cognee/infrastructure/databases/graph/get_graph_client.py +++ b/cognee/infrastructure/databases/graph/get_graph_client.py @@ -6,11 +6,11 @@ from .graph_db_interface import GraphDBInterface from .networkx.adapter import NetworkXAdapter -async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) -> GraphDBInterface : +async def get_graph_client(graph_type: GraphDBType=None, graph_file_name: str = None) -> GraphDBInterface : """Factory function to get the appropriate graph client based on the graph type.""" config = get_graph_config() - if graph_type == GraphDBType.NEO4J: + if config.graph_engine == GraphDBType.NEO4J: try: from .neo4j_driver.adapter import Neo4jAdapter @@ -22,7 +22,7 @@ async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) except: pass - elif graph_type == GraphDBType.FALKORDB: + elif config.graph_engine == GraphDBType.FALKORDB: try: from .falkordb.adapter import FalcorDBAdapter diff --git a/cognee/modules/cognify/graph/add_node_connections.py b/cognee/modules/cognify/graph/add_node_connections.py index 758d538dc..96f2dd662 100644 --- a/cognee/modules/cognify/graph/add_node_connections.py +++ b/cognee/modules/cognify/graph/add_node_connections.py @@ -92,7 +92,7 @@ def graph_ready_output(results): if __name__ == "__main__": async def main(): - graph_client = await get_graph_client(GraphDBType.NEO4J) + graph_client = await get_graph_client() graph = graph_client.graph # for nodes, attr in graph.nodes(data=True): diff --git a/cognee/modules/data/deletion/prune_system.py b/cognee/modules/data/deletion/prune_system.py index 3fbce112c..a913edeec 100644 --- a/cognee/modules/data/deletion/prune_system.py +++ b/cognee/modules/data/deletion/prune_system.py @@ -5,7 +5,7 @@ from cognee.infrastructure.databases.graph.get_graph_client import get_graph_cli async def prune_system(graph = True, vector = True): if graph: graph_config = get_graph_config() - graph_client = await get_graph_client(graph_config.graph_engine) + graph_client = await get_graph_client() await graph_client.delete_graph() if vector: diff --git a/cognee/modules/search/vector/search_similarity.py b/cognee/modules/search/vector/search_similarity.py index 17c0e81e4..de0d86e71 100644 --- a/cognee/modules/search/vector/search_similarity.py +++ b/cognee/modules/search/vector/search_similarity.py @@ -3,11 +3,11 @@ from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.vector import get_vector_engine async def search_similarity(query: str, graph): - graph_config = get_graph_config() + # graph_config = get_graph_config() + # + # graph_db_type = graph_config.graph_engine - graph_db_type = graph_config.graph_engine - - graph_client = await get_graph_client(graph_db_type) + graph_client = await get_graph_client() layer_nodes = await graph_client.get_layer_nodes()