fix: combined context search (#1420)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> ## Type of Change <!-- Please check the relevant option --> - [x] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Changes Made <!-- List the specific changes made in this PR --> - - - ## Testing <!-- Describe how you tested your changes --> ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## Related Issues <!-- Link any related issues using "Fixes #issue_number" or "Relates to #issue_number" --> ## Additional Notes <!-- Add any additional notes, concerns, or context for reviewers --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
250cf59845
commit
3e870a076e
5 changed files with 207 additions and 80 deletions
|
|
@ -244,67 +244,122 @@ function CellResult({ content }: { content: [] }) {
|
|||
for (const line of content) {
|
||||
try {
|
||||
if (Array.isArray(line)) {
|
||||
// Insights search returns uncommon graph data structure
|
||||
if (Array.from(line).length > 0 && Array.isArray(line[0]) && line[0][1]["relationship_name"]) {
|
||||
parsedContent.push(
|
||||
<div key={line[0][1]["relationship_name"]} className="w-full h-full bg-white">
|
||||
<span className="text-sm pl-2 mb-4">reasoning graph</span>
|
||||
<GraphVisualization
|
||||
data={transformInsightsGraphData(line)}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-48"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// @ts-expect-error line can be Array or string
|
||||
for (const item of line) {
|
||||
if (typeof item === "string") {
|
||||
if (
|
||||
typeof item === "object" && item["search_result"] && (typeof(item["search_result"]) === "string"
|
||||
|| (Array.isArray(item["search_result"]) && typeof(item["search_result"][0]) === "string"))
|
||||
) {
|
||||
parsedContent.push(
|
||||
<pre key={item.slice(0, -10)}>
|
||||
<div key={String(item["search_result"])} className="w-full h-full bg-white">
|
||||
<span className="text-sm pl-2 mb-4">query response (dataset: {item["dataset_name"]})</span>
|
||||
<span className="block px-2 py-2 whitespace-normal">{item["search_result"]}</span>
|
||||
</div>
|
||||
);
|
||||
} else if (typeof(item) === "object" && item["search_result"] && typeof(item["search_result"]) === "object") {
|
||||
parsedContent.push(
|
||||
<pre className="px-2 w-full h-full bg-white text-sm" key={String(item).slice(0, -10)}>
|
||||
{JSON.stringify(item, null, 2)}
|
||||
</pre>
|
||||
)
|
||||
} else if (typeof(item) === "string") {
|
||||
parsedContent.push(
|
||||
<pre className="px-2 w-full h-full bg-white text-sm whitespace-normal" key={item.slice(0, -10)}>
|
||||
{item}
|
||||
</pre>
|
||||
);
|
||||
}
|
||||
if (typeof item === "object" && item["search_result"]) {
|
||||
} else if (typeof(item) === "object" && !(item["search_result"] || item["graphs"])) {
|
||||
parsedContent.push(
|
||||
<div className="w-full h-full bg-white">
|
||||
<span className="text-sm pl-2 mb-4">query response (dataset: {item["dataset_name"]})</span>
|
||||
<span className="block px-2 py-2">{item["search_result"]}</span>
|
||||
</div>
|
||||
);
|
||||
<pre className="px-2 w-full h-full bg-white text-sm" key={String(item).slice(0, -10)}>
|
||||
{JSON.stringify(item, null, 2)}
|
||||
</pre>
|
||||
)
|
||||
}
|
||||
if (typeof item === "object" && item["graph"] && typeof item["graph"] === "object") {
|
||||
parsedContent.push(
|
||||
<div className="w-full h-full bg-white">
|
||||
<span className="text-sm pl-2 mb-4">reasoning graph</span>
|
||||
<GraphVisualization
|
||||
data={transformToVisualizationData(item["graph"])}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-48"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
if (typeof item === "object" && item["graphs"] && typeof item["graphs"] === "object") {
|
||||
Object.entries<{ nodes: []; edges: []; }>(item["graphs"]).forEach(([datasetName, graph]) => {
|
||||
parsedContent.push(
|
||||
<div key={datasetName} className="w-full h-full bg-white">
|
||||
<span className="text-sm pl-2 mb-4">reasoning graph (datasets: {datasetName})</span>
|
||||
<GraphVisualization
|
||||
data={transformToVisualizationData(graph)}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-80"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
if (typeof(line) === "object" && line["result"]) {
|
||||
|
||||
if (typeof(line) === "object" && line["result"] && typeof(line["result"]) === "string") {
|
||||
const datasets = Array.from(
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
new Set(Object.values(line["datasets"]).map((dataset: any) => dataset.name))
|
||||
).join(", ");
|
||||
|
||||
parsedContent.push(
|
||||
<div className="w-full h-full bg-white">
|
||||
<div key={line["result"]} className="w-full h-full bg-white">
|
||||
<span className="text-sm pl-2 mb-4">query response (datasets: {datasets})</span>
|
||||
<span className="block px-2 py-2">{line["result"]}</span>
|
||||
<span className="block px-2 py-2 whitespace-normal">{line["result"]}</span>
|
||||
</div>
|
||||
);
|
||||
if (line["graphs"]) {
|
||||
}
|
||||
if (typeof(line) === "object" && line["graphs"]) {
|
||||
Object.entries<{ nodes: []; edges: []; }>(line["graphs"]).forEach(([datasetName, graph]) => {
|
||||
parsedContent.push(
|
||||
<div className="w-full h-full bg-white">
|
||||
<span className="text-sm pl-2 mb-4">reasoning graph</span>
|
||||
<div key={datasetName} className="w-full h-full bg-white">
|
||||
<span className="text-sm pl-2 mb-4">reasoning graph (datasets: {datasetName})</span>
|
||||
<GraphVisualization
|
||||
data={transformToVisualizationData(line["graphs"]["*"])}
|
||||
data={transformToVisualizationData(graph)}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-80"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (typeof(line) === "object" && line["result"] && typeof(line["result"]) === "object") {
|
||||
parsedContent.push(
|
||||
<pre className="px-2 w-full h-full bg-white text-sm" key={String(line).slice(0, -10)}>
|
||||
{JSON.stringify(line["result"], null, 2)}
|
||||
</pre>
|
||||
)
|
||||
}
|
||||
if (typeof(line) === "string") {
|
||||
parsedContent.push(
|
||||
<pre className="px-2 w-full h-full bg-white text-sm whitespace-normal" key={String(line).slice(0, -10)}>
|
||||
{line}
|
||||
</pre>
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
parsedContent.push(line);
|
||||
parsedContent.push(
|
||||
<pre className="px-2 w-full h-full bg-white text-sm whitespace-normal" key={String(line).slice(0, -10)}>
|
||||
{line}
|
||||
</pre>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -317,38 +372,61 @@ function CellResult({ content }: { content: [] }) {
|
|||
};
|
||||
|
||||
function transformToVisualizationData(graph: { nodes: [], edges: [] }) {
|
||||
// Implementation to transform triplet to visualization data
|
||||
|
||||
return {
|
||||
nodes: graph.nodes,
|
||||
links: graph.edges,
|
||||
};
|
||||
|
||||
// const nodes = {};
|
||||
// const links = {};
|
||||
|
||||
// for (const triplet of triplets) {
|
||||
// nodes[triplet.source.id] = {
|
||||
// id: triplet.source.id,
|
||||
// label: triplet.source.attributes.name,
|
||||
// type: triplet.source.attributes.type,
|
||||
// attributes: triplet.source.attributes,
|
||||
// };
|
||||
// nodes[triplet.destination.id] = {
|
||||
// id: triplet.destination.id,
|
||||
// label: triplet.destination.attributes.name,
|
||||
// type: triplet.destination.attributes.type,
|
||||
// attributes: triplet.destination.attributes,
|
||||
// };
|
||||
// links[`${triplet.source.id}_${triplet.attributes.relationship_name}_${triplet.destination.id}`] = {
|
||||
// source: triplet.source.id,
|
||||
// target: triplet.destination.id,
|
||||
// label: triplet.attributes.relationship_name,
|
||||
// }
|
||||
// }
|
||||
|
||||
// return {
|
||||
// nodes: Object.values(nodes),
|
||||
// links: Object.values(links),
|
||||
// };
|
||||
}
|
||||
|
||||
type Triplet = [{
|
||||
id: string,
|
||||
name: string,
|
||||
type: string,
|
||||
}, {
|
||||
relationship_name: string,
|
||||
}, {
|
||||
id: string,
|
||||
name: string,
|
||||
type: string,
|
||||
}]
|
||||
|
||||
function transformInsightsGraphData(triplets: Triplet[]) {
|
||||
const nodes: {
|
||||
[key: string]: {
|
||||
id: string,
|
||||
label: string,
|
||||
type: string,
|
||||
}
|
||||
} = {};
|
||||
const links: {
|
||||
[key: string]: {
|
||||
source: string,
|
||||
target: string,
|
||||
label: string,
|
||||
}
|
||||
} = {};
|
||||
|
||||
for (const triplet of triplets) {
|
||||
nodes[triplet[0].id] = {
|
||||
id: triplet[0].id,
|
||||
label: triplet[0].name || triplet[0].id,
|
||||
type: triplet[0].type,
|
||||
};
|
||||
nodes[triplet[2].id] = {
|
||||
id: triplet[2].id,
|
||||
label: triplet[2].name || triplet[2].id,
|
||||
type: triplet[2].type,
|
||||
};
|
||||
const linkKey = `${triplet[0]["id"]}_${triplet[1]["relationship_name"]}_${triplet[2]["id"]}`;
|
||||
links[linkKey] = {
|
||||
source: triplet[0].id,
|
||||
target: triplet[2].id,
|
||||
label: triplet[1]["relationship_name"],
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
nodes: Object.values(nodes),
|
||||
links: Object.values(links),
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -136,12 +136,19 @@ async def search(
|
|||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
|
||||
return_value.append(
|
||||
{
|
||||
"search_result": result,
|
||||
"search_result": [result] if result else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"graphs": graphs,
|
||||
}
|
||||
)
|
||||
return return_value
|
||||
|
|
@ -155,14 +162,6 @@ async def search(
|
|||
return return_value[0]
|
||||
else:
|
||||
return return_value
|
||||
# return [
|
||||
# SearchResult(
|
||||
# search_result=result,
|
||||
# dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
|
||||
# dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
|
||||
# )
|
||||
# for index, (result, _, datasets) in enumerate(search_results)
|
||||
# ]
|
||||
|
||||
|
||||
async def authorized_search(
|
||||
|
|
@ -208,11 +207,11 @@ async def authorized_search(
|
|||
context = {}
|
||||
datasets: List[Dataset] = []
|
||||
|
||||
for _, search_context, datasets in search_responses:
|
||||
for dataset in datasets:
|
||||
for _, search_context, search_datasets in search_responses:
|
||||
for dataset in search_datasets:
|
||||
context[str(dataset.id)] = search_context
|
||||
|
||||
datasets.extend(datasets)
|
||||
datasets.extend(search_datasets)
|
||||
|
||||
specific_search_tools = await get_search_type_tools(
|
||||
query_type=query_type,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
from typing import List, cast
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.search.types.SearchResult import SearchResultDataset
|
||||
from cognee.modules.search.utils.transform_context_to_graph import transform_context_to_graph
|
||||
from cognee.modules.search.utils.transform_insights_to_graph import transform_insights_to_graph
|
||||
|
||||
|
||||
async def prepare_search_result(search_result):
|
||||
|
|
@ -12,29 +15,48 @@ async def prepare_search_result(search_result):
|
|||
result_graph = None
|
||||
context_texts = {}
|
||||
|
||||
if isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge):
|
||||
if isinstance(datasets, list) and len(datasets) == 0:
|
||||
datasets = [
|
||||
SearchResultDataset(
|
||||
id=uuid5(NAMESPACE_OID, "*"),
|
||||
name="all available datasets",
|
||||
)
|
||||
]
|
||||
|
||||
if (
|
||||
isinstance(context, List)
|
||||
and len(context) > 0
|
||||
and isinstance(context[0], tuple)
|
||||
and context[0][1].get("relationship_name")
|
||||
):
|
||||
context_graph = transform_insights_to_graph(context)
|
||||
graphs = {
|
||||
", ".join([dataset.name for dataset in datasets]): context_graph,
|
||||
}
|
||||
results = None
|
||||
elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge):
|
||||
context_graph = transform_context_to_graph(context)
|
||||
|
||||
graphs = {
|
||||
"*": context_graph,
|
||||
", ".join([dataset.name for dataset in datasets]): context_graph,
|
||||
}
|
||||
context_texts = {
|
||||
"*": await resolve_edges_to_text(context),
|
||||
", ".join([dataset.name for dataset in datasets]): await resolve_edges_to_text(context),
|
||||
}
|
||||
elif isinstance(context, str):
|
||||
context_texts = {
|
||||
"*": context,
|
||||
", ".join([dataset.name for dataset in datasets]): context,
|
||||
}
|
||||
elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], str):
|
||||
context_texts = {
|
||||
"*": "\n".join(cast(List[str], context)),
|
||||
", ".join([dataset.name for dataset in datasets]): "\n".join(cast(List[str], context)),
|
||||
}
|
||||
|
||||
if isinstance(results, List) and len(results) > 0 and isinstance(results[0], Edge):
|
||||
result_graph = transform_context_to_graph(results)
|
||||
|
||||
return {
|
||||
"result": result_graph or results[0] if len(results) == 1 else results,
|
||||
"result": result_graph or results[0] if results and len(results) == 1 else results,
|
||||
"graphs": graphs,
|
||||
"context": context_texts,
|
||||
"datasets": datasets,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ def transform_context_to_graph(context: List[Edge]):
|
|||
if "name" in triplet.node1.attributes
|
||||
else triplet.node1.id,
|
||||
"type": triplet.node1.attributes["type"],
|
||||
"attributes": triplet.node2.attributes,
|
||||
"attributes": triplet.node1.attributes,
|
||||
}
|
||||
nodes[triplet.node2.id] = {
|
||||
"id": triplet.node2.id,
|
||||
|
|
|
|||
28
cognee/modules/search/utils/transform_insights_to_graph.py
Normal file
28
cognee/modules/search/utils/transform_insights_to_graph.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
def transform_insights_to_graph(context: List[Tuple[Dict, Dict, Dict]]):
|
||||
nodes = {}
|
||||
edges = {}
|
||||
|
||||
for triplet in context:
|
||||
nodes[triplet[0]["id"]] = {
|
||||
"id": triplet[0]["id"],
|
||||
"label": triplet[0]["name"] if "name" in triplet[0] else triplet[0]["id"],
|
||||
"type": triplet[0]["type"],
|
||||
}
|
||||
nodes[triplet[2]["id"]] = {
|
||||
"id": triplet[2]["id"],
|
||||
"label": triplet[2]["name"] if "name" in triplet[2] else triplet[2]["id"],
|
||||
"type": triplet[2]["type"],
|
||||
}
|
||||
edges[f"{triplet[0]['id']}_{triplet[1]['relationship_name']}_{triplet[2]['id']}"] = {
|
||||
"source": triplet[0]["id"],
|
||||
"target": triplet[2]["id"],
|
||||
"label": triplet[1]["relationship_name"],
|
||||
}
|
||||
|
||||
return {
|
||||
"nodes": list(nodes.values()),
|
||||
"edges": list(edges.values()),
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue