fix: graph view with search (#1368)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Boris 2025-09-11 16:16:03 +02:00 committed by GitHub
parent 47cb34e89c
commit 74a3220e9f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 268 additions and 141 deletions

View file

@ -0,0 +1,46 @@
"""Add notebook table
Revision ID: 45957f0a9849
Revises: 9e7a3cb85175
Create Date: 2025-09-10 17:47:58.201319
"""
from datetime import datetime, timezone
from uuid import uuid4
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "45957f0a9849"
down_revision: Union[str, None] = "9e7a3cb85175"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
if "notebooks" not in inspector.get_table_names():
# Define table with all necessary columns including primary key
op.create_table(
"notebooks",
sa.Column("id", sa.UUID, primary_key=True, default=uuid4), # Critical for SQLite
sa.Column("owner_id", sa.UUID, index=True),
sa.Column("name", sa.String(), nullable=False),
sa.Column("cells", sa.JSON(), nullable=False),
sa.Column("deletable", sa.Boolean(), default=True),
sa.Column("created_at", sa.DateTime(), default=lambda: datetime.now(timezone.utc)),
)
def downgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
if "notebooks" in inspector.get_table_names():
op.drop_table("notebooks")

View file

@ -39,7 +39,8 @@ export default function AddDataToCognee({ datasets, refreshDatasets, useCloud =
} : {
name: "main_dataset",
},
Array.from(filesForUpload)
Array.from(filesForUpload),
useCloud
)
.then(({ dataset_id, dataset_name }) => {
refreshDatasets();

View file

@ -5,16 +5,31 @@ import { useCallback, useEffect, useRef, useState } from "react";
import { Header } from "@/ui/Layout";
import { SearchIcon } from "@/ui/Icons";
import { Notebook } from "@/ui/elements";
import { fetch, isCloudEnvironment } from "@/utils";
import { Notebook as NotebookType } from "@/ui/elements/Notebook/types";
import { useAuthenticatedUser } from "@/modules/auth";
import { Dataset } from "@/modules/ingestion/useDatasets";
import useNotebooks from "@/modules/notebooks/useNotebooks";
import AddDataToCognee from "./AddDataToCognee";
import NotebooksAccordion from "./NotebooksAccordion";
import CogneeInstancesAccordion from "./CogneeInstancesAccordion";
import AddDataToCognee from "./AddDataToCognee";
import InstanceDatasetsAccordion from "./InstanceDatasetsAccordion";
export default function Dashboard() {
interface DashboardProps {
user?: {
id: string;
name: string;
email: string;
picture: string;
};
accessToken: string;
}
export default function Dashboard({ accessToken }: DashboardProps) {
fetch.setAccessToken(accessToken);
const { user } = useAuthenticatedUser();
const {
notebooks,
refreshNotebooks,
@ -91,6 +106,8 @@ export default function Dashboard() {
setDatasets(datasets);
}, []);
const isCloudEnv = isCloudEnvironment();
return (
<div className="h-full flex flex-col bg-gray-200">
<video
@ -104,7 +121,7 @@ export default function Dashboard() {
Your browser does not support the video tag.
</video>
<Header />
<Header user={user} />
<div className="relative flex-1 flex flex-row gap-2.5 items-start w-full max-w-[1920px] max-h-[calc(100% - 3.5rem)] overflow-hidden mx-auto px-2.5 py-2.5">
<div className="px-5 py-4 lg:w-96 bg-white rounded-xl min-h-full">
@ -116,6 +133,7 @@ export default function Dashboard() {
<AddDataToCognee
datasets={datasets}
refreshDatasets={refreshDatasetsRef.current}
useCloud={isCloudEnv}
/>
<NotebooksAccordion

View file

@ -1,6 +1,7 @@
import classNames from "classnames";
import { useCallback, useEffect } from "react";
import { fetch, useBoolean } from "@/utils";
import { fetch, isCloudEnvironment, useBoolean } from "@/utils";
import { checkCloudConnection } from "@/modules/cloud";
import { CaretIcon, CloseIcon, CloudIcon, LocalCogneeIcon } from "@/ui/Icons";
import { CTAButton, GhostButton, IconButton, Input, Modal } from "@/ui/elements";
@ -18,11 +19,13 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
const {
value: isCloudCogneeConnected,
setTrue: setCloudCogneeConnected,
} = useBoolean(false);
} = useBoolean(isCloudEnvironment());
const checkConnectionToCloudCognee = useCallback((apiKey: string) => {
fetch.setApiKey(apiKey);
return checkCloudConnection(apiKey)
const checkConnectionToCloudCognee = useCallback((apiKey?: string) => {
if (apiKey) {
fetch.setApiKey(apiKey);
}
return checkCloudConnection()
.then(setCloudCogneeConnected)
}, [setCloudCogneeConnected]);
@ -33,8 +36,6 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
};
checkConnectionToLocalCognee();
checkConnectionToCloudCognee("");
}, [checkConnectionToCloudCognee, setCloudCogneeConnected, setLocalCogneeConnected]);
const {
@ -54,8 +55,12 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
});
};
const isCloudEnv = isCloudEnvironment();
return (
<>
<div className={classNames("flex flex-col", {
"flex-col-reverse": isCloudEnv,
})}>
<DatasetsAccordion
title={(
<div className="flex flex-row items-center justify-between">
@ -69,7 +74,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
switchCaretPosition={true}
className="pt-3 pb-1.5"
contentClassName="pl-4"
onDatasetsChange={onDatasetsChange}
onDatasetsChange={!isCloudEnv ? onDatasetsChange : () => {}}
/>
{isCloudCogneeConnected ? (
@ -86,11 +91,11 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
switchCaretPosition={true}
className="pt-3 pb-1.5"
contentClassName="pl-4"
onDatasetsChange={onDatasetsChange}
onDatasetsChange={isCloudEnv ? onDatasetsChange : () => {}}
useCloud={true}
/>
) : (
<button className="w-full flex flex-row items-center justify-between py-1.5 cursor-pointer" onClick={!isCloudCogneeConnected ? openCloudConnectionModal : () => {}}>
<button className="w-full flex flex-row items-center justify-between py-1.5 cursor-pointer pt-3" onClick={!isCloudCogneeConnected ? openCloudConnectionModal : () => {}}>
<div className="flex flex-row items-center gap-1.5">
<CaretIcon className="rotate-[-90deg]" />
<div className="flex flex-row items-center gap-2">
@ -120,6 +125,6 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
</form>
</div>
</Modal>
</>
</div>
);
}

View file

@ -1 +1,11 @@
export { default } from "./Dashboard";
"use server";
import Dashboard from "./Dashboard";
export default async function Page() {
const accessToken = "";
return (
<Dashboard accessToken={accessToken} />
);
}

View file

@ -1,3 +1,3 @@
export { default } from "./(graph)/GraphView";
export { default } from "./dashboard/page";
export const dynamic = "force-dynamic";
// export const dynamic = "force-dynamic";

View file

@ -0,0 +1,6 @@
import fetch from "@/utils/fetch";
export default function getUser() {
return fetch("/v1/auth/me")
.then((response) => response.json());
}

View file

@ -1,2 +1,2 @@
export { default as useAuthenticatedUser } from "./useAuthenticatedUser";
export { type User } from "./types";
export { default as useAuthenticatedUser } from "./useAuthenticatedUser";

View file

@ -2,5 +2,5 @@ export interface User {
id: string;
name: string;
email: string;
avatarImagePath: string;
picture: string;
}

View file

@ -3,7 +3,7 @@ import { fetch } from "@/utils";
import { User } from "./types";
export default function useAuthenticatedUser() {
const [user, setUser] = useState<User | null>(null);
const [user, setUser] = useState<User>();
useEffect(() => {
if (!user) {

View file

@ -1,10 +1,7 @@
import { fetch } from "@/utils";
export default function checkCloudConnection(apiKey: string) {
export default function checkCloudConnection() {
return fetch("/v1/checks/connection", {
method: "POST",
headers: {
"X-Api-Key": apiKey,
},
});
}

View file

@ -1,5 +1,5 @@
import { useCallback, useState } from "react";
import { fetch } from "@/utils";
import { fetch, isCloudEnvironment } from "@/utils";
import { Cell, Notebook } from "@/ui/elements/Notebook/types";
function useNotebooks() {
@ -12,7 +12,7 @@ function useNotebooks() {
headers: {
"Content-Type": "application/json",
},
})
}, isCloudEnvironment())
.then((response) => response.json())
.then((notebook) => {
setNotebooks((notebooks) => [
@ -26,30 +26,31 @@ function useNotebooks() {
const removeNotebook = useCallback((notebookId: string) => {
return fetch(`/v1/notebooks/${notebookId}`, {
method: "DELETE",
})
.then(() => {
setNotebooks((notebooks) =>
notebooks.filter((notebook) => notebook.id !== notebookId)
);
});
method: "DELETE",
}, isCloudEnvironment())
.then(() => {
setNotebooks((notebooks) =>
notebooks.filter((notebook) => notebook.id !== notebookId)
);
});
}, []);
const fetchNotebooks = useCallback(() => {
return fetch("/v1/notebooks", {
headers: {
"Content-Type": "application/json",
},
})
.then((response) => response.json())
.then((notebooks) => {
setNotebooks(notebooks);
headers: {
"Content-Type": "application/json",
},
}, isCloudEnvironment())
.then((response) => response.json())
.then((notebooks) => {
setNotebooks(notebooks);
return notebooks;
})
.catch((error) => {
console.error("Error fetching notebooks:", error);
});
return notebooks;
})
.catch((error) => {
console.error("Error fetching notebooks:", error);
throw error
});
}, []);
const updateNotebook = useCallback((updatedNotebook: Notebook) => {
@ -64,19 +65,19 @@ function useNotebooks() {
const saveNotebook = useCallback((notebook: Notebook) => {
return fetch(`/v1/notebooks/${notebook.id}`, {
body: JSON.stringify({
name: notebook.name,
cells: notebook.cells,
}),
method: "PUT",
headers: {
"Content-Type": "application/json",
},
})
.then((response) => response.json())
body: JSON.stringify({
name: notebook.name,
cells: notebook.cells,
}),
method: "PUT",
headers: {
"Content-Type": "application/json",
},
}, isCloudEnvironment())
.then((response) => response.json())
}, []);
const runCell = useCallback((notebook: Notebook, cell: Cell) => {
const runCell = useCallback((notebook: Notebook, cell: Cell, cogneeInstance: string) => {
setNotebooks((existingNotebooks) =>
existingNotebooks.map((existingNotebook) =>
existingNotebook.id === notebook.id ? {
@ -100,7 +101,7 @@ function useNotebooks() {
headers: {
"Content-Type": "application/json",
},
})
}, cogneeInstance === "cloud")
.then((response) => response.json())
.then((response) => {
setNotebooks((existingNotebooks) =>

View file

@ -6,12 +6,17 @@ import { useBoolean } from "@/utils";
import { CloseIcon, CloudIcon, CogneeIcon } from "../Icons";
import { CTAButton, GhostButton, IconButton, Modal } from "../elements";
import { useAuthenticatedUser } from "@/modules/auth";
import syncData from "@/modules/cloud/syncData";
export default function Header() {
const { user } = useAuthenticatedUser();
interface HeaderProps {
user?: {
name: string;
email: string;
picture: string;
};
}
export default function Header({ user }: HeaderProps) {
const {
value: isSyncModalOpen,
setTrue: openSyncModal,
@ -45,8 +50,8 @@ export default function Header() {
<SettingsIcon />
</div> */}
<Link href="/account" className="bg-indigo-600 w-8 h-8 rounded-full overflow-hidden">
{user?.avatarImagePath ? (
<Image width="32" height="32" alt="Name of the user" src={user.avatarImagePath} />
{user?.picture ? (
<Image width="32" height="32" alt="Name of the user" src={user.picture} />
) : (
<div className="w-8 h-8 rounded-full text-white flex items-center justify-center">
{user?.email?.charAt(0) || "C"}

View file

@ -14,7 +14,7 @@ import { Cell, Notebook as NotebookType } from "./types";
interface NotebookProps {
notebook: NotebookType;
runCell: (notebook: NotebookType, cell: Cell) => Promise<void>;
runCell: (notebook: NotebookType, cell: Cell, cogneeInstance: string) => Promise<void>;
updateNotebook: (updatedNotebook: NotebookType) => void;
saveNotebook: (notebook: NotebookType) => void;
}
@ -47,8 +47,8 @@ export default function Notebook({ notebook, updateNotebook, saveNotebook, runCe
}
}, [notebook, saveNotebook, updateNotebook]);
const handleCellRun = useCallback((cell: Cell) => {
return runCell(notebook, cell);
const handleCellRun = useCallback((cell: Cell, cogneeInstance: string) => {
return runCell(notebook, cell, cogneeInstance);
}, [notebook, runCell]);
const handleCellAdd = useCallback((afterCellIndex: number, cellType: "markdown" | "code") => {
@ -244,7 +244,7 @@ export default function Notebook({ notebook, updateNotebook, saveNotebook, runCe
}
function CellResult({ content = [] }) {
function CellResult({ content }: { content: [] }) {
const parsedContent = [];
const graphRef = useRef<GraphVisualizationAPI>();
@ -256,6 +256,7 @@ function CellResult({ content = [] }) {
for (const line of content) {
try {
if (Array.isArray(line)) {
// @ts-expect-error line can be Array or string
for (const item of line) {
if (typeof item === "string") {
parsedContent.push(
@ -264,15 +265,13 @@ function CellResult({ content = [] }) {
</pre>
);
}
if (typeof item === "object" && item["search_result"] && Array.isArray(item["search_result"])) {
for (const result of item["search_result"]) {
parsedContent.push(
<div className="w-full h-full bg-white">
<span className="text-sm pl-2 mb-4">query response (dataset: {item["dataset_name"]})</span>
<span className="block px-2 py-2">{result}</span>
</div>
);
}
if (typeof item === "object" && item["search_result"]) {
parsedContent.push(
<div className="w-full h-full bg-white">
<span className="text-sm pl-2 mb-4">query response (dataset: {item["dataset_name"]})</span>
<span className="block px-2 py-2">{item["search_result"]}</span>
</div>
);
}
if (typeof item === "object" && item["graph"] && typeof item["graph"] === "object") {
parsedContent.push(
@ -289,6 +288,27 @@ function CellResult({ content = [] }) {
}
}
}
if (typeof(line) === "object" && line["result"]) {
parsedContent.push(
<div className="w-full h-full bg-white">
<span className="text-sm pl-2 mb-4">query response (dataset: {line["dataset_name"]})</span>
<span className="block px-2 py-2">{line["result"]}</span>
</div>
);
if (line["graphs"]) {
parsedContent.push(
<div className="w-full h-full bg-white">
<span className="text-sm pl-2 mb-4">reasoning graph</span>
<GraphVisualization
data={transformToVisualizationData(line["graphs"]["*"])}
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
graphControls={graphControls}
className="min-h-48"
/>
</div>
);
}
}
} catch (error) {
console.error(error);
parsedContent.push(line);
@ -298,45 +318,44 @@ function CellResult({ content = [] }) {
return parsedContent.map((item, index) => (
<div key={index} className="px-2 py-1">
{item}
{/* {typeof item === "object" && item["search_result"] && Array.isArray(item["search_result"]) && (
(item["search_result"] as []).map((result: string) => (<pre key={result.slice(0, -10)}>{result}</pre>))
)}
{typeof item === "object" && item["graph"] && typeof item["graph"] === "object" && (
(item["graph"])
)} */}
</div>
));
};
function transformToVisualizationData(triplets) {
function transformToVisualizationData(graph: { nodes: [], edges: [] }) {
// Implementation to transform triplet to visualization data
const nodes = {};
const links = {};
for (const triplet of triplets) {
nodes[triplet.source.id] = {
id: triplet.source.id,
label: triplet.source.attributes.name,
type: triplet.source.attributes.type,
attributes: triplet.source.attributes,
};
nodes[triplet.destination.id] = {
id: triplet.destination.id,
label: triplet.destination.attributes.name,
type: triplet.destination.attributes.type,
attributes: triplet.destination.attributes,
};
links[`${triplet.source.id}_${triplet.attributes.relationship_name}_${triplet.destination.id}`] = {
source: triplet.source.id,
target: triplet.destination.id,
label: triplet.attributes.relationship_name,
}
}
return {
nodes: Object.values(nodes),
links: Object.values(links),
nodes: graph.nodes,
links: graph.edges,
};
// const nodes = {};
// const links = {};
// for (const triplet of triplets) {
// nodes[triplet.source.id] = {
// id: triplet.source.id,
// label: triplet.source.attributes.name,
// type: triplet.source.attributes.type,
// attributes: triplet.source.attributes,
// };
// nodes[triplet.destination.id] = {
// id: triplet.destination.id,
// label: triplet.destination.attributes.name,
// type: triplet.destination.attributes.type,
// attributes: triplet.destination.attributes,
// };
// links[`${triplet.source.id}_${triplet.attributes.relationship_name}_${triplet.destination.id}`] = {
// source: triplet.source.id,
// target: triplet.destination.id,
// label: triplet.attributes.relationship_name,
// }
// }
// return {
// nodes: Object.values(nodes),
// links: Object.values(links),
// };
}

View file

@ -1,16 +1,16 @@
import { useState } from "react";
import classNames from "classnames";
import { useBoolean } from "@/utils";
import { LocalCogneeIcon, PlayIcon } from "@/ui/Icons";
import { PopupMenu } from "@/ui/elements";
import { PlayIcon } from "@/ui/Icons";
import { PopupMenu, IconButton, Select } from "@/ui/elements";
import { LoadingIndicator } from "@/ui/App";
import { Cell } from "./types";
import IconButton from "../IconButton";
interface NotebookCellHeaderProps {
cell: Cell;
runCell: (cell: Cell) => Promise<void>;
runCell: (cell: Cell, cogneeInstance: string) => Promise<void>;
renameCell: (cell: Cell) => void;
removeCell: (cell: Cell) => void;
moveCellUp: (cell: Cell) => void;
@ -33,9 +33,11 @@ export default function NotebookCellHeader({
setFalse: setIsNotRunningCell,
} = useBoolean(false);
const [runInstance, setRunInstance] = useState<string>("local");
const handleCellRun = () => {
setIsRunningCell();
runCell(cell)
runCell(cell, runInstance)
.then(() => {
setIsNotRunningCell();
});
@ -48,10 +50,14 @@ export default function NotebookCellHeader({
<span className="ml-4">{cell.name}</span>
</div>
<div className="pr-4 flex flex-row items-center gap-8">
<div className="flex flex-row items-center gap-2">
<LocalCogneeIcon className="text-indigo-700" />
<span className="text-xs">local cognee</span>
</div>
<Select name="cogneeInstance" onChange={(event) => setRunInstance(event.currentTarget.value)} className="!bg-transparent outline-none cursor-pointer !hover:bg-gray-50">
<option value="local" className="flex flex-row items-center gap-2">
local cognee
</option>
<option value="cloud" className="flex flex-row items-center gap-2">
cloud cognee
</option>
</Select>
<PopupMenu>
<div className="flex flex-col gap-0.5">
<button onClick={() => moveCellUp(cell)} className="hover:bg-gray-100 w-full text-left px-2 cursor-pointer">move cell up</button>

View file

@ -1,4 +1,5 @@
import handleServerErrors from "./handleServerErrors";
import isCloudEnvironment from "./isCloudEnvironment";
let numberOfRetries = 0;
@ -6,9 +7,10 @@ const isAuth0Enabled = process.env.USE_AUTH0_AUTHORIZATION?.toLowerCase() === "t
const backendApiUrl = process.env.NEXT_PUBLIC_BACKEND_API_URL || "http://localhost:8000/api";
const cloudApiUrl = process.env.NEXT_PUBLIC_CLOUD_API_URL || "https://api.cognee.ai/api";
const cloudApiUrl = process.env.NEXT_PUBLIC_CLOUD_API_URL || "http://localhost:8001/api";
let apiKey: string | null = null;
let accessToken: string | null = null;
export default async function fetch(url: string, options: RequestInit = {}, useCloud = false): Promise<Response> {
function retry(lastError: Response) {
@ -34,7 +36,10 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
...options,
headers: {
...options.headers,
...(useCloud ? {"X-Api-Key": apiKey!} : {}),
...(useCloud && !isCloudEnvironment()
? {"X-Api-Key": apiKey!}
: {"Authorization": `Bearer ${accessToken}`}
),
},
credentials: "include",
},
@ -66,3 +71,7 @@ fetch.checkHealth = () => {
fetch.setApiKey = (newApiKey: string) => {
apiKey = newApiKey;
};
fetch.setAccessToken = (newAccessToken: string) => {
accessToken = newAccessToken;
};

View file

@ -2,3 +2,4 @@ export { default as fetch } from "./fetch";
export { default as handleServerErrors } from "./handleServerErrors";
export { default as useBoolean } from "./useBoolean";
export { default as useOutsideClick } from "./useOutsideClick";
export { default as isCloudEnvironment } from "./isCloudEnvironment";

View file

@ -0,0 +1,4 @@
export default function isCloudEnvironment() {
return process.env.NEXT_PUBLIC_IS_CLOUD_ENVIRONMENT?.toLowerCase() === "true";
}

View file

@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends
from cognee.api.DTO import InDTO
from cognee.infrastructure.databases.relational import get_async_session
from cognee.infrastructure.utils.run_async import run_async
from cognee.modules.notebooks.models import Notebook, NotebookCell
from cognee.modules.notebooks.operations import run_in_local_sandbox
from cognee.modules.users.models import User
@ -74,7 +75,7 @@ def get_notebooks_router():
if notebook is None:
return JSONResponse(status_code=404, content={"error": "Notebook not found"})
result, error = run_in_local_sandbox(run_code.content)
result, error = await run_async(run_in_local_sandbox, run_code.content)
return JSONResponse(
status_code=200, content={"result": jsonable_encoder(result), "error": error}

View file

@ -8,8 +8,15 @@ def run_sync(coro, timeout=None):
def runner():
nonlocal result, exception
try:
result = asyncio.run(coro)
try:
running_loop = asyncio.get_running_loop()
result = asyncio.run_coroutine_threadsafe(coro, running_loop).result(timeout)
except RuntimeError:
result = asyncio.run(coro)
except Exception as e:
exception = e

View file

@ -5,7 +5,6 @@ import traceback
def wrap_in_async_handler(user_code: str) -> str:
return (
"import asyncio\n\n"
"from cognee.infrastructure.utils.run_sync import run_sync\n\n"
"async def __user_main__():\n"
+ "\n".join(" " + line for line in user_code.strip().split("\n"))
@ -28,19 +27,6 @@ def run_in_local_sandbox(code, environment=None):
printOutput = []
# def process_output(output):
# try:
# result = json.loads(
# re.sub(
# r"'([^']*)'", r'"\1"',
# re.sub(r"\bNone\b", "null", output)
# )
# )
# result = json.loads(output)
# return result
# except json.JSONDecodeError:
# return output
def customPrintFunction(output):
printOutput.append(output)

View file

@ -63,9 +63,10 @@ async def get_memory_fragment(
if properties_to_project is None:
properties_to_project = ["id", "description", "name", "type", "text"]
memory_fragment = CogneeGraph()
try:
graph_engine = await get_graph_engine()
memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(
graph_engine,

View file

@ -98,7 +98,9 @@ async def search(
query.id,
json.dumps(
jsonable_encoder(
await prepare_search_result(search_results)
await prepare_search_result(
search_results[0] if isinstance(search_results, list) else search_results
)
if use_combined_context
else [
await prepare_search_result(search_result) for search_result in search_results
@ -109,7 +111,9 @@ async def search(
)
if use_combined_context:
prepared_search_results = await prepare_search_result(search_results)
prepared_search_results = await prepare_search_result(
search_results[0] if isinstance(search_results, list) else search_results
)
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]

View file

@ -13,10 +13,10 @@ async def prepare_search_result(search_result):
context_texts = {}
if isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge):
result_graph = transform_context_to_graph(context)
context_graph = transform_context_to_graph(context)
graphs = {
"*": result_graph,
"*": context_graph,
}
context_texts = {
"*": await resolve_edges_to_text(context),