diff --git a/cognee-frontend/src/ui/elements/Notebook/Notebook.tsx b/cognee-frontend/src/ui/elements/Notebook/Notebook.tsx index 553c254c9..31c716b96 100644 --- a/cognee-frontend/src/ui/elements/Notebook/Notebook.tsx +++ b/cognee-frontend/src/ui/elements/Notebook/Notebook.tsx @@ -244,67 +244,122 @@ function CellResult({ content }: { content: [] }) { for (const line of content) { try { if (Array.isArray(line)) { + // Insights search returns uncommon graph data structure + if (Array.from(line).length > 0 && Array.isArray(line[0]) && line[0][1]["relationship_name"]) { + parsedContent.push( +
++ query response (dataset: {item["dataset_name"]}) + {item["search_result"]} ++ ); + } else if (typeof(item) === "object" && item["search_result"] && typeof(item["search_result"]) === "object") { + parsedContent.push( ++ {JSON.stringify(item, null, 2)} ++ ) + } else if (typeof(item) === "string") { + parsedContent.push( +{item}); - } - if (typeof item === "object" && item["search_result"]) { + } else if (typeof(item) === "object" && !(item["search_result"] || item["graphs"])) { parsedContent.push( -- query response (dataset: {item["dataset_name"]}) - {item["search_result"]} -- ); ++ {JSON.stringify(item, null, 2)} ++ ) } - if (typeof item === "object" && item["graph"] && typeof item["graph"] === "object") { - parsedContent.push( -- reasoning graph -- ); + + if (typeof item === "object" && item["graphs"] && typeof item["graphs"] === "object") { + Object.entries<{ nodes: []; edges: []; }>(item["graphs"]).forEach(([datasetName, graph]) => { + parsedContent.push( +} - graphControls={graphControls} - className="min-h-48" - /> - + reasoning graph (datasets: {datasetName}) ++ ); + }); } } } - if (typeof(line) === "object" && line["result"]) { + + if (typeof(line) === "object" && line["result"] && typeof(line["result"]) === "string") { const datasets = Array.from( // eslint-disable-next-line @typescript-eslint/no-explicit-any new Set(Object.values(line["datasets"]).map((dataset: any) => dataset.name)) ).join(", "); parsedContent.push( -} + graphControls={graphControls} + className="min-h-80" + /> + +query response (datasets: {datasets}) - {line["result"]} + {line["result"]}); - if (line["graphs"]) { + } + if (typeof(line) === "object" && line["graphs"]) { + Object.entries<{ nodes: []; edges: []; }>(line["graphs"]).forEach(([datasetName, graph]) => { parsedContent.push( -- reasoning graph ++ reasoning graph (datasets: {datasetName})); - } + }); + } + + if (typeof(line) === "object" && line["result"] && typeof(line["result"]) === "object") { + parsedContent.push( +} graphControls={graphControls} className="min-h-80" /> + {JSON.stringify(line["result"], null, 2)} ++ ) + } + if (typeof(line) === "string") { + parsedContent.push( ++ {line} ++ ) } } catch (error) { console.error(error); - parsedContent.push(line); + parsedContent.push( ++ {line} ++ ); } } @@ -317,38 +372,61 @@ function CellResult({ content }: { content: [] }) { }; function transformToVisualizationData(graph: { nodes: [], edges: [] }) { - // Implementation to transform triplet to visualization data - return { nodes: graph.nodes, links: graph.edges, }; - - // const nodes = {}; - // const links = {}; - - // for (const triplet of triplets) { - // nodes[triplet.source.id] = { - // id: triplet.source.id, - // label: triplet.source.attributes.name, - // type: triplet.source.attributes.type, - // attributes: triplet.source.attributes, - // }; - // nodes[triplet.destination.id] = { - // id: triplet.destination.id, - // label: triplet.destination.attributes.name, - // type: triplet.destination.attributes.type, - // attributes: triplet.destination.attributes, - // }; - // links[`${triplet.source.id}_${triplet.attributes.relationship_name}_${triplet.destination.id}`] = { - // source: triplet.source.id, - // target: triplet.destination.id, - // label: triplet.attributes.relationship_name, - // } - // } - - // return { - // nodes: Object.values(nodes), - // links: Object.values(links), - // }; +} + +type Triplet = [{ + id: string, + name: string, + type: string, +}, { + relationship_name: string, +}, { + id: string, + name: string, + type: string, +}] + +function transformInsightsGraphData(triplets: Triplet[]) { + const nodes: { + [key: string]: { + id: string, + label: string, + type: string, + } + } = {}; + const links: { + [key: string]: { + source: string, + target: string, + label: string, + } + } = {}; + + for (const triplet of triplets) { + nodes[triplet[0].id] = { + id: triplet[0].id, + label: triplet[0].name || triplet[0].id, + type: triplet[0].type, + }; + nodes[triplet[2].id] = { + id: triplet[2].id, + label: triplet[2].name || triplet[2].id, + type: triplet[2].type, + }; + const linkKey = `${triplet[0]["id"]}_${triplet[1]["relationship_name"]}_${triplet[2]["id"]}`; + links[linkKey] = { + source: triplet[0].id, + target: triplet[2].id, + label: triplet[1]["relationship_name"], + }; + } + + return { + nodes: Object.values(nodes), + links: Object.values(links), + }; } diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 0c236d896..65efafb4c 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -136,12 +136,19 @@ async def search( if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": return_value = [] for search_result in search_results: - result, context, datasets = search_result + prepared_search_results = await prepare_search_result(search_result) + + result = prepared_search_results["result"] + graphs = prepared_search_results["graphs"] + context = prepared_search_results["context"] + datasets = prepared_search_results["datasets"] + return_value.append( { - "search_result": result, + "search_result": [result] if result else None, "dataset_id": datasets[0].id, "dataset_name": datasets[0].name, + "graphs": graphs, } ) return return_value @@ -155,14 +162,6 @@ async def search( return return_value[0] else: return return_value - # return [ - # SearchResult( - # search_result=result, - # dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None, - # dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None, - # ) - # for index, (result, _, datasets) in enumerate(search_results) - # ] async def authorized_search( @@ -208,11 +207,11 @@ async def authorized_search( context = {} datasets: List[Dataset] = [] - for _, search_context, datasets in search_responses: - for dataset in datasets: + for _, search_context, search_datasets in search_responses: + for dataset in search_datasets: context[str(dataset.id)] = search_context - datasets.extend(datasets) + datasets.extend(search_datasets) specific_search_tools = await get_search_type_tools( query_type=query_type, diff --git a/cognee/modules/search/utils/prepare_search_result.py b/cognee/modules/search/utils/prepare_search_result.py index 19cbe07ac..b854a318d 100644 --- a/cognee/modules/search/utils/prepare_search_result.py +++ b/cognee/modules/search/utils/prepare_search_result.py @@ -1,8 +1,11 @@ from typing import List, cast +from uuid import uuid5, NAMESPACE_OID from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.search.types.SearchResult import SearchResultDataset from cognee.modules.search.utils.transform_context_to_graph import transform_context_to_graph +from cognee.modules.search.utils.transform_insights_to_graph import transform_insights_to_graph async def prepare_search_result(search_result): @@ -12,29 +15,48 @@ async def prepare_search_result(search_result): result_graph = None context_texts = {} - if isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge): + if isinstance(datasets, list) and len(datasets) == 0: + datasets = [ + SearchResultDataset( + id=uuid5(NAMESPACE_OID, "*"), + name="all available datasets", + ) + ] + + if ( + isinstance(context, List) + and len(context) > 0 + and isinstance(context[0], tuple) + and context[0][1].get("relationship_name") + ): + context_graph = transform_insights_to_graph(context) + graphs = { + ", ".join([dataset.name for dataset in datasets]): context_graph, + } + results = None + elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge): context_graph = transform_context_to_graph(context) graphs = { - "*": context_graph, + ", ".join([dataset.name for dataset in datasets]): context_graph, } context_texts = { - "*": await resolve_edges_to_text(context), + ", ".join([dataset.name for dataset in datasets]): await resolve_edges_to_text(context), } elif isinstance(context, str): context_texts = { - "*": context, + ", ".join([dataset.name for dataset in datasets]): context, } elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], str): context_texts = { - "*": "\n".join(cast(List[str], context)), + ", ".join([dataset.name for dataset in datasets]): "\n".join(cast(List[str], context)), } if isinstance(results, List) and len(results) > 0 and isinstance(results[0], Edge): result_graph = transform_context_to_graph(results) return { - "result": result_graph or results[0] if len(results) == 1 else results, + "result": result_graph or results[0] if results and len(results) == 1 else results, "graphs": graphs, "context": context_texts, "datasets": datasets, diff --git a/cognee/modules/search/utils/transform_context_to_graph.py b/cognee/modules/search/utils/transform_context_to_graph.py index 0bc889575..4fc722dc6 100644 --- a/cognee/modules/search/utils/transform_context_to_graph.py +++ b/cognee/modules/search/utils/transform_context_to_graph.py @@ -14,7 +14,7 @@ def transform_context_to_graph(context: List[Edge]): if "name" in triplet.node1.attributes else triplet.node1.id, "type": triplet.node1.attributes["type"], - "attributes": triplet.node2.attributes, + "attributes": triplet.node1.attributes, } nodes[triplet.node2.id] = { "id": triplet.node2.id, diff --git a/cognee/modules/search/utils/transform_insights_to_graph.py b/cognee/modules/search/utils/transform_insights_to_graph.py new file mode 100644 index 000000000..e01a444cd --- /dev/null +++ b/cognee/modules/search/utils/transform_insights_to_graph.py @@ -0,0 +1,28 @@ +from typing import Dict, List, Tuple + + +def transform_insights_to_graph(context: List[Tuple[Dict, Dict, Dict]]): + nodes = {} + edges = {} + + for triplet in context: + nodes[triplet[0]["id"]] = { + "id": triplet[0]["id"], + "label": triplet[0]["name"] if "name" in triplet[0] else triplet[0]["id"], + "type": triplet[0]["type"], + } + nodes[triplet[2]["id"]] = { + "id": triplet[2]["id"], + "label": triplet[2]["name"] if "name" in triplet[2] else triplet[2]["id"], + "type": triplet[2]["type"], + } + edges[f"{triplet[0]['id']}_{triplet[1]['relationship_name']}_{triplet[2]['id']}"] = { + "source": triplet[0]["id"], + "target": triplet[2]["id"], + "label": triplet[1]["relationship_name"], + } + + return { + "nodes": list(nodes.values()), + "edges": list(edges.values()), + }