rewrote configs

This commit is contained in:
Vasilije 2024-06-10 13:40:05 +02:00
parent a23fc40f6e
commit c9b2a06dff
8 changed files with 22 additions and 18 deletions

4
.gitignore vendored
View file

@ -10,6 +10,10 @@ __pycache__/
*.py[cod]
*$py.class
notebooks/
full_run.ipynb
evals/
# C extensions
*.so

View file

@ -73,9 +73,9 @@ async def get_dataset_graph(dataset_id: str):
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
try:
graph_config = get_graph_config()
graph_engine = graph_config.graph_engine
graph_client = await get_graph_client(graph_engine)
# graph_config = get_graph_config()
# graph_engine = graph_config.graph_engine
graph_client = await get_graph_client()
graph_url = await render_graph(graph_client.graph)
return JSONResponse(

View file

@ -49,9 +49,9 @@ async def cognify(datasets: Union[str, List[str]] = None):
stopwords.ensure_loaded()
create_task_status_table()
graph_config = get_graph_config()
graph_db_type = graph_config.graph_engine
graph_client = await get_graph_client(graph_db_type)
# graph_config = get_graph_config()
# graph_db_type = graph_config.graph_engine
graph_client = await get_graph_client()
relational_config = get_relationaldb_config()
db_engine = relational_config.database_engine
@ -180,7 +180,7 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
graph_config = get_graph_config()
graph_client = await get_graph_client(graph_config.graph_engine)
graph_client = await get_graph_client()
graph_topology = graph_config.graph_model
if graph_topology == SourceCodeGraph:

View file

@ -42,9 +42,9 @@ USER_ID = "default_user"
async def add_topology(directory: str = "example", model: BaseModel = GitHubRepositoryModel) -> Any:
graph_config = get_graph_config()
graph_db_type = graph_config.graph_database_provider
# graph_db_type = graph_config.graph_database_provider
graph_client = await get_graph_client(graph_db_type)
graph_client = await get_graph_client()
engine = TopologyEngine()
topology = await engine.infer_from_directory_structure(node_id=USER_ID, repository=directory, model=model)

View file

@ -6,11 +6,11 @@ from .graph_db_interface import GraphDBInterface
from .networkx.adapter import NetworkXAdapter
async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) -> GraphDBInterface :
async def get_graph_client(graph_type: GraphDBType=None, graph_file_name: str = None) -> GraphDBInterface :
"""Factory function to get the appropriate graph client based on the graph type."""
config = get_graph_config()
if graph_type == GraphDBType.NEO4J:
if config.graph_engine == GraphDBType.NEO4J:
try:
from .neo4j_driver.adapter import Neo4jAdapter
@ -22,7 +22,7 @@ async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None)
except:
pass
elif graph_type == GraphDBType.FALKORDB:
elif config.graph_engine == GraphDBType.FALKORDB:
try:
from .falkordb.adapter import FalcorDBAdapter

View file

@ -92,7 +92,7 @@ def graph_ready_output(results):
if __name__ == "__main__":
async def main():
graph_client = await get_graph_client(GraphDBType.NEO4J)
graph_client = await get_graph_client()
graph = graph_client.graph
# for nodes, attr in graph.nodes(data=True):

View file

@ -5,7 +5,7 @@ from cognee.infrastructure.databases.graph.get_graph_client import get_graph_cli
async def prune_system(graph = True, vector = True):
if graph:
graph_config = get_graph_config()
graph_client = await get_graph_client(graph_config.graph_engine)
graph_client = await get_graph_client()
await graph_client.delete_graph()
if vector:

View file

@ -3,11 +3,11 @@ from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.vector import get_vector_engine
async def search_similarity(query: str, graph):
graph_config = get_graph_config()
# graph_config = get_graph_config()
#
# graph_db_type = graph_config.graph_engine
graph_db_type = graph_config.graph_engine
graph_client = await get_graph_client(graph_db_type)
graph_client = await get_graph_client()
layer_nodes = await graph_client.get_layer_nodes()