rewrote configs
This commit is contained in:
parent
a23fc40f6e
commit
c9b2a06dff
8 changed files with 22 additions and 18 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -10,6 +10,10 @@ __pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
|
||||||
|
notebooks/
|
||||||
|
full_run.ipynb
|
||||||
|
evals/
|
||||||
|
|
||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -73,9 +73,9 @@ async def get_dataset_graph(dataset_id: str):
|
||||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
graph_config = get_graph_config()
|
# graph_config = get_graph_config()
|
||||||
graph_engine = graph_config.graph_engine
|
# graph_engine = graph_config.graph_engine
|
||||||
graph_client = await get_graph_client(graph_engine)
|
graph_client = await get_graph_client()
|
||||||
graph_url = await render_graph(graph_client.graph)
|
graph_url = await render_graph(graph_client.graph)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
|
|
||||||
|
|
@ -49,9 +49,9 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
||||||
stopwords.ensure_loaded()
|
stopwords.ensure_loaded()
|
||||||
create_task_status_table()
|
create_task_status_table()
|
||||||
|
|
||||||
graph_config = get_graph_config()
|
# graph_config = get_graph_config()
|
||||||
graph_db_type = graph_config.graph_engine
|
# graph_db_type = graph_config.graph_engine
|
||||||
graph_client = await get_graph_client(graph_db_type)
|
graph_client = await get_graph_client()
|
||||||
|
|
||||||
relational_config = get_relationaldb_config()
|
relational_config = get_relationaldb_config()
|
||||||
db_engine = relational_config.database_engine
|
db_engine = relational_config.database_engine
|
||||||
|
|
@ -180,7 +180,7 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
||||||
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
|
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
|
||||||
|
|
||||||
graph_config = get_graph_config()
|
graph_config = get_graph_config()
|
||||||
graph_client = await get_graph_client(graph_config.graph_engine)
|
graph_client = await get_graph_client()
|
||||||
graph_topology = graph_config.graph_model
|
graph_topology = graph_config.graph_model
|
||||||
|
|
||||||
if graph_topology == SourceCodeGraph:
|
if graph_topology == SourceCodeGraph:
|
||||||
|
|
|
||||||
|
|
@ -42,9 +42,9 @@ USER_ID = "default_user"
|
||||||
|
|
||||||
async def add_topology(directory: str = "example", model: BaseModel = GitHubRepositoryModel) -> Any:
|
async def add_topology(directory: str = "example", model: BaseModel = GitHubRepositoryModel) -> Any:
|
||||||
graph_config = get_graph_config()
|
graph_config = get_graph_config()
|
||||||
graph_db_type = graph_config.graph_database_provider
|
# graph_db_type = graph_config.graph_database_provider
|
||||||
|
|
||||||
graph_client = await get_graph_client(graph_db_type)
|
graph_client = await get_graph_client()
|
||||||
|
|
||||||
engine = TopologyEngine()
|
engine = TopologyEngine()
|
||||||
topology = await engine.infer_from_directory_structure(node_id=USER_ID, repository=directory, model=model)
|
topology = await engine.infer_from_directory_structure(node_id=USER_ID, repository=directory, model=model)
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,11 @@ from .graph_db_interface import GraphDBInterface
|
||||||
from .networkx.adapter import NetworkXAdapter
|
from .networkx.adapter import NetworkXAdapter
|
||||||
|
|
||||||
|
|
||||||
async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) -> GraphDBInterface :
|
async def get_graph_client(graph_type: GraphDBType=None, graph_file_name: str = None) -> GraphDBInterface :
|
||||||
"""Factory function to get the appropriate graph client based on the graph type."""
|
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||||
config = get_graph_config()
|
config = get_graph_config()
|
||||||
|
|
||||||
if graph_type == GraphDBType.NEO4J:
|
if config.graph_engine == GraphDBType.NEO4J:
|
||||||
try:
|
try:
|
||||||
from .neo4j_driver.adapter import Neo4jAdapter
|
from .neo4j_driver.adapter import Neo4jAdapter
|
||||||
|
|
||||||
|
|
@ -22,7 +22,7 @@ async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif graph_type == GraphDBType.FALKORDB:
|
elif config.graph_engine == GraphDBType.FALKORDB:
|
||||||
try:
|
try:
|
||||||
from .falkordb.adapter import FalcorDBAdapter
|
from .falkordb.adapter import FalcorDBAdapter
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ def graph_ready_output(results):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
graph_client = await get_graph_client(GraphDBType.NEO4J)
|
graph_client = await get_graph_client()
|
||||||
graph = graph_client.graph
|
graph = graph_client.graph
|
||||||
|
|
||||||
# for nodes, attr in graph.nodes(data=True):
|
# for nodes, attr in graph.nodes(data=True):
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from cognee.infrastructure.databases.graph.get_graph_client import get_graph_cli
|
||||||
async def prune_system(graph = True, vector = True):
|
async def prune_system(graph = True, vector = True):
|
||||||
if graph:
|
if graph:
|
||||||
graph_config = get_graph_config()
|
graph_config = get_graph_config()
|
||||||
graph_client = await get_graph_client(graph_config.graph_engine)
|
graph_client = await get_graph_client()
|
||||||
await graph_client.delete_graph()
|
await graph_client.delete_graph()
|
||||||
|
|
||||||
if vector:
|
if vector:
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,11 @@ from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
||||||
async def search_similarity(query: str, graph):
|
async def search_similarity(query: str, graph):
|
||||||
graph_config = get_graph_config()
|
# graph_config = get_graph_config()
|
||||||
|
#
|
||||||
|
# graph_db_type = graph_config.graph_engine
|
||||||
|
|
||||||
graph_db_type = graph_config.graph_engine
|
graph_client = await get_graph_client()
|
||||||
|
|
||||||
graph_client = await get_graph_client(graph_db_type)
|
|
||||||
|
|
||||||
layer_nodes = await graph_client.get_layer_nodes()
|
layer_nodes = await graph_client.get_layer_nodes()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue