rewrote configs
This commit is contained in:
parent
a23fc40f6e
commit
c9b2a06dff
8 changed files with 22 additions and 18 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -10,6 +10,10 @@ __pycache__/
|
|||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
notebooks/
|
||||
full_run.ipynb
|
||||
evals/
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
|
|
|
|||
|
|
@ -73,9 +73,9 @@ async def get_dataset_graph(dataset_id: str):
|
|||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
|
||||
try:
|
||||
graph_config = get_graph_config()
|
||||
graph_engine = graph_config.graph_engine
|
||||
graph_client = await get_graph_client(graph_engine)
|
||||
# graph_config = get_graph_config()
|
||||
# graph_engine = graph_config.graph_engine
|
||||
graph_client = await get_graph_client()
|
||||
graph_url = await render_graph(graph_client.graph)
|
||||
|
||||
return JSONResponse(
|
||||
|
|
|
|||
|
|
@ -49,9 +49,9 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
stopwords.ensure_loaded()
|
||||
create_task_status_table()
|
||||
|
||||
graph_config = get_graph_config()
|
||||
graph_db_type = graph_config.graph_engine
|
||||
graph_client = await get_graph_client(graph_db_type)
|
||||
# graph_config = get_graph_config()
|
||||
# graph_db_type = graph_config.graph_engine
|
||||
graph_client = await get_graph_client()
|
||||
|
||||
relational_config = get_relationaldb_config()
|
||||
db_engine = relational_config.database_engine
|
||||
|
|
@ -180,7 +180,7 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
|||
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
|
||||
|
||||
graph_config = get_graph_config()
|
||||
graph_client = await get_graph_client(graph_config.graph_engine)
|
||||
graph_client = await get_graph_client()
|
||||
graph_topology = graph_config.graph_model
|
||||
|
||||
if graph_topology == SourceCodeGraph:
|
||||
|
|
|
|||
|
|
@ -42,9 +42,9 @@ USER_ID = "default_user"
|
|||
|
||||
async def add_topology(directory: str = "example", model: BaseModel = GitHubRepositoryModel) -> Any:
|
||||
graph_config = get_graph_config()
|
||||
graph_db_type = graph_config.graph_database_provider
|
||||
# graph_db_type = graph_config.graph_database_provider
|
||||
|
||||
graph_client = await get_graph_client(graph_db_type)
|
||||
graph_client = await get_graph_client()
|
||||
|
||||
engine = TopologyEngine()
|
||||
topology = await engine.infer_from_directory_structure(node_id=USER_ID, repository=directory, model=model)
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@ from .graph_db_interface import GraphDBInterface
|
|||
from .networkx.adapter import NetworkXAdapter
|
||||
|
||||
|
||||
async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) -> GraphDBInterface :
|
||||
async def get_graph_client(graph_type: GraphDBType=None, graph_file_name: str = None) -> GraphDBInterface :
|
||||
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||
config = get_graph_config()
|
||||
|
||||
if graph_type == GraphDBType.NEO4J:
|
||||
if config.graph_engine == GraphDBType.NEO4J:
|
||||
try:
|
||||
from .neo4j_driver.adapter import Neo4jAdapter
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None)
|
|||
except:
|
||||
pass
|
||||
|
||||
elif graph_type == GraphDBType.FALKORDB:
|
||||
elif config.graph_engine == GraphDBType.FALKORDB:
|
||||
try:
|
||||
from .falkordb.adapter import FalcorDBAdapter
|
||||
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ def graph_ready_output(results):
|
|||
if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
graph_client = await get_graph_client(GraphDBType.NEO4J)
|
||||
graph_client = await get_graph_client()
|
||||
graph = graph_client.graph
|
||||
|
||||
# for nodes, attr in graph.nodes(data=True):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from cognee.infrastructure.databases.graph.get_graph_client import get_graph_cli
|
|||
async def prune_system(graph = True, vector = True):
|
||||
if graph:
|
||||
graph_config = get_graph_config()
|
||||
graph_client = await get_graph_client(graph_config.graph_engine)
|
||||
graph_client = await get_graph_client()
|
||||
await graph_client.delete_graph()
|
||||
|
||||
if vector:
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ from cognee.infrastructure.databases.graph.config import get_graph_config
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
async def search_similarity(query: str, graph):
|
||||
graph_config = get_graph_config()
|
||||
# graph_config = get_graph_config()
|
||||
#
|
||||
# graph_db_type = graph_config.graph_engine
|
||||
|
||||
graph_db_type = graph_config.graph_engine
|
||||
|
||||
graph_client = await get_graph_client(graph_db_type)
|
||||
graph_client = await get_graph_client()
|
||||
|
||||
layer_nodes = await graph_client.get_layer_nodes()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue