Compare commits
34 commits
main
...
crewai-dem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d7c07e483 | ||
|
|
fac930dc59 | ||
|
|
6da557565a | ||
|
|
4636b69664 | ||
|
|
dda6e92169 | ||
|
|
c9590ef760 | ||
|
|
ef929ba442 | ||
|
|
342cbc9461 | ||
|
|
ea1e23a7aa | ||
|
|
7df1daee71 | ||
|
|
950195223e | ||
|
|
b4b55b820d | ||
|
|
d7d626698d | ||
|
|
aecdff0503 | ||
|
|
4e373cfee7 | ||
|
|
f1e254f357 | ||
|
|
8f5d5b9ac2 | ||
|
|
ce14a441af | ||
|
|
f825732eb2 | ||
|
|
ecdf624bda | ||
|
|
1267f6c1e7 | ||
|
|
d8fde4c527 | ||
|
|
96d1dd772c | ||
|
|
cc52df94b7 | ||
|
|
ad9abb8b76 | ||
|
|
5d4f82fdd4 | ||
|
|
8aae9f8dd8 | ||
|
|
cd813c5732 | ||
|
|
7456567597 | ||
|
|
b29ab72c50 | ||
|
|
5cbdbf3abf | ||
|
|
cc4fab9e75 | ||
|
|
0c1e515c8f | ||
|
|
fe83a25576 |
103 changed files with 7449 additions and 1625 deletions
|
|
@ -156,6 +156,15 @@ Try cognee UI out locally [here](https://docs.cognee.ai/how-to-guides/cognee-ui)
|
|||
</div>
|
||||
|
||||
|
||||
## CrewAI
|
||||
|
||||
Note1: After each restart go to `localhost:3000/auth` and login again.
|
||||
Note2: Activity is not preserved in the DB, so it will be lost after page refresh.
|
||||
|
||||
1. Start FastAPI server by running `client.py` inside `cognee/api` directory
|
||||
2. Start NextJS server by running `npm run dev` inside `cognee-frontend` directory.
|
||||
3. If you are not logged-in, app will redirect to `/auth` page. Otherwise go there manually and login (if server is restarted).
|
||||
|
||||
|
||||
## Demos
|
||||
|
||||
|
|
|
|||
7
cognee-frontend/.prettierrc
Normal file
7
cognee-frontend/.prettierrc
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"trailingComma": "es5",
|
||||
"tabWidth": 2,
|
||||
"semi": true,
|
||||
"singleQuote": false,
|
||||
"plugins": ["prettier-plugin-tailwindcss"]
|
||||
}
|
||||
1739
cognee-frontend/package-lock.json
generated
1739
cognee-frontend/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "cognee-frontend",
|
||||
"version": "0.1.0",
|
||||
"version": "1.0.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev",
|
||||
|
|
@ -10,13 +10,17 @@
|
|||
},
|
||||
"dependencies": {
|
||||
"classnames": "^2.5.1",
|
||||
"next": "14.2.3",
|
||||
"d3-force-3d": "^3.0.6",
|
||||
"next": "15.3.2",
|
||||
"ohmy-ui": "^0.0.6",
|
||||
"react": "^18",
|
||||
"react-dom": "^18",
|
||||
"react-force-graph-2d": "^1.27.1",
|
||||
"tailwindcss": "^4.1.7",
|
||||
"uuid": "^9.0.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@tailwindcss/postcss": "^4.1.7",
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^18",
|
||||
"@types/react-dom": "^18",
|
||||
|
|
|
|||
5
cognee-frontend/postcss.config.mjs
Normal file
5
cognee-frontend/postcss.config.mjs
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
export default {
|
||||
plugins: {
|
||||
"@tailwindcss/postcss": {},
|
||||
}
|
||||
}
|
||||
90
cognee-frontend/src/app/(graph)/CogneeAddWidget.tsx
Normal file
90
cognee-frontend/src/app/(graph)/CogneeAddWidget.tsx
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"use client";
|
||||
|
||||
import { v4 as uuid4 } from "uuid";
|
||||
import { ChangeEvent, useEffect } from "react";
|
||||
import { CTAButton, StatusIndicator } from "@/ui/elements";
|
||||
|
||||
import addData from "@/modules/ingestion/addData";
|
||||
import cognifyDataset from "@/modules/datasets/cognifyDataset";
|
||||
import useDatasets from "@/modules/ingestion/useDatasets";
|
||||
import getDatasetGraph from '@/modules/datasets/getDatasetGraph';
|
||||
|
||||
export interface NodesAndEdges {
|
||||
nodes: { id: string; label: string }[];
|
||||
links: { source: string; target: string; label: string }[];
|
||||
}
|
||||
|
||||
interface CogneeAddWidgetProps {
|
||||
onData: (data: NodesAndEdges) => void;
|
||||
}
|
||||
|
||||
export default function CogneeAddWidget({ onData }: CogneeAddWidgetProps) {
|
||||
const {
|
||||
datasets,
|
||||
addDataset,
|
||||
removeDataset,
|
||||
refreshDatasets,
|
||||
} = useDatasets();
|
||||
|
||||
useEffect(() => {
|
||||
refreshDatasets()
|
||||
.then((datasets) => {
|
||||
const dataset = datasets?.[0];
|
||||
|
||||
// For CrewAI we don't have a dataset.
|
||||
// if (dataset) {
|
||||
getDatasetGraph(dataset || { id: uuid4() })
|
||||
.then((graph) => onData({
|
||||
nodes: graph.nodes,
|
||||
links: graph.edges,
|
||||
}));
|
||||
// }
|
||||
});
|
||||
}, [onData, refreshDatasets]);
|
||||
|
||||
const handleAddFiles = (dataset: { id?: string, name?: string }, event: ChangeEvent<HTMLInputElement>) => {
|
||||
event.stopPropagation();
|
||||
|
||||
if (!event.currentTarget.files) {
|
||||
throw new Error("Error: No files added to the uploader input.");
|
||||
}
|
||||
|
||||
const files: File[] = Array.from(event.currentTarget.files);
|
||||
|
||||
return addData(dataset, files)
|
||||
.then(() => {
|
||||
const onUpdate = (data: any) => {
|
||||
onData({
|
||||
nodes: data.payload.nodes,
|
||||
links: data.payload.edges,
|
||||
});
|
||||
};
|
||||
|
||||
return cognifyDataset(dataset, onUpdate);
|
||||
});
|
||||
};
|
||||
|
||||
return null;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4 mb-4">
|
||||
{datasets.length ? datasets.map((dataset) => (
|
||||
<div key={dataset.id} className="flex gap-8 items-center">
|
||||
<div className="flex flex-row gap-4 items-center">
|
||||
<StatusIndicator status={dataset.status} />
|
||||
<span className="text-white">{dataset.name}</span>
|
||||
</div>
|
||||
<CTAButton type="button" className="relative">
|
||||
<input type="file" multiple onChange={handleAddFiles.bind(null, dataset)} className="absolute w-full h-full cursor-pointer opacity-0" />
|
||||
<span>+ Add Data</span>
|
||||
</CTAButton>
|
||||
</div>
|
||||
)) : (
|
||||
<CTAButton type="button" className="relative">
|
||||
<input type="file" multiple onChange={handleAddFiles.bind(null, { name: "main_dataset" })} className="absolute w-full h-full cursor-pointer opacity-0" />
|
||||
<span>+ Add Data</span>
|
||||
</CTAButton>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
106
cognee-frontend/src/app/(graph)/CrewAITrigger.tsx
Normal file
106
cognee-frontend/src/app/(graph)/CrewAITrigger.tsx
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
import { useState } from "react";
|
||||
import { fetch } from "@/utils";
|
||||
import { v4 as uuid4 } from "uuid";
|
||||
import { LoadingIndicator } from "@/ui/App";
|
||||
import { CTAButton, Input } from "@/ui/elements";
|
||||
|
||||
interface CrewAIFormPayload extends HTMLFormElement {
|
||||
username1: HTMLInputElement;
|
||||
username2: HTMLInputElement;
|
||||
}
|
||||
|
||||
interface CrewAITriggerProps {
|
||||
onData: (data: any) => void;
|
||||
onActivity: (activities: any) => void;
|
||||
}
|
||||
|
||||
export default function CrewAITrigger({ onData, onActivity }: CrewAITriggerProps) {
|
||||
const [isCrewAIRunning, setIsCrewAIRunning] = useState(false);
|
||||
|
||||
const handleRunCrewAI = (event: React.FormEvent<CrewAIFormPayload>) => {
|
||||
event.preventDefault();
|
||||
const formElements = event.currentTarget;
|
||||
|
||||
const crewAIConfig = {
|
||||
username1: formElements.username1.value,
|
||||
username2: formElements.username2.value,
|
||||
};
|
||||
|
||||
const websocket = new WebSocket("ws://localhost:8000/api/v1/crewai/subscribe");
|
||||
|
||||
websocket.onopen = () => {
|
||||
websocket.send(JSON.stringify({
|
||||
"Authorization": `Bearer ${localStorage.getItem("access_token")}`,
|
||||
}));
|
||||
};
|
||||
|
||||
let isCrewAIDone = false;
|
||||
onActivity([{ id: uuid4(), timestamp: Date.now(), activity: "Running CrewAI" }]);
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
if (data.status === "PipelineRunActivity") {
|
||||
onActivity(data.payload);
|
||||
return;
|
||||
}
|
||||
|
||||
onData({
|
||||
nodes: data.payload.nodes,
|
||||
links: data.payload.edges,
|
||||
});
|
||||
|
||||
const nodes_type_map: { [key: string]: number } = {};
|
||||
|
||||
for (let i = 0; i < data.payload.nodes.length; i++) {
|
||||
const node = data.payload.nodes[i];
|
||||
if (!nodes_type_map[node.type]) {
|
||||
nodes_type_map[node.type] = 0;
|
||||
}
|
||||
nodes_type_map[node.type] += 1;
|
||||
}
|
||||
|
||||
const activityMessage = Object.entries(nodes_type_map).reduce((message, [type, count]) => {
|
||||
return `${message}\n | ${type}: ${count}`;
|
||||
}, "Graph updated:");
|
||||
|
||||
onActivity([{
|
||||
id: uuid4(),
|
||||
timestamp: Date.now(),
|
||||
activity: activityMessage,
|
||||
}]);
|
||||
|
||||
if (data.status === "PipelineRunCompleted") {
|
||||
isCrewAIDone = true;
|
||||
websocket.close();
|
||||
}
|
||||
};
|
||||
|
||||
setIsCrewAIRunning(true);
|
||||
|
||||
return fetch("/v1/crewai/run", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(crewAIConfig),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
})
|
||||
.then(response => response.json())
|
||||
.finally(() => {
|
||||
websocket.close();
|
||||
setIsCrewAIRunning(false);
|
||||
onActivity([{ id: uuid4(), timestamp: Date.now(), activity: "CrewAI run done" }]);
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<form className="w-full flex flex-row gap-2 items-center" onSubmit={handleRunCrewAI}>
|
||||
<Input name="username1" type="text" placeholder="Github Username" required defaultValue="hajdul88" />
|
||||
<Input name="username2" type="text" placeholder="Github Username" required defaultValue="lxobr" />
|
||||
<CTAButton type="submit" className="whitespace-nowrap">
|
||||
Run CrewAI
|
||||
{isCrewAIRunning && <LoadingIndicator />}
|
||||
</CTAButton>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
196
cognee-frontend/src/app/(graph)/GraphControls.tsx
Normal file
196
cognee-frontend/src/app/(graph)/GraphControls.tsx
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
"use client";
|
||||
|
||||
import { v4 as uuid4 } from "uuid";
|
||||
import classNames from "classnames";
|
||||
import { NodeObject } from "react-force-graph-2d";
|
||||
import { ChangeEvent, useImperativeHandle, useState } from "react";
|
||||
|
||||
import { DeleteIcon } from "@/ui/Icons";
|
||||
import { FeedbackForm } from "@/ui/Partials";
|
||||
import { CTAButton, Input, NeutralButton, Select } from "@/ui/elements";
|
||||
|
||||
interface GraphControlsProps {
|
||||
isAddNodeFormOpen: boolean;
|
||||
ref: React.RefObject<GraphControlsAPI>;
|
||||
onFitIntoView: () => void;
|
||||
onGraphShapeChange: (shape: string) => void;
|
||||
}
|
||||
|
||||
export interface GraphControlsAPI {
|
||||
setSelectedNode: (node: NodeObject | null) => void;
|
||||
getSelectedNode: () => NodeObject | null;
|
||||
updateActivity: (activities: ActivityLog[]) => void;
|
||||
}
|
||||
|
||||
type ActivityLog = {
|
||||
id: string;
|
||||
timestamp: number;
|
||||
activity: string;
|
||||
};
|
||||
|
||||
type NodeProperty = {
|
||||
id: string;
|
||||
name: string;
|
||||
value: string;
|
||||
};
|
||||
|
||||
const formatter = new Intl.DateTimeFormat("en-GB", { dateStyle: "short", timeStyle: "medium" });
|
||||
|
||||
export default function GraphControls({ isAddNodeFormOpen, onGraphShapeChange, onFitIntoView, ref }: GraphControlsProps) {
|
||||
const [selectedNode, setSelectedNode] = useState<NodeObject | null>(null);
|
||||
const [activityLog, setActivityLog] = useState<ActivityLog[]>([]);
|
||||
const [nodeProperties, setNodeProperties] = useState<NodeProperty[]>([]);
|
||||
const [newProperty, setNewProperty] = useState<NodeProperty>({
|
||||
id: uuid4(),
|
||||
name: "",
|
||||
value: "",
|
||||
});
|
||||
|
||||
const updateActivity = (newActivities: ActivityLog[]) => {
|
||||
setActivityLog((activities) => [...activities, ...newActivities]);
|
||||
};
|
||||
|
||||
const handlePropertyChange = (property: NodeProperty, property_key: string, event: ChangeEvent<HTMLInputElement>) => {
|
||||
const value = event.target.value;
|
||||
|
||||
setNodeProperties(nodeProperties.map((nodeProperty) => (nodeProperty.id === property.id ? {...nodeProperty, [property_key]: value } : nodeProperty)));
|
||||
};
|
||||
|
||||
const handlePropertyAdd = () => {
|
||||
if (newProperty.name && newProperty.value) {
|
||||
setNodeProperties([...nodeProperties, newProperty]);
|
||||
setNewProperty({ id: uuid4(), name: "", value: "" });
|
||||
} else {
|
||||
alert("Please fill in both name and value fields for the new property.");
|
||||
}
|
||||
};
|
||||
|
||||
const handlePropertyDelete = (property: NodeProperty) => {
|
||||
setNodeProperties(nodeProperties.filter((nodeProperty) => nodeProperty.id !== property.id));
|
||||
};
|
||||
|
||||
const handleNewPropertyChange = (property: NodeProperty, property_key: string, event: ChangeEvent<HTMLInputElement>) => {
|
||||
const value = event.target.value;
|
||||
|
||||
setNewProperty({...property, [property_key]: value });
|
||||
};
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
setSelectedNode,
|
||||
getSelectedNode: () => selectedNode,
|
||||
updateActivity,
|
||||
}));
|
||||
|
||||
const [selectedTab, setSelectedTab] = useState("nodeDetails");
|
||||
|
||||
const handleGraphShapeControl = (event: ChangeEvent<HTMLSelectElement>) => {
|
||||
onGraphShapeChange(event.target.value);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="flex w-full">
|
||||
<button onClick={() => setSelectedTab("nodeDetails")} className={classNames("cursor-pointer pt-4 pb-4 align-center text-gray-300 border-b-2 w-30 flex-1/3", { "border-b-indigo-600 text-white": selectedTab === "nodeDetails" })}>
|
||||
<span className="whitespace-nowrap">Node Details</span>
|
||||
</button>
|
||||
<button onClick={() => setSelectedTab("activityLog")} className={classNames("cursor-pointer pt-4 pb-4 align-center text-gray-300 border-b-2 w-30 flex-1/3", { "border-b-indigo-600 text-white": selectedTab === "activityLog" })}>
|
||||
<span className="whitespace-nowrap">Activity Log</span>
|
||||
</button>
|
||||
<button onClick={() => setSelectedTab("feedback")} className={classNames("cursor-pointer pt-4 pb-4 align-center text-gray-300 border-b-2 w-30 flex-1/3", { "border-b-indigo-600 text-white": selectedTab === "feedback" })}>
|
||||
<span className="whitespace-nowrap">Feedback</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="pt-4">
|
||||
{selectedTab === "nodeDetails" && (
|
||||
<>
|
||||
<div className="w-full flex flex-row gap-2 items-center mb-4">
|
||||
<label className="text-gray-300 whitespace-nowrap flex-1/5">Graph Shape:</label>
|
||||
<Select defaultValue="none" onChange={handleGraphShapeControl} className="flex-2/5">
|
||||
<option value="none">None</option>
|
||||
<option value="td">Top-down</option>
|
||||
<option value="bu">Bottom-up</option>
|
||||
<option value="lr">Left-right</option>
|
||||
<option value="rl">Right-left</option>
|
||||
<option value="radialin">Radial-in</option>
|
||||
<option value="radialout">Radial-out</option>
|
||||
</Select>
|
||||
<NeutralButton onClick={onFitIntoView} className="flex-2/5 whitespace-nowrap">Fit Graph into View</NeutralButton>
|
||||
</div>
|
||||
|
||||
|
||||
{isAddNodeFormOpen ? (
|
||||
<form className="flex flex-col gap-4" onSubmit={() => {}}>
|
||||
<div className="flex flex-row gap-4 items-center">
|
||||
<span className="text-gray-300 whitespace-nowrap">Source Node ID:</span>
|
||||
<Input readOnly type="text" defaultValue={selectedNode!.id} />
|
||||
</div>
|
||||
<div className="flex flex-col gap-4 items-end">
|
||||
{nodeProperties.map((property) => (
|
||||
<div key={property.id} className="w-full flex flex-row gap-2 items-center">
|
||||
<Input className="flex-1/3" type="text" placeholder="Property name" required value={property.name} onChange={handlePropertyChange.bind(null, property, "name")} />
|
||||
<Input className="flex-2/3" type="text" placeholder="Property value" required value={property.value} onChange={handlePropertyChange.bind(null, property, "value")} />
|
||||
<button className="border-1 border-white p-2 rounded-sm" onClick={handlePropertyDelete.bind(null, property)}>
|
||||
<DeleteIcon width={16} height={18} color="white" />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
<div className="w-full flex flex-row gap-2 items-center">
|
||||
<Input className="flex-1/3" type="text" placeholder="Property name" required value={newProperty.name} onChange={handleNewPropertyChange.bind(null, newProperty, "name")} />
|
||||
<Input className="flex-2/3" type="text" placeholder="Property value" required value={newProperty.value} onChange={handleNewPropertyChange.bind(null, newProperty, "value")} />
|
||||
<NeutralButton type="button" className="" onClick={handlePropertyAdd}>Add</NeutralButton>
|
||||
</div>
|
||||
</div>
|
||||
<CTAButton type="submit">Add Node</CTAButton>
|
||||
</form>
|
||||
) : (
|
||||
selectedNode ? (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-2 overflow-y-auto max-h-96">
|
||||
<div className="flex gap-2 items-top">
|
||||
<span className="text-gray-300">ID:</span>
|
||||
<span className="text-white">{selectedNode.id}</span>
|
||||
</div>
|
||||
<div className="flex gap-2 items-top">
|
||||
<span className="text-gray-300">Label:</span>
|
||||
<span className="text-white">{selectedNode.label}</span>
|
||||
</div>
|
||||
|
||||
{Object.entries(selectedNode.properties).map(([key, value]) => (
|
||||
<div key={key} className="flex gap-2 items-top">
|
||||
<span className="text-gray-300">{key.charAt(0).toUpperCase() + key.slice(1)}:</span>
|
||||
<span className="text-white truncate">{typeof value === "object" ? JSON.stringify(value) : value as string}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* <CTAButton type="button" onClick={() => {}}>Edit Node</CTAButton> */}
|
||||
</div>
|
||||
) : (
|
||||
<span className="text-white">No node selected.</span>
|
||||
)
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{selectedTab === "activityLog" && (
|
||||
<div className="flex flex-col gap-2">
|
||||
{activityLog.map((activity) => (
|
||||
<div key={activity.id} className="flex gap-2 items-top">
|
||||
<span className="text-gray-300 whitespace-nowrap">{formatter.format(activity.timestamp)}: </span>
|
||||
<span className="text-white whitespace-normal">{activity.activity}</span>
|
||||
</div>
|
||||
))}
|
||||
{!activityLog.length && <span className="text-white">No activity logged.</span>}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{selectedTab === "feedback" && (
|
||||
<div className="flex flex-col gap-2">
|
||||
<FeedbackForm onSuccess={() => {}} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
279
cognee-frontend/src/app/(graph)/GraphView.tsx
Normal file
279
cognee-frontend/src/app/(graph)/GraphView.tsx
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
"use client";
|
||||
|
||||
import { forceCollide, forceManyBody } from "d3-force-3d";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import ForceGraph, { ForceGraphMethods, LinkObject, NodeObject } from "react-force-graph-2d";
|
||||
|
||||
import { TextLogo } from "@/ui/App";
|
||||
import { Divider } from "@/ui/Layout";
|
||||
import { Footer } from "@/ui/Partials";
|
||||
import CrewAITrigger from "./CrewAITrigger";
|
||||
import CogneeAddWidget, { NodesAndEdges } from "./CogneeAddWidget";
|
||||
import GraphControls, { GraphControlsAPI } from "./GraphControls";
|
||||
|
||||
import { useBoolean } from "@/utils";
|
||||
|
||||
// import exampleData from "./example_data.json";
|
||||
|
||||
interface GraphNode {
|
||||
id: string | number;
|
||||
label: string;
|
||||
properties?: {};
|
||||
}
|
||||
|
||||
interface GraphData {
|
||||
nodes: GraphNode[];
|
||||
links: { source: string | number; target: string | number; label: string }[];
|
||||
}
|
||||
|
||||
export default function GraphView() {
|
||||
const {
|
||||
value: isAddNodeFormOpen,
|
||||
setTrue: enableAddNodeForm,
|
||||
setFalse: disableAddNodeForm,
|
||||
} = useBoolean(false);
|
||||
|
||||
const [data, updateData] = useState<GraphData | null>(null);
|
||||
|
||||
const onDataChange = useCallback((newData: NodesAndEdges) => {
|
||||
if (!newData.nodes.length && !newData.links.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
updateData({
|
||||
nodes: newData.nodes,
|
||||
links: newData.links,
|
||||
});
|
||||
}, []);
|
||||
|
||||
const graphRef = useRef<ForceGraphMethods>();
|
||||
|
||||
const graphControls = useRef<GraphControlsAPI>(null);
|
||||
|
||||
const onActivityChange = (activities: any) => {
|
||||
graphControls.current?.updateActivity(activities);
|
||||
};
|
||||
|
||||
const handleNodeClick = (node: NodeObject) => {
|
||||
graphControls.current?.setSelectedNode(node);
|
||||
graphRef.current?.d3ReheatSimulation();
|
||||
};
|
||||
|
||||
const textSize = 6;
|
||||
const nodeSize = 15;
|
||||
const addNodeDistanceFromSourceNode = 15;
|
||||
|
||||
const handleBackgroundClick = (event: MouseEvent) => {
|
||||
const selectedNode = graphControls.current?.getSelectedNode();
|
||||
|
||||
if (!selectedNode) {
|
||||
return;
|
||||
}
|
||||
|
||||
graphControls.current?.setSelectedNode(null);
|
||||
|
||||
// const graphBoundingBox = document.getElementById("graph-container")?.querySelector("canvas")?.getBoundingClientRect();
|
||||
// const x = event.clientX - graphBoundingBox!.x;
|
||||
// const y = event.clientY - graphBoundingBox!.y;
|
||||
|
||||
// const graphClickCoords = graphRef.current!.screen2GraphCoords(x, y);
|
||||
|
||||
// const distanceFromAddNode = Math.sqrt(
|
||||
// Math.pow(graphClickCoords.x - (selectedNode!.x! + addNodeDistanceFromSourceNode), 2)
|
||||
// + Math.pow(graphClickCoords.y - (selectedNode!.y! + addNodeDistanceFromSourceNode), 2)
|
||||
// );
|
||||
|
||||
// if (distanceFromAddNode <= 10) {
|
||||
// enableAddNodeForm();
|
||||
// } else {
|
||||
// disableAddNodeForm();
|
||||
// graphControls.current?.setSelectedNode(null);
|
||||
// }
|
||||
};
|
||||
|
||||
function renderNode(node: NodeObject, ctx: CanvasRenderingContext2D, globalScale: number) {
|
||||
const selectedNode = graphControls.current?.getSelectedNode();
|
||||
|
||||
ctx.save();
|
||||
|
||||
// if (node.id === selectedNode?.id) {
|
||||
// ctx.fillStyle = "gray";
|
||||
|
||||
// ctx.beginPath();
|
||||
// ctx.arc(node.x! + addNodeDistanceFromSourceNode, node.y! + addNodeDistanceFromSourceNode, 10, 0, 2 * Math.PI);
|
||||
// ctx.fill();
|
||||
|
||||
// ctx.beginPath();
|
||||
// ctx.moveTo(node.x! + addNodeDistanceFromSourceNode - 5, node.y! + addNodeDistanceFromSourceNode)
|
||||
// ctx.lineTo(node.x! + addNodeDistanceFromSourceNode - 5 + 10, node.y! + addNodeDistanceFromSourceNode);
|
||||
// ctx.stroke();
|
||||
|
||||
// ctx.beginPath();
|
||||
// ctx.moveTo(node.x! + addNodeDistanceFromSourceNode, node.y! + addNodeDistanceFromSourceNode - 5)
|
||||
// ctx.lineTo(node.x! + addNodeDistanceFromSourceNode, node.y! + addNodeDistanceFromSourceNode - 5 + 10);
|
||||
// ctx.stroke();
|
||||
// }
|
||||
|
||||
// ctx.beginPath();
|
||||
// ctx.arc(node.x, node.y, nodeSize, 0, 2 * Math.PI);
|
||||
// ctx.fill();
|
||||
|
||||
// draw text label (with background rect)
|
||||
const textPos = {
|
||||
x: node.x!,
|
||||
y: node.y!,
|
||||
};
|
||||
|
||||
ctx.translate(textPos.x, textPos.y);
|
||||
ctx.textAlign = "center";
|
||||
ctx.textBaseline = "middle";
|
||||
ctx.fillStyle = "#333333";
|
||||
ctx.font = `${textSize}px Sans-Serif`;
|
||||
ctx.fillText(node.label, 0, 0);
|
||||
|
||||
ctx.restore();
|
||||
}
|
||||
|
||||
function renderLink(link: LinkObject, ctx: CanvasRenderingContext2D) {
|
||||
const MAX_FONT_SIZE = 4;
|
||||
const LABEL_NODE_MARGIN = nodeSize * 1.5;
|
||||
|
||||
const start = link.source;
|
||||
const end = link.target;
|
||||
|
||||
// ignore unbound links
|
||||
if (typeof start !== "object" || typeof end !== "object") return;
|
||||
|
||||
const textPos = {
|
||||
x: start.x! + (end.x! - start.x!) / 2,
|
||||
y: start.y! + (end.y! - start.y!) / 2,
|
||||
};
|
||||
|
||||
const relLink = { x: end.x! - start.x!, y: end.y! - start.y! };
|
||||
|
||||
const maxTextLength = Math.sqrt(Math.pow(relLink.x, 2) + Math.pow(relLink.y, 2)) - LABEL_NODE_MARGIN * 2;
|
||||
|
||||
let textAngle = Math.atan2(relLink.y, relLink.x);
|
||||
// maintain label vertical orientation for legibility
|
||||
if (textAngle > Math.PI / 2) textAngle = -(Math.PI - textAngle);
|
||||
if (textAngle < -Math.PI / 2) textAngle = -(-Math.PI - textAngle);
|
||||
|
||||
const label = link.label
|
||||
|
||||
// estimate fontSize to fit in link length
|
||||
ctx.font = "1px Sans-Serif";
|
||||
const fontSize = Math.min(MAX_FONT_SIZE, maxTextLength / ctx.measureText(label).width);
|
||||
ctx.font = `${fontSize}px Sans-Serif`;
|
||||
const textWidth = ctx.measureText(label).width;
|
||||
const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2); // some padding
|
||||
|
||||
// draw text label (with background rect)
|
||||
ctx.save();
|
||||
ctx.translate(textPos.x, textPos.y);
|
||||
ctx.rotate(textAngle);
|
||||
|
||||
ctx.fillStyle = "rgba(255, 255, 255, 0.8)";
|
||||
ctx.fillRect(- bckgDimensions[0] / 2, - bckgDimensions[1] / 2, bckgDimensions[0], bckgDimensions[1]);
|
||||
|
||||
ctx.textAlign = "center";
|
||||
ctx.textBaseline = "middle";
|
||||
ctx.fillStyle = "darkgrey";
|
||||
ctx.fillText(label, 0, 0);
|
||||
ctx.restore();
|
||||
}
|
||||
|
||||
function handleDagError(loopNodeIds: (string | number)[]) {}
|
||||
|
||||
useEffect(() => {
|
||||
// add collision force
|
||||
graphRef.current!.d3Force("collision", forceCollide(nodeSize * 1.5));
|
||||
graphRef.current!.d3Force("charge", forceManyBody().strength(-1500).distanceMin(300).distanceMax(900));
|
||||
}, [data]);
|
||||
|
||||
const [graphShape, setGraphShape] = useState<string | undefined>(undefined);
|
||||
|
||||
return (
|
||||
<main className="flex flex-col h-full">
|
||||
<div className="pt-6 pr-3 pb-3 pl-6">
|
||||
<TextLogo width={86} height={24} />
|
||||
</div>
|
||||
<Divider />
|
||||
<div className="w-full h-full relative overflow-hidden">
|
||||
<div className="w-full h-full" id="graph-container">
|
||||
{data ? (
|
||||
<ForceGraph
|
||||
ref={graphRef}
|
||||
dagMode={graphShape as undefined}
|
||||
dagLevelDistance={300}
|
||||
onDagError={handleDagError}
|
||||
graphData={data}
|
||||
|
||||
nodeLabel="label"
|
||||
nodeRelSize={nodeSize}
|
||||
nodeCanvasObject={renderNode}
|
||||
nodeCanvasObjectMode={() => "after"}
|
||||
nodeAutoColorBy="type"
|
||||
|
||||
linkLabel="label"
|
||||
linkCanvasObject={renderLink}
|
||||
linkCanvasObjectMode={() => "after"}
|
||||
linkDirectionalArrowLength={3.5}
|
||||
linkDirectionalArrowRelPos={1}
|
||||
|
||||
onNodeClick={handleNodeClick}
|
||||
onBackgroundClick={handleBackgroundClick}
|
||||
d3VelocityDecay={0.3}
|
||||
/>
|
||||
) : (
|
||||
<ForceGraph
|
||||
ref={graphRef}
|
||||
dagMode="lr"
|
||||
dagLevelDistance={100}
|
||||
graphData={{
|
||||
nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
|
||||
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
|
||||
}}
|
||||
|
||||
nodeLabel="label"
|
||||
nodeRelSize={20}
|
||||
nodeCanvasObject={renderNode}
|
||||
nodeCanvasObjectMode={() => "after"}
|
||||
nodeAutoColorBy="type"
|
||||
|
||||
linkLabel="label"
|
||||
linkCanvasObject={renderLink}
|
||||
linkCanvasObjectMode={() => "after"}
|
||||
linkDirectionalArrowLength={3.5}
|
||||
linkDirectionalArrowRelPos={1}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="absolute top-2 left-2 bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md max-w-2xl">
|
||||
<CogneeAddWidget onData={onDataChange} />
|
||||
<CrewAITrigger onData={onDataChange} onActivity={onActivityChange} />
|
||||
</div>
|
||||
|
||||
<div className="absolute top-2 right-2 bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-110">
|
||||
<GraphControls
|
||||
ref={graphControls}
|
||||
isAddNodeFormOpen={isAddNodeFormOpen}
|
||||
onFitIntoView={() => graphRef.current?.zoomToFit(1000, 50)}
|
||||
onGraphShapeChange={setGraphShape}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<Divider />
|
||||
<div className="pl-6 pr-6">
|
||||
<Footer>
|
||||
{(data?.nodes.length || data?.links.length) && (
|
||||
<div className="flex flex-row items-center gap-6">
|
||||
<span>Nodes: {data?.nodes.length || 0}</span>
|
||||
<span>Edges: {data?.links.length || 0}</span>
|
||||
</div>
|
||||
)}
|
||||
</Footer>
|
||||
</div>
|
||||
</main>
|
||||
);
|
||||
}
|
||||
1376
cognee-frontend/src/app/(graph)/example_data.json
Normal file
1376
cognee-frontend/src/app/(graph)/example_data.json
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,16 +0,0 @@
|
|||
.main {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
flex-direction: column;
|
||||
padding: 0;
|
||||
min-height: 100vh;
|
||||
}
|
||||
|
||||
.authContainer {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
padding: 24px 0;
|
||||
margin: 0 auto;
|
||||
max-width: 440px;
|
||||
width: 100%;
|
||||
}
|
||||
|
|
@ -1,29 +1,24 @@
|
|||
import { Spacer, Stack, Text } from 'ohmy-ui';
|
||||
import { TextLogo } from '@/ui/App';
|
||||
import Footer from '@/ui/Partials/Footer/Footer';
|
||||
|
||||
import styles from './AuthPage.module.css';
|
||||
import { Divider } from '@/ui/Layout';
|
||||
import SignInForm from '@/ui/Partials/SignInForm/SignInForm';
|
||||
import { TextLogo } from "@/ui/App";
|
||||
import { Divider } from "@/ui/Layout";
|
||||
import Footer from "@/ui/Partials/Footer/Footer";
|
||||
import SignInForm from "@/ui/Partials/SignInForm/SignInForm";
|
||||
|
||||
export default function AuthPage() {
|
||||
return (
|
||||
<main className={styles.main}>
|
||||
<Spacer inset vertical="2" horizontal="2">
|
||||
<Stack orientation="horizontal" gap="between" align="center">
|
||||
<TextLogo width={158} height={44} color="white" />
|
||||
</Stack>
|
||||
</Spacer>
|
||||
<Divider />
|
||||
<div className={styles.authContainer}>
|
||||
<Stack gap="4" style={{ width: '100%' }}>
|
||||
<h1><Text size="large">Sign in</Text></h1>
|
||||
<SignInForm />
|
||||
</Stack>
|
||||
<main className="flex flex-col h-full">
|
||||
<div className="pt-6 pr-3 pb-3 pl-6">
|
||||
<TextLogo width={86} height={24} />
|
||||
</div>
|
||||
<Spacer inset horizontal="3" wrap>
|
||||
<Divider />
|
||||
<div className="w-full max-w-md pt-12 pb-6 m-auto">
|
||||
<div className="flex flex-col w-full gap-8">
|
||||
<h1><span className="text-xl">Sign in</span></h1>
|
||||
<SignInForm />
|
||||
</div>
|
||||
</div>
|
||||
<div className="pl-6 pr-6">
|
||||
<Footer />
|
||||
</Spacer>
|
||||
</div>
|
||||
</main>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,23 +15,16 @@
|
|||
--textarea-default-color: #0D051C !important;
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
html,
|
||||
body {
|
||||
height: 100%;
|
||||
max-width: 100vw;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
body {
|
||||
background: var(--global-background-default);
|
||||
}
|
||||
|
||||
a {
|
||||
color: inherit;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
@import "tailwindcss";
|
||||
|
|
|
|||
130
cognee-frontend/src/app/page copy.tsx
Normal file
130
cognee-frontend/src/app/page copy.tsx
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
'use client';
|
||||
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import styles from "./page.module.css";
|
||||
import { GhostButton, Notification, NotificationContainer, Spacer, Stack, Text, useBoolean, useNotifications } from 'ohmy-ui';
|
||||
import useDatasets from '@/modules/ingestion/useDatasets';
|
||||
import DataView, { Data } from '@/modules/ingestion/DataView';
|
||||
import DatasetsView from '@/modules/ingestion/DatasetsView';
|
||||
import classNames from 'classnames';
|
||||
import addData from '@/modules/ingestion/addData';
|
||||
import cognifyDataset from '@/modules/datasets/cognifyDataset';
|
||||
import getDatasetData from '@/modules/datasets/getDatasetData';
|
||||
import { Footer, SettingsModal } from '@/ui/Partials';
|
||||
import { TextLogo } from '@/ui/App';
|
||||
import { SettingsIcon } from '@/ui/Icons';
|
||||
|
||||
export default function Home() {
|
||||
const {
|
||||
datasets,
|
||||
refreshDatasets,
|
||||
} = useDatasets();
|
||||
|
||||
const [datasetData, setDatasetData] = useState<Data[]>([]);
|
||||
const [selectedDataset, setSelectedDataset] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
refreshDatasets();
|
||||
}, [refreshDatasets]);
|
||||
|
||||
const openDatasetData = (dataset: { id: string }) => {
|
||||
getDatasetData(dataset)
|
||||
.then(setDatasetData)
|
||||
.then(() => setSelectedDataset(dataset.id));
|
||||
};
|
||||
|
||||
const closeDatasetData = () => {
|
||||
setDatasetData([]);
|
||||
setSelectedDataset(null);
|
||||
};
|
||||
|
||||
const { notifications, showNotification } = useNotifications();
|
||||
|
||||
const onDataAdd = useCallback((dataset: { id: string }, files: File[]) => {
|
||||
return addData(dataset, files)
|
||||
.then(() => {
|
||||
showNotification("Data added successfully. Please run \"Cognify\" when ready.", 5000);
|
||||
openDatasetData(dataset);
|
||||
});
|
||||
}, [showNotification])
|
||||
|
||||
const onDatasetCognify = useCallback((dataset: { id: string, name: string }) => {
|
||||
showNotification(`Cognification started for dataset "${dataset.name}".`, 5000);
|
||||
|
||||
return cognifyDataset(dataset)
|
||||
.then(() => {
|
||||
showNotification(`Dataset "${dataset.name}" cognified.`, 5000);
|
||||
})
|
||||
.catch(() => {
|
||||
showNotification(`Dataset "${dataset.name}" cognification failed. Please try again.`, 5000);
|
||||
});
|
||||
}, [showNotification]);
|
||||
|
||||
const onCognify = useCallback(() => {
|
||||
const dataset = datasets.find((dataset) => dataset.id === selectedDataset);
|
||||
return onDatasetCognify({
|
||||
id: dataset!.id,
|
||||
name: dataset!.name,
|
||||
});
|
||||
}, [datasets, onDatasetCognify, selectedDataset]);
|
||||
|
||||
const {
|
||||
value: isSettingsModalOpen,
|
||||
setTrue: openSettingsModal,
|
||||
setFalse: closeSettingsModal,
|
||||
} = useBoolean(false);
|
||||
|
||||
return (
|
||||
<main className={styles.main}>
|
||||
<Spacer inset vertical="2" horizontal="2">
|
||||
<Stack orientation="horizontal" gap="between" align="center">
|
||||
<TextLogo width={158} height={44} color="white" />
|
||||
<GhostButton hugContent onClick={openSettingsModal}>
|
||||
<SettingsIcon />
|
||||
</GhostButton>
|
||||
</Stack>
|
||||
</Spacer>
|
||||
<SettingsModal isOpen={isSettingsModalOpen} onClose={closeSettingsModal} />
|
||||
<Spacer inset vertical="1" horizontal="3">
|
||||
<div className={styles.data}>
|
||||
<div className={classNames(styles.datasetsView, {
|
||||
[styles.openDatasetData]: datasetData.length > 0,
|
||||
})}>
|
||||
<DatasetsView
|
||||
datasets={datasets}
|
||||
onDatasetClick={openDatasetData}
|
||||
onDatasetCognify={onDatasetCognify}
|
||||
/>
|
||||
</div>
|
||||
{datasetData.length > 0 && selectedDataset && (
|
||||
<div className={styles.dataView}>
|
||||
<DataView
|
||||
data={datasetData}
|
||||
datasetId={selectedDataset}
|
||||
onClose={closeDatasetData}
|
||||
onDataAdd={onDataAdd}
|
||||
onCognify={onCognify}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Spacer>
|
||||
<Spacer inset horizontal="3" wrap>
|
||||
<Footer />
|
||||
</Spacer>
|
||||
<NotificationContainer gap="1" bottom right>
|
||||
{notifications.map((notification, index: number) => (
|
||||
<Notification
|
||||
key={notification.id}
|
||||
isOpen={notification.isOpen}
|
||||
style={{ top: `${index * 60}px` }}
|
||||
expireIn={notification.expireIn}
|
||||
onClose={notification.delete}
|
||||
>
|
||||
<Text nowrap>{notification.message}</Text>
|
||||
</Notification>
|
||||
))}
|
||||
</NotificationContainer>
|
||||
</main>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,130 +1 @@
|
|||
'use client';
|
||||
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import styles from "./page.module.css";
|
||||
import { GhostButton, Notification, NotificationContainer, Spacer, Stack, Text, useBoolean, useNotifications } from 'ohmy-ui';
|
||||
import useDatasets from '@/modules/ingestion/useDatasets';
|
||||
import DataView, { Data } from '@/modules/ingestion/DataView';
|
||||
import DatasetsView from '@/modules/ingestion/DatasetsView';
|
||||
import classNames from 'classnames';
|
||||
import addData from '@/modules/ingestion/addData';
|
||||
import cognifyDataset from '@/modules/datasets/cognifyDataset';
|
||||
import getDatasetData from '@/modules/datasets/getDatasetData';
|
||||
import { Footer, SettingsModal } from '@/ui/Partials';
|
||||
import { TextLogo } from '@/ui/App';
|
||||
import { SettingsIcon } from '@/ui/Icons';
|
||||
|
||||
export default function Home() {
|
||||
const {
|
||||
datasets,
|
||||
refreshDatasets,
|
||||
} = useDatasets();
|
||||
|
||||
const [datasetData, setDatasetData] = useState<Data[]>([]);
|
||||
const [selectedDataset, setSelectedDataset] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
refreshDatasets();
|
||||
}, [refreshDatasets]);
|
||||
|
||||
const openDatasetData = (dataset: { id: string }) => {
|
||||
getDatasetData(dataset)
|
||||
.then(setDatasetData)
|
||||
.then(() => setSelectedDataset(dataset.id));
|
||||
};
|
||||
|
||||
const closeDatasetData = () => {
|
||||
setDatasetData([]);
|
||||
setSelectedDataset(null);
|
||||
};
|
||||
|
||||
const { notifications, showNotification } = useNotifications();
|
||||
|
||||
const onDataAdd = useCallback((dataset: { id: string }, files: File[]) => {
|
||||
return addData(dataset, files)
|
||||
.then(() => {
|
||||
showNotification("Data added successfully. Please run \"Cognify\" when ready.", 5000);
|
||||
openDatasetData(dataset);
|
||||
});
|
||||
}, [showNotification])
|
||||
|
||||
const onDatasetCognify = useCallback((dataset: { id: string, name: string }) => {
|
||||
showNotification(`Cognification started for dataset "${dataset.name}".`, 5000);
|
||||
|
||||
return cognifyDataset(dataset)
|
||||
.then(() => {
|
||||
showNotification(`Dataset "${dataset.name}" cognified.`, 5000);
|
||||
})
|
||||
.catch(() => {
|
||||
showNotification(`Dataset "${dataset.name}" cognification failed. Please try again.`, 5000);
|
||||
});
|
||||
}, [showNotification]);
|
||||
|
||||
const onCognify = useCallback(() => {
|
||||
const dataset = datasets.find((dataset) => dataset.id === selectedDataset);
|
||||
return onDatasetCognify({
|
||||
id: dataset!.id,
|
||||
name: dataset!.name,
|
||||
});
|
||||
}, [datasets, onDatasetCognify, selectedDataset]);
|
||||
|
||||
const {
|
||||
value: isSettingsModalOpen,
|
||||
setTrue: openSettingsModal,
|
||||
setFalse: closeSettingsModal,
|
||||
} = useBoolean(false);
|
||||
|
||||
return (
|
||||
<main className={styles.main}>
|
||||
<Spacer inset vertical="2" horizontal="2">
|
||||
<Stack orientation="horizontal" gap="between" align="center">
|
||||
<TextLogo width={158} height={44} color="white" />
|
||||
<GhostButton hugContent onClick={openSettingsModal}>
|
||||
<SettingsIcon />
|
||||
</GhostButton>
|
||||
</Stack>
|
||||
</Spacer>
|
||||
<SettingsModal isOpen={isSettingsModalOpen} onClose={closeSettingsModal} />
|
||||
<Spacer inset vertical="1" horizontal="3">
|
||||
<div className={styles.data}>
|
||||
<div className={classNames(styles.datasetsView, {
|
||||
[styles.openDatasetData]: datasetData.length > 0,
|
||||
})}>
|
||||
<DatasetsView
|
||||
datasets={datasets}
|
||||
onDatasetClick={openDatasetData}
|
||||
onDatasetCognify={onDatasetCognify}
|
||||
/>
|
||||
</div>
|
||||
{datasetData.length > 0 && selectedDataset && (
|
||||
<div className={styles.dataView}>
|
||||
<DataView
|
||||
data={datasetData}
|
||||
datasetId={selectedDataset}
|
||||
onClose={closeDatasetData}
|
||||
onDataAdd={onDataAdd}
|
||||
onCognify={onCognify}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Spacer>
|
||||
<Spacer inset horizontal="3" wrap>
|
||||
<Footer />
|
||||
</Spacer>
|
||||
<NotificationContainer gap="1" bottom right>
|
||||
{notifications.map((notification, index: number) => (
|
||||
<Notification
|
||||
key={notification.id}
|
||||
isOpen={notification.isOpen}
|
||||
style={{ top: `${index * 60}px` }}
|
||||
expireIn={notification.expireIn}
|
||||
onClose={notification.delete}
|
||||
>
|
||||
<Text nowrap>{notification.message}</Text>
|
||||
</Notification>
|
||||
))}
|
||||
</NotificationContainer>
|
||||
</main>
|
||||
);
|
||||
}
|
||||
export { default } from "./(graph)/GraphView";
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { fetch } from '@/utils';
|
||||
|
||||
export default function cognifyDataset(dataset: { id?: string, name?: string }) {
|
||||
export default function cognifyDataset(dataset: { id?: string, name?: string }, onUpdate = (data: []) => {}) {
|
||||
return fetch('/v1/cognify', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
|
|
@ -9,5 +9,35 @@ export default function cognifyDataset(dataset: { id?: string, name?: string })
|
|||
body: JSON.stringify({
|
||||
datasets: [dataset.id || dataset.name],
|
||||
}),
|
||||
}).then((response) => response.json());
|
||||
})
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
const websocket = new WebSocket(`ws://localhost:8000/api/v1/cognify/subscribe/${data.pipeline_run_id}`);
|
||||
|
||||
websocket.onopen = () => {
|
||||
websocket.send(JSON.stringify({
|
||||
"Authorization": `Bearer ${localStorage.getItem("access_token")}`,
|
||||
}));
|
||||
};
|
||||
|
||||
let isCognifyDone = false;
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
onUpdate(data);
|
||||
|
||||
if (data.status === "PipelineRunCompleted") {
|
||||
isCognifyDone = true;
|
||||
websocket.close();
|
||||
}
|
||||
};
|
||||
|
||||
return new Promise(async (resolve) => {
|
||||
while (!isCognifyDone) {
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
}
|
||||
|
||||
resolve(true);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
|
|||
6
cognee-frontend/src/modules/datasets/getDatasetGraph.ts
Normal file
6
cognee-frontend/src/modules/datasets/getDatasetGraph.ts
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
import { fetch } from '@/utils';
|
||||
|
||||
export default function getDatasetGraph(dataset: { id: string }) {
|
||||
return fetch(`/v1/datasets/${dataset.id}/graph`)
|
||||
.then((response) => response.json());
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import { useState } from 'react';
|
||||
import Link from 'next/link';
|
||||
import { Explorer } from '@/ui/Partials';
|
||||
import StatusIcon from './StatusIcon';
|
||||
import StatusIcon from '@/ui/elements/StatusIndicator';
|
||||
import { LoadingIndicator } from '@/ui/App';
|
||||
import { DropdownMenu, GhostButton, Stack, Text, CTAButton, useBoolean, Modal, Spacer } from "ohmy-ui";
|
||||
import styles from "./DatasetsView.module.css";
|
||||
|
|
|
|||
|
|
@ -1,15 +0,0 @@
|
|||
export default function StatusIcon({ status }: { status: 'DATASET_PROCESSING_COMPLETED' | string }) {
|
||||
const isSuccess = status === 'DATASET_PROCESSING_COMPLETED';
|
||||
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
width: '16px',
|
||||
height: '16px',
|
||||
borderRadius: '4px',
|
||||
background: isSuccess ? '#53ff24' : '#ff5024',
|
||||
}}
|
||||
title={isSuccess ? 'Dataset cognified' : 'Cognify data in order to explore it'}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
@ -42,7 +42,7 @@ function useDatasets() {
|
|||
|
||||
statusTimeout.current = setTimeout(() => {
|
||||
checkDatasetStatuses(datasets);
|
||||
}, 5000);
|
||||
}, 50000);
|
||||
}, [fetchDatasetStatuses]);
|
||||
|
||||
useEffect(() => {
|
||||
|
|
@ -73,7 +73,7 @@ function useDatasets() {
|
|||
}, []);
|
||||
|
||||
const fetchDatasets = useCallback(() => {
|
||||
fetch('/v1/datasets', {
|
||||
return fetch('/v1/datasets', {
|
||||
headers: {
|
||||
Authorization: `Bearer ${localStorage.getItem('access_token')}`,
|
||||
},
|
||||
|
|
@ -84,9 +84,9 @@ function useDatasets() {
|
|||
|
||||
if (datasets.length > 0) {
|
||||
checkDatasetStatuses(datasets);
|
||||
} else {
|
||||
window.location.href = '/wizard';
|
||||
}
|
||||
|
||||
return datasets;
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
|
||||
.loadingIndicator {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
width: 1rem;
|
||||
height: 1rem;
|
||||
border-radius: 50%;
|
||||
border: 2px solid var(--global-color-primary);
|
||||
border: 0.18rem solid white;
|
||||
border-top-color: transparent;
|
||||
border-bottom-color: transparent;
|
||||
animation: spin 2s linear infinite;
|
||||
|
|
|
|||
7
cognee-frontend/src/ui/Icons/DeleteIcon.tsx
Normal file
7
cognee-frontend/src/ui/Icons/DeleteIcon.tsx
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
export default function DeleteIcon({ width = 12, height = 14, color = 'currentColor' }) {
|
||||
return (
|
||||
<svg width={width} height={height} viewBox="0 0 12 14" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M3.625 1.87357H3.5C3.56875 1.87357 3.625 1.81732 3.625 1.74857V1.87357H8.375V1.74857C8.375 1.81732 8.43125 1.87357 8.5 1.87357H8.375V2.99857H9.5V1.74857C9.5 1.197 9.05156 0.748566 8.5 0.748566H3.5C2.94844 0.748566 2.5 1.197 2.5 1.74857V2.99857H3.625V1.87357ZM11.5 2.99857H0.5C0.223438 2.99857 0 3.222 0 3.49857V3.99857C0 4.06732 0.05625 4.12357 0.125 4.12357H1.06875L1.45469 12.2954C1.47969 12.8283 1.92031 13.2486 2.45313 13.2486H9.54688C10.0813 13.2486 10.5203 12.8298 10.5453 12.2954L10.9313 4.12357H11.875C11.9438 4.12357 12 4.06732 12 3.99857V3.49857C12 3.222 11.7766 2.99857 11.5 2.99857ZM9.42656 12.1236H2.57344L2.19531 4.12357H9.80469L9.42656 12.1236Z" fill={color} />
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
|
@ -3,7 +3,7 @@ export default function GitHubIcon({ width = 24, height = 24, color = 'currentCo
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width={width} height={height} viewBox="0 0 28 28" className={className}>
|
||||
<g transform="translate(-1477 -38)">
|
||||
<rect width="28" height="28" transform="translate(1477 38)" fill={color} opacity="0" />
|
||||
<path d="M16.142,1.9A13.854,13.854,0,0,0,11.78,28.966c.641.128,1.155-.577,1.155-1.154v-1.86c-3.848.834-5.067-1.86-5.067-1.86a4.169,4.169,0,0,0-1.411-2.052c-1.283-.9.064-.834.064-.834a2.758,2.758,0,0,1,2.117,1.283c1.09,1.86,3.528,1.668,4.3,1.347a3.463,3.463,0,0,1,.321-1.86c-4.361-.77-6.735-3.335-6.735-6.8A6.863,6.863,0,0,1,8.381,10.3a3.977,3.977,0,0,1,.192-4.1,5.708,5.708,0,0,1,4.1,1.86,9.685,9.685,0,0,1,3.463-.513,10.968,10.968,0,0,1,3.463.449,5.773,5.773,0,0,1,4.1-1.8,4.169,4.169,0,0,1,.257,4.1,6.863,6.863,0,0,1,1.8,4.875c0,3.463-2.373,6.029-6.735,6.8a3.464,3.464,0,0,1,.321,1.86v3.977a1.155,1.155,0,0,0,1.219,1.155A13.918,13.918,0,0,0,16.142,1.9Z" transform="translate(1474.913 36.102)" fill="#fdfdfd"/>
|
||||
<path d="M16.142,1.9A13.854,13.854,0,0,0,11.78,28.966c.641.128,1.155-.577,1.155-1.154v-1.86c-3.848.834-5.067-1.86-5.067-1.86a4.169,4.169,0,0,0-1.411-2.052c-1.283-.9.064-.834.064-.834a2.758,2.758,0,0,1,2.117,1.283c1.09,1.86,3.528,1.668,4.3,1.347a3.463,3.463,0,0,1,.321-1.86c-4.361-.77-6.735-3.335-6.735-6.8A6.863,6.863,0,0,1,8.381,10.3a3.977,3.977,0,0,1,.192-4.1,5.708,5.708,0,0,1,4.1,1.86,9.685,9.685,0,0,1,3.463-.513,10.968,10.968,0,0,1,3.463.449,5.773,5.773,0,0,1,4.1-1.8,4.169,4.169,0,0,1,.257,4.1,6.863,6.863,0,0,1,1.8,4.875c0,3.463-2.373,6.029-6.735,6.8a3.464,3.464,0,0,1,.321,1.86v3.977a1.155,1.155,0,0,0,1.219,1.155A13.918,13.918,0,0,0,16.142,1.9Z" transform="translate(1474.913 36.102)" fill={color}/>
|
||||
</g>
|
||||
</svg>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
export { default as DeleteIcon } from './DeleteIcon';
|
||||
export { default as GithubIcon } from './GitHubIcon';
|
||||
export { default as DiscordIcon } from './DiscordIcon';
|
||||
export { default as SettingsIcon } from './SettingsIcon';
|
||||
|
|
|
|||
69
cognee-frontend/src/ui/Partials/FeedbackForm.tsx
Normal file
69
cognee-frontend/src/ui/Partials/FeedbackForm.tsx
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { LoadingIndicator } from "@/ui/App";
|
||||
import { fetch, useBoolean } from "@/utils";
|
||||
import { CTAButton, TextArea } from "@/ui/elements";
|
||||
|
||||
interface SignInFormPayload extends HTMLFormElement {
|
||||
feedback: HTMLTextAreaElement;
|
||||
}
|
||||
|
||||
interface FeedbackFormProps {
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
export default function FeedbackForm({ onSuccess }: FeedbackFormProps) {
|
||||
const {
|
||||
value: isSubmittingFeedback,
|
||||
setTrue: disableFeedbackSubmit,
|
||||
setFalse: enableFeedbackSubmit,
|
||||
} = useBoolean(false);
|
||||
|
||||
const [feedbackError, setFeedbackError] = useState<string | null>(null);
|
||||
|
||||
const signIn = (event: React.FormEvent<SignInFormPayload>) => {
|
||||
event.preventDefault();
|
||||
const formElements = event.currentTarget;
|
||||
|
||||
setFeedbackError(null);
|
||||
disableFeedbackSubmit();
|
||||
|
||||
fetch("/v1/crewai/feedback", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
feedback: formElements.feedback.value,
|
||||
}),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(() => {
|
||||
onSuccess();
|
||||
formElements.feedback.value = "";
|
||||
})
|
||||
.catch(error => setFeedbackError(error.detail))
|
||||
.finally(() => enableFeedbackSubmit());
|
||||
};
|
||||
|
||||
return (
|
||||
<form onSubmit={signIn} className="flex flex-col gap-2">
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="mb-4">
|
||||
<label className="block text-white" htmlFor="feedback">Feedback on agent's reasoning</label>
|
||||
<TextArea id="feedback" name="feedback" type="text" placeholder="Your feedback" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<CTAButton type="submit">
|
||||
<span>Submit feedback</span>
|
||||
{isSubmittingFeedback && <LoadingIndicator />}
|
||||
</CTAButton>
|
||||
|
||||
{feedbackError && (
|
||||
<span className="text-s text-white">{feedbackError}</span>
|
||||
)}
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
.footer {
|
||||
padding: 24px 0;
|
||||
}
|
||||
|
||||
.leftSide {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.rightSide {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
gap: 24px;
|
||||
}
|
||||
|
|
@ -1,25 +1,25 @@
|
|||
import Link from 'next/link';
|
||||
import { Stack } from 'ohmy-ui';
|
||||
import { DiscordIcon, GithubIcon } from '@/ui/Icons';
|
||||
// import { TextLogo } from '@/ui/App';
|
||||
import styles from './Footer.module.css';
|
||||
import Link from "next/link";
|
||||
import { DiscordIcon, GithubIcon } from "@/ui/Icons";
|
||||
|
||||
export default function Footer() {
|
||||
interface FooterProps {
|
||||
children?: React.ReactNode;
|
||||
}
|
||||
|
||||
export default function Footer({ children }: FooterProps) {
|
||||
return (
|
||||
<footer className={styles.footer}>
|
||||
<Stack orientation="horizontal" gap="between">
|
||||
<div className={styles.leftSide}>
|
||||
{/* <TextLogo width={92} height={24} /> */}
|
||||
</div>
|
||||
<div className={styles.rightSide}>
|
||||
<Link target="_blank" href="https://github.com/topoteretes/cognee">
|
||||
<GithubIcon color="white" />
|
||||
</Link>
|
||||
<Link target="_blank" href="https://discord.gg/m63hxKsp4p">
|
||||
<DiscordIcon color="white" />
|
||||
</Link>
|
||||
</div>
|
||||
</Stack>
|
||||
<footer className="pt-6 pb-6 flex flex-row items-center justify-between">
|
||||
<div>
|
||||
{children}
|
||||
</div>
|
||||
|
||||
<div className="flex flex-row gap-4">
|
||||
<Link target="_blank" href="https://github.com/topoteretes/cognee">
|
||||
<GithubIcon color="black" />
|
||||
</Link>
|
||||
<Link target="_blank" href="https://discord.gg/m63hxKsp4p">
|
||||
<DiscordIcon color="black" />
|
||||
</Link>
|
||||
</div>
|
||||
</footer>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,19 +1,9 @@
|
|||
"use client";
|
||||
|
||||
import {
|
||||
CTAButton,
|
||||
FormGroup,
|
||||
FormInput,
|
||||
FormLabel,
|
||||
Input,
|
||||
Spacer,
|
||||
Stack,
|
||||
Text,
|
||||
useBoolean,
|
||||
} from 'ohmy-ui';
|
||||
import { LoadingIndicator } from '@/ui/App';
|
||||
import { fetch, handleServerErrors } from '@/utils';
|
||||
import { useState } from 'react';
|
||||
import { useState } from "react";
|
||||
import { LoadingIndicator } from "@/ui/App";
|
||||
import { fetch, useBoolean } from "@/utils";
|
||||
import { CTAButton, Input } from "@/ui/elements";
|
||||
|
||||
interface SignInFormPayload extends HTMLFormElement {
|
||||
vectorDBUrl: HTMLInputElement;
|
||||
|
|
@ -22,10 +12,10 @@ interface SignInFormPayload extends HTMLFormElement {
|
|||
}
|
||||
|
||||
const errorsMap = {
|
||||
LOGIN_BAD_CREDENTIALS: 'Invalid username or password',
|
||||
LOGIN_BAD_CREDENTIALS: "Invalid username or password",
|
||||
};
|
||||
|
||||
export default function SignInForm({ onSignInSuccess = () => window.location.href = '/', submitButtonText = 'Sign in' }) {
|
||||
export default function SignInForm({ onSignInSuccess = () => window.location.href = "/", submitButtonText = "Sign in" }) {
|
||||
const {
|
||||
value: isSigningIn,
|
||||
setTrue: disableSignIn,
|
||||
|
|
@ -46,14 +36,13 @@ export default function SignInForm({ onSignInSuccess = () => window.location.hre
|
|||
setSignInError(null);
|
||||
disableSignIn();
|
||||
|
||||
fetch('/v1/auth/login', {
|
||||
method: 'POST',
|
||||
fetch("/v1/auth/login", {
|
||||
method: "POST",
|
||||
body: authCredentials,
|
||||
})
|
||||
.then(handleServerErrors)
|
||||
.then(response => response.json())
|
||||
.then((bearer) => {
|
||||
window.localStorage.setItem('access_token', bearer.access_token);
|
||||
window.localStorage.setItem("access_token", bearer.access_token);
|
||||
onSignInSuccess();
|
||||
})
|
||||
.catch(error => setSignInError(errorsMap[error.detail as keyof typeof errorsMap]))
|
||||
|
|
@ -61,36 +50,26 @@ export default function SignInForm({ onSignInSuccess = () => window.location.hre
|
|||
};
|
||||
|
||||
return (
|
||||
<form onSubmit={signIn} style={{ width: '100%' }}>
|
||||
<Stack gap="4" orientation="vertical">
|
||||
<Stack gap="4" orientation="vertical">
|
||||
<FormGroup orientation="vertical" align="center/" gap="2">
|
||||
<FormLabel>Email:</FormLabel>
|
||||
<FormInput>
|
||||
<Input defaultValue="default_user@example.com" name="email" type="email" placeholder="Your email address" />
|
||||
</FormInput>
|
||||
</FormGroup>
|
||||
<FormGroup orientation="vertical" align="center/" gap="2">
|
||||
<FormLabel>Password:</FormLabel>
|
||||
<FormInput>
|
||||
<Input defaultValue="default_password" name="password" type="password" placeholder="Your password" />
|
||||
</FormInput>
|
||||
</FormGroup>
|
||||
</Stack>
|
||||
<form onSubmit={signIn} className="flex flex-col gap-2">
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="mb-4">
|
||||
<label className="block mb-2" htmlFor="email">Email</label>
|
||||
<Input id="email" defaultValue="default_user@example.com" name="email" type="email" placeholder="Your email address" />
|
||||
</div>
|
||||
<div className="mb-4">
|
||||
<label className="block mb-2" htmlFor="password">Password</label>
|
||||
<Input id="password" defaultValue="default_password" name="password" type="password" placeholder="Your password" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Spacer top="2">
|
||||
<CTAButton type="submit">
|
||||
<Stack gap="2" orientation="horizontal" align="/center">
|
||||
{submitButtonText}
|
||||
{isSigningIn && <LoadingIndicator />}
|
||||
</Stack>
|
||||
</CTAButton>
|
||||
</Spacer>
|
||||
<CTAButton type="submit">
|
||||
{submitButtonText}
|
||||
{isSigningIn && <LoadingIndicator />}
|
||||
</CTAButton>
|
||||
|
||||
{signInError && (
|
||||
<Text>{signInError}</Text>
|
||||
)}
|
||||
</Stack>
|
||||
{signInError && (
|
||||
<span className="text-s text-white">{signInError}</span>
|
||||
)}
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
export { default as Footer } from './Footer/Footer';
|
||||
export { default as SettingsModal } from './SettingsModal/SettingsModal';
|
||||
export { default as SearchView } from './SearchView/SearchView';
|
||||
export { default as IFrameView } from './IFrameView/IFrameView';
|
||||
export { default as Explorer } from './Explorer/Explorer';
|
||||
export { default as Footer } from "./Footer/Footer";
|
||||
export { default as SettingsModal } from "./SettingsModal/SettingsModal";
|
||||
export { default as SearchView } from "./SearchView/SearchView";
|
||||
export { default as IFrameView } from "./IFrameView/IFrameView";
|
||||
export { default as Explorer } from "./Explorer/Explorer";
|
||||
export { default as FeedbackForm } from "./FeedbackForm";
|
||||
|
|
|
|||
8
cognee-frontend/src/ui/elements/CTAButton.tsx
Normal file
8
cognee-frontend/src/ui/elements/CTAButton.tsx
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
import classNames from 'classnames';
|
||||
import { ButtonHTMLAttributes } from "react";
|
||||
|
||||
export default function CTAButton({ children, className, ...props }: ButtonHTMLAttributes<HTMLButtonElement>) {
|
||||
return (
|
||||
<button className={classNames("flex flex-row justify-center items-center gap-2 cursor-pointer rounded-md bg-indigo-600 px-3 py-2 text-sm font-semibold text-white shadow-xs hover:bg-indigo-500 focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-indigo-600", className)} {...props}>{children}</button>
|
||||
);
|
||||
}
|
||||
8
cognee-frontend/src/ui/elements/Input.tsx
Normal file
8
cognee-frontend/src/ui/elements/Input.tsx
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
import classNames from "classnames"
|
||||
import { InputHTMLAttributes } from "react"
|
||||
|
||||
export default function Input({ className, ...props }: InputHTMLAttributes<HTMLInputElement>) {
|
||||
return (
|
||||
<input className={classNames("block w-full rounded-md bg-white px-3 py-1.5 text-base text-gray-900 outline-1 -outline-offset-1 outline-gray-300 placeholder:text-gray-400 focus:outline-2 focus:-outline-offset-2 focus:outline-indigo-600 sm:text-sm/6", className)} {...props} />
|
||||
)
|
||||
}
|
||||
8
cognee-frontend/src/ui/elements/NeutralButton.tsx
Normal file
8
cognee-frontend/src/ui/elements/NeutralButton.tsx
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
import classNames from 'classnames';
|
||||
import { ButtonHTMLAttributes } from "react";
|
||||
|
||||
export default function CTAButton({ children, className, ...props }: ButtonHTMLAttributes<HTMLButtonElement>) {
|
||||
return (
|
||||
<button className={classNames("flex flex-row justify-center items-center gap-2 cursor-pointer rounded-md bg-transparent px-3 py-2 text-sm font-semibold text-white shadow-xs border-1 border-white hover:bg-gray-400 focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-indigo-600", className)} {...props}>{children}</button>
|
||||
);
|
||||
}
|
||||
10
cognee-frontend/src/ui/elements/Select.tsx
Normal file
10
cognee-frontend/src/ui/elements/Select.tsx
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
import classNames from "classnames";
|
||||
import { SelectHTMLAttributes } from "react";
|
||||
|
||||
export default function Select({ children, className, ...props }: SelectHTMLAttributes<HTMLSelectElement>) {
|
||||
return (
|
||||
<select className={classNames("block w-full appearance-none rounded-md bg-white py-1.5 pr-8 pl-3 text-base text-gray-900 outline-1 -outline-offset-1 outline-gray-300 focus:outline-2 focus:-outline-offset-2 focus:outline-indigo-600 sm:text-sm/6", className)} {...props}>
|
||||
{children}
|
||||
</select>
|
||||
);
|
||||
}
|
||||
22
cognee-frontend/src/ui/elements/StatusIndicator.tsx
Normal file
22
cognee-frontend/src/ui/elements/StatusIndicator.tsx
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
export default function StatusIndicator({ status }: { status: "DATASET_PROCESSING_COMPLETED" | string }) {
|
||||
const statusColor = {
|
||||
DATASET_PROCESSING_STARTED: "#ffd500",
|
||||
DATASET_PROCESSING_INITIATED: "#ffd500",
|
||||
DATASET_PROCESSING_COMPLETED: "#53ff24",
|
||||
DATASET_PROCESSING_ERRORED: "#ff5024",
|
||||
};
|
||||
|
||||
const isSuccess = status === "DATASET_PROCESSING_COMPLETED";
|
||||
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
width: "16px",
|
||||
height: "16px",
|
||||
borderRadius: "4px",
|
||||
background: statusColor[status as keyof typeof statusColor],
|
||||
}}
|
||||
title={isSuccess ? "Dataset cognified" : "Cognify data in order to explore it"}
|
||||
/>
|
||||
);
|
||||
}
|
||||
7
cognee-frontend/src/ui/elements/TextArea.tsx
Normal file
7
cognee-frontend/src/ui/elements/TextArea.tsx
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
import { InputHTMLAttributes } from "react"
|
||||
|
||||
export default function TextArea(props: InputHTMLAttributes<HTMLTextAreaElement>) {
|
||||
return (
|
||||
<textarea className="block w-full mt-2 rounded-md bg-white px-3 py-1.5 text-base text-gray-900 outline-1 -outline-offset-1 outline-gray-300 placeholder:text-gray-400 focus:outline-2 focus:-outline-offset-2 focus:outline-indigo-600 sm:text-sm/6" {...props} />
|
||||
)
|
||||
}
|
||||
6
cognee-frontend/src/ui/elements/index.ts
Normal file
6
cognee-frontend/src/ui/elements/index.ts
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
export { default as Input } from "./Input";
|
||||
export { default as Select } from "./Select";
|
||||
export { default as TextArea } from "./TextArea";
|
||||
export { default as CTAButton } from "./CTAButton";
|
||||
export { default as NeutralButton } from "./NeutralButton";
|
||||
export { default as StatusIndicator } from "./StatusIndicator";
|
||||
|
|
@ -1,2 +1,3 @@
|
|||
export { default as fetch } from './fetch';
|
||||
export { default as handleServerErrors } from './handleServerErrors';
|
||||
export { default as fetch } from "./fetch";
|
||||
export { default as handleServerErrors } from "./handleServerErrors";
|
||||
export { default as useBoolean } from "./useBoolean";
|
||||
|
|
|
|||
14
cognee-frontend/src/utils/useBoolean.ts
Normal file
14
cognee-frontend/src/utils/useBoolean.ts
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
import { useState } from "react";
|
||||
|
||||
export default function useBoolean(initialValue: boolean) {
|
||||
const [value, setValue] = useState(initialValue);
|
||||
|
||||
const setTrue = () => setValue(true);
|
||||
const setFalse = () => setValue(false);
|
||||
|
||||
return {
|
||||
value,
|
||||
setTrue,
|
||||
setFalse,
|
||||
};
|
||||
}
|
||||
|
|
@ -1,6 +1,10 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"lib": ["dom", "dom.iterable", "esnext"],
|
||||
"lib": [
|
||||
"dom",
|
||||
"dom.iterable",
|
||||
"esnext"
|
||||
],
|
||||
"allowJs": true,
|
||||
"skipLibCheck": true,
|
||||
"strict": true,
|
||||
|
|
@ -18,9 +22,19 @@
|
|||
}
|
||||
],
|
||||
"paths": {
|
||||
"@/*": ["./src/*"]
|
||||
}
|
||||
"@/*": [
|
||||
"./src/*"
|
||||
]
|
||||
},
|
||||
"target": "ES2017"
|
||||
},
|
||||
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"],
|
||||
"exclude": ["node_modules"]
|
||||
"include": [
|
||||
"next-env.d.ts",
|
||||
"**/*.ts",
|
||||
"**/*.tsx",
|
||||
".next/types/**/*.ts"
|
||||
],
|
||||
"exclude": [
|
||||
"node_modules"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,11 +2,17 @@
|
|||
|
||||
import os
|
||||
import uvicorn
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
import sentry_sdk
|
||||
from traceback import format_exc
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import Request
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.api.v1.permissions.routers import get_permissions_router
|
||||
from cognee.api.v1.settings.routers import get_settings_router
|
||||
from cognee.api.v1.datasets.routers import get_datasets_router
|
||||
|
|
@ -15,11 +21,8 @@ from cognee.api.v1.search.routers import get_search_router
|
|||
from cognee.api.v1.add.routers import get_add_router
|
||||
from cognee.api.v1.delete.routers import get_delete_router
|
||||
from cognee.api.v1.responses.routers import get_responses_router
|
||||
from fastapi import Request
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from cognee.api.v1.crewai.routers import get_crewai_router
|
||||
from cognee.exceptions import CogneeApiError
|
||||
from traceback import format_exc
|
||||
from cognee.api.v1.users.routers import (
|
||||
get_auth_router,
|
||||
get_register_router,
|
||||
|
|
@ -28,7 +31,6 @@ from cognee.api.v1.users.routers import (
|
|||
get_users_router,
|
||||
get_visualize_router,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -45,9 +47,10 @@ app_environment = os.getenv("ENV", "prod")
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# from cognee.modules.data.deletion import prune_system, prune_data
|
||||
# await prune_data()
|
||||
# await prune_system(metadata = True)
|
||||
from cognee.modules.data.deletion import prune_system, prune_data
|
||||
|
||||
await prune_data()
|
||||
await prune_system(metadata=True)
|
||||
# if app_environment == "local" or app_environment == "dev":
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
|
|
@ -170,6 +173,8 @@ app.include_router(get_delete_router(), prefix="/api/v1/delete", tags=["delete"]
|
|||
|
||||
app.include_router(get_responses_router(), prefix="/api/v1/responses", tags=["responses"])
|
||||
|
||||
app.include_router(get_crewai_router(), prefix="/api/v1/crewai", tags=["crewai"])
|
||||
|
||||
codegraph_routes = get_code_pipeline_router()
|
||||
if codegraph_routes:
|
||||
app.include_router(codegraph_routes, prefix="/api/v1/code-pipeline", tags=["code-pipeline"])
|
||||
|
|
@ -185,7 +190,7 @@ def start_api_server(host: str = "0.0.0.0", port: int = 8000):
|
|||
try:
|
||||
logger.info("Starting server at %s:%s", host, port)
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
uvicorn.run(app, host=host, port=port, loop="asyncio")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to start server: {e}")
|
||||
# Here you could add any cleanup code or error recovery code.
|
||||
|
|
|
|||
|
|
@ -14,6 +14,11 @@ async def add(
|
|||
):
|
||||
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user, node_set)]
|
||||
|
||||
await cognee_pipeline(
|
||||
pipeline_run_info = None
|
||||
|
||||
async for run_info in cognee_pipeline(
|
||||
tasks=tasks, datasets=dataset_name, data=data, user=user, pipeline_name="add_pipeline"
|
||||
)
|
||||
):
|
||||
pipeline_run_info = run_info
|
||||
|
||||
return pipeline_run_info
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ logger = get_logger()
|
|||
def get_add_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=None)
|
||||
@router.post("/", response_model=dict)
|
||||
async def add(
|
||||
data: List[UploadFile],
|
||||
datasetId: Optional[UUID] = Form(default=None),
|
||||
|
|
@ -56,7 +56,9 @@ def get_add_router() -> APIRouter:
|
|||
|
||||
return await cognee_add(file_data)
|
||||
else:
|
||||
await cognee_add(data, datasetName, user=user)
|
||||
add_run = await cognee_add(data, datasetName, user=user)
|
||||
|
||||
return add_run.model_dump()
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,21 @@
|
|||
import asyncio
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import Union, Optional
|
||||
from pydantic import BaseModel
|
||||
from typing import Union, Optional
|
||||
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.pipelines import cognee_pipeline
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import PipelineRunCompleted, PipelineRunStarted
|
||||
from cognee.modules.pipelines.queues.pipeline_run_info_queues import push_to_queue
|
||||
from cognee.modules.graph.operations import get_formatted_graph_data
|
||||
from cognee.modules.crewai.get_crewai_pipeline_run_id import get_crewai_pipeline_run_id
|
||||
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_documents,
|
||||
classify_documents,
|
||||
|
|
@ -16,8 +24,6 @@ from cognee.tasks.documents import (
|
|||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.pipelines import cognee_pipeline
|
||||
|
||||
logger = get_logger("cognify")
|
||||
|
||||
|
|
@ -31,13 +37,65 @@ async def cognify(
|
|||
chunker=TextChunker,
|
||||
chunk_size: int = None,
|
||||
ontology_file_path: Optional[str] = None,
|
||||
run_in_background: bool = False,
|
||||
is_stream_info_enabled: bool = False,
|
||||
):
|
||||
tasks = await get_default_tasks(user, graph_model, chunker, chunk_size, ontology_file_path)
|
||||
|
||||
return await cognee_pipeline(
|
||||
if not user:
|
||||
user = await get_default_user()
|
||||
|
||||
if run_in_background:
|
||||
return await run_cognify_as_background_process(tasks, user, datasets)
|
||||
else:
|
||||
return await run_cognify_blocking(tasks, user, datasets, is_stream_info_enabled)
|
||||
|
||||
|
||||
async def run_cognify_blocking(tasks, user, datasets, is_stream_info_enabled=False):
|
||||
pipeline_run_info = None
|
||||
|
||||
async for run_info in cognee_pipeline(
|
||||
tasks=tasks, datasets=datasets, user=user, pipeline_name="cognify_pipeline"
|
||||
):
|
||||
pipeline_run_info = run_info
|
||||
|
||||
if (
|
||||
is_stream_info_enabled
|
||||
and not isinstance(pipeline_run_info, PipelineRunStarted)
|
||||
and not isinstance(pipeline_run_info, PipelineRunCompleted)
|
||||
):
|
||||
pipeline_run_id = get_crewai_pipeline_run_id(user.id)
|
||||
pipeline_run_info.payload = await get_formatted_graph_data()
|
||||
push_to_queue(pipeline_run_id, pipeline_run_info)
|
||||
|
||||
return pipeline_run_info
|
||||
|
||||
|
||||
async def run_cognify_as_background_process(tasks, user, datasets):
|
||||
pipeline_run = cognee_pipeline(
|
||||
tasks=tasks, user=user, datasets=datasets, pipeline_name="cognify_pipeline"
|
||||
)
|
||||
|
||||
pipeline_run_started_info = await anext(pipeline_run)
|
||||
|
||||
async def handle_rest_of_the_run():
|
||||
while True:
|
||||
try:
|
||||
pipeline_run_info = await anext(pipeline_run)
|
||||
|
||||
pipeline_run_info.payload = await get_formatted_graph_data()
|
||||
|
||||
push_to_queue(pipeline_run_info.pipeline_run_id, pipeline_run_info)
|
||||
|
||||
if isinstance(pipeline_run_info, PipelineRunCompleted):
|
||||
break
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
asyncio.create_task(handle_rest_of_the_run())
|
||||
|
||||
return pipeline_run_started_info
|
||||
|
||||
|
||||
async def get_default_tasks( # TODO: Find out a better way to do this (Boris's comment)
|
||||
user: User = None,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,23 @@
|
|||
from typing import List, Optional
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from fastapi import Depends
|
||||
from fastapi import APIRouter
|
||||
from typing import List, Optional
|
||||
from starlette.status import WS_1000_NORMAL_CLOSURE, WS_1008_POLICY_VIOLATION
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import APIRouter, WebSocket, Depends, WebSocketDisconnect
|
||||
|
||||
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import PipelineRunCompleted, PipelineRunInfo
|
||||
from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
|
||||
get_from_queue,
|
||||
initialize_queue,
|
||||
remove_queue,
|
||||
)
|
||||
|
||||
|
||||
class CognifyPayloadDTO(BaseModel):
|
||||
|
|
@ -22,8 +34,107 @@ def get_cognify_router() -> APIRouter:
|
|||
from cognee.api.v1.cognify import cognify as cognee_cognify
|
||||
|
||||
try:
|
||||
await cognee_cognify(payload.datasets, user, payload.graph_model)
|
||||
cognify_run = await cognee_cognify(
|
||||
payload.datasets, user, payload.graph_model, run_in_background=True
|
||||
)
|
||||
|
||||
return cognify_run.model_dump()
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
@router.websocket("/subscribe/{pipeline_run_id}")
|
||||
async def subscribe_to_cognify_info(websocket: WebSocket, pipeline_run_id: str):
|
||||
await websocket.accept()
|
||||
|
||||
auth_message = await websocket.receive_json()
|
||||
|
||||
try:
|
||||
await get_authenticated_user(auth_message.get("Authorization"))
|
||||
except Exception:
|
||||
await websocket.close(code=WS_1008_POLICY_VIOLATION, reason="Unauthorized")
|
||||
return
|
||||
|
||||
pipeline_run_id = UUID(pipeline_run_id)
|
||||
|
||||
initialize_queue(pipeline_run_id)
|
||||
|
||||
while True:
|
||||
pipeline_run_info = get_from_queue(pipeline_run_id)
|
||||
|
||||
if not pipeline_run_info:
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
if not isinstance(pipeline_run_info, PipelineRunInfo):
|
||||
continue
|
||||
|
||||
try:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"pipeline_run_id": str(pipeline_run_info.pipeline_run_id),
|
||||
"status": pipeline_run_info.status,
|
||||
"payload": await get_nodes_and_edges(pipeline_run_info.payload)
|
||||
if pipeline_run_info.payload
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(pipeline_run_info, PipelineRunCompleted):
|
||||
remove_queue(pipeline_run_id)
|
||||
await websocket.close(code=WS_1000_NORMAL_CLOSURE)
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
remove_queue(pipeline_run_id)
|
||||
break
|
||||
|
||||
return router
|
||||
|
||||
|
||||
async def get_nodes_and_edges(data_points):
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
get_graph_from_model(
|
||||
data_point,
|
||||
added_nodes=added_nodes,
|
||||
added_edges=added_edges,
|
||||
visited_properties=visited_properties,
|
||||
)
|
||||
for data_point in data_points
|
||||
]
|
||||
)
|
||||
|
||||
for result_nodes, result_edges in results:
|
||||
nodes.extend(result_nodes)
|
||||
edges.extend(result_edges)
|
||||
|
||||
nodes, edges = deduplicate_nodes_and_edges(nodes, edges)
|
||||
|
||||
return {
|
||||
"nodes": list(
|
||||
map(
|
||||
lambda node: {
|
||||
"id": str(node.id),
|
||||
"label": node.name if hasattr(node, "name") else f"{node.type}_{str(node.id)}",
|
||||
"properties": {},
|
||||
},
|
||||
nodes,
|
||||
)
|
||||
),
|
||||
"edges": list(
|
||||
map(
|
||||
lambda edge: {
|
||||
"source": str(edge[0]),
|
||||
"target": str(edge[1]),
|
||||
"label": edge[2],
|
||||
},
|
||||
edges,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
|
|
|||
1
cognee/api/v1/crewai/routers/__init__.py
Normal file
1
cognee/api/v1/crewai/routers/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .get_crewai_router import get_crewai_router
|
||||
117
cognee/api/v1/crewai/routers/get_crewai_router.py
Normal file
117
cognee/api/v1/crewai/routers/get_crewai_router.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
import os
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||
from starlette.status import WS_1000_NORMAL_CLOSURE, WS_1008_POLICY_VIOLATION
|
||||
|
||||
from cognee.api.DTO import InDTO
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_ingest_datapoints import (
|
||||
cognify_github_data_from_username,
|
||||
)
|
||||
from cognee.modules.crewai.get_crewai_pipeline_run_id import get_crewai_pipeline_run_id
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.modules.pipelines.models import PipelineRunInfo, PipelineRunCompleted
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.main import (
|
||||
# run_github_ingestion,
|
||||
run_hiring_crew,
|
||||
)
|
||||
from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
|
||||
get_from_queue,
|
||||
initialize_queue,
|
||||
remove_queue,
|
||||
)
|
||||
|
||||
|
||||
class CrewAIRunPayloadDTO(InDTO):
|
||||
username1: str
|
||||
username2: str
|
||||
|
||||
|
||||
class CrewAIFeedbackPayloadDTO(InDTO):
|
||||
feedback: str
|
||||
|
||||
|
||||
def get_crewai_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/run", response_model=bool)
|
||||
async def run_crewai(
|
||||
payload: CrewAIRunPayloadDTO,
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
# Run CrewAI with the provided usernames
|
||||
# run_future = run_github_ingestion(payload.username1, payload.username2)
|
||||
token = os.getenv("GITHUB_TOKEN")
|
||||
|
||||
await cognify_github_data_from_username(payload.username1, token)
|
||||
await cognify_github_data_from_username(payload.username2, token)
|
||||
|
||||
applicants = {
|
||||
"applicant_1": payload.username1,
|
||||
"applicant_2": payload.username2,
|
||||
}
|
||||
|
||||
run_hiring_crew(applicants=applicants, number_of_rounds=2)
|
||||
|
||||
return True
|
||||
|
||||
@router.post("/feedback", response_model=None)
|
||||
async def send_feedback(
|
||||
payload: CrewAIFeedbackPayloadDTO,
|
||||
user: User = Depends(
|
||||
get_authenticated_user,
|
||||
),
|
||||
):
|
||||
from cognee import add, cognify
|
||||
# from secrets import choice
|
||||
# from string import ascii_letters, digits
|
||||
|
||||
# hash6 = "".join(choice(ascii_letters + digits) for _ in range(6))
|
||||
dataset_name = "final_reports"
|
||||
await add(payload.feedback, node_set=["final_report"], dataset_name=dataset_name)
|
||||
await cognify(datasets=dataset_name, is_stream_info_enabled=True)
|
||||
|
||||
@router.websocket("/subscribe")
|
||||
async def subscribe_to_crewai_info(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
||||
auth_message = await websocket.receive_json()
|
||||
|
||||
try:
|
||||
user = await get_authenticated_user(auth_message.get("Authorization"))
|
||||
except Exception:
|
||||
await websocket.close(code=WS_1008_POLICY_VIOLATION, reason="Unauthorized")
|
||||
return
|
||||
|
||||
pipeline_run_id = get_crewai_pipeline_run_id(user.id)
|
||||
|
||||
initialize_queue(pipeline_run_id)
|
||||
|
||||
while True:
|
||||
pipeline_run_info = get_from_queue(pipeline_run_id)
|
||||
|
||||
if not pipeline_run_info:
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
if not isinstance(pipeline_run_info, PipelineRunInfo):
|
||||
continue
|
||||
|
||||
try:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"pipeline_run_id": str(pipeline_run_info.pipeline_run_id),
|
||||
"status": pipeline_run_info.status,
|
||||
"payload": pipeline_run_info.payload if pipeline_run_info.payload else None,
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(pipeline_run_info, PipelineRunCompleted):
|
||||
remove_queue(pipeline_run_id)
|
||||
await websocket.close(code=WS_1000_NORMAL_CLOSURE)
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
remove_queue(pipeline_run_id)
|
||||
break
|
||||
|
||||
return router
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from cognee.modules.graph.operations import get_formatted_graph_data
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from fastapi import APIRouter
|
||||
from datetime import datetime
|
||||
|
|
@ -39,6 +40,23 @@ class DataDTO(OutDTO):
|
|||
raw_data_location: str
|
||||
|
||||
|
||||
class GraphNodeDTO(OutDTO):
|
||||
id: UUID
|
||||
label: str
|
||||
properties: dict
|
||||
|
||||
|
||||
class GraphEdgeDTO(OutDTO):
|
||||
source: UUID
|
||||
target: UUID
|
||||
label: str
|
||||
|
||||
|
||||
class GraphDTO(OutDTO):
|
||||
nodes: List[GraphNodeDTO]
|
||||
edges: List[GraphEdgeDTO]
|
||||
|
||||
|
||||
def get_datasets_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -94,24 +112,18 @@ def get_datasets_router() -> APIRouter:
|
|||
|
||||
await delete_data(data)
|
||||
|
||||
@router.get("/{dataset_id}/graph", response_model=str)
|
||||
@router.get("/{dataset_id}/graph", response_model=GraphDTO)
|
||||
async def get_dataset_graph(dataset_id: UUID, user: User = Depends(get_authenticated_user)):
|
||||
from cognee.shared.utils import render_graph
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
try:
|
||||
graph_client = await get_graph_engine()
|
||||
graph_url = await render_graph(graph_client.graph)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=str(graph_url),
|
||||
content=await get_formatted_graph_data(),
|
||||
)
|
||||
except Exception as error:
|
||||
print(error)
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content="Graphistry credentials are not set. Please set them in your .env file.",
|
||||
content="Error retrieving dataset graph data.",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from typing import Union
|
||||
from typing import Union, Optional, Type, List
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.exceptions import UserNotFoundError
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.search.methods import search as search_function
|
||||
|
||||
|
|
@ -13,6 +15,8 @@ async def search(
|
|||
datasets: Union[list[str], str, None] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> list:
|
||||
# We use lists from now on for datasets
|
||||
if isinstance(datasets, str):
|
||||
|
|
@ -28,6 +32,8 @@ async def search(
|
|||
user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
0
cognee/complex_demos/__init__.py
Normal file
0
cognee/complex_demos/__init__.py
Normal file
49
cognee/complex_demos/crewai_demo/README
Normal file
49
cognee/complex_demos/crewai_demo/README
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
# CrewAI
|
||||
|
||||
This is a demo project to showcase and test how cognee and CrewAI can work together:
|
||||
|
||||
Short description:
|
||||
|
||||
We simulate the hiring process for a technical role. These are the steps of the pipeline:
|
||||
|
||||
1. First we ingest github data including:
|
||||
-commits, comments and other soft skill related information for each of the candidates.
|
||||
-source code and other technical skill related information for each of the candidates.
|
||||
|
||||
2. We hire 3 agents to make the decision using cognee's memory engine
|
||||
|
||||
1 - HR Expert Agent focusing on soft skills:
|
||||
- Analyzes the communication skills, clarity, engagement a kindness based on the commits, comments and github communication of the candidates.
|
||||
-To analyze the soft skills of the candidates, the agent performs multiple searches using cognee.search
|
||||
-The subgraph that the agent can use is limited to the "soft" nodeset subgraph
|
||||
- Scores each candidate from 0 to 1 and gives reasoning
|
||||
|
||||
2 - Technical Expert Agent focusing on technical skills:
|
||||
- Analyzes strictly code related and technical skills based on github commits and pull requests of the candidates.
|
||||
- To analyze the technical skills of the candidates, the agent performs multiple searches using cognee.search
|
||||
- The subgraph that the agent can use is limited to the "techical" nodeset subgraph
|
||||
- Scores each candidate from 0 to 1 and gives reasoning
|
||||
|
||||
3 - CEO/CTO agent who makes the final decision:
|
||||
- Given the output of the HR expert and Technical expert agents, the decision maker agent makes the final decision about the hiring procedure.
|
||||
- The agent will choose the best candidate to hire, and will give reasoning for each of the candidates (why hire/no_hire).
|
||||
|
||||
|
||||
The following tools were implemented:
|
||||
- Cognee build: cognifies the added data (Preliminary task, therefore it is not performed by agents.)
|
||||
- Cognee search: searches the cognee memory, limiting the subgraph using the nodeset subgraph retriever (Used by many agents)
|
||||
- In the case of technical and soft skills agents the tool gets instantiated with the restricted nodeset search capability
|
||||
|
||||
|
||||
The three agents are working together to simulate a hiring process, evaluating soft and technical skills, while the CEO/CTO agent
|
||||
makes the final decision (HIRE/NOHIRE) based on the outputs of the evaluation agents.
|
||||
|
||||
|
||||
## Run in UI
|
||||
|
||||
Note1: After each restart go to `localhost:3000/auth` and login again.
|
||||
Note2: Activity is not preserved in the DB, so it will be lost after page refresh.
|
||||
|
||||
1. Start FastAPI server by running `client.py` inside `cognee/api` directory
|
||||
2. Start NextJS server by running `npm run dev` inside `cognee-frontend` directory.
|
||||
3. If you are not logged-in, app will redirect to `/auth` page. Otherwise go there manually and login (if server is restarted).
|
||||
0
cognee/complex_demos/crewai_demo/__init__.py
Normal file
0
cognee/complex_demos/crewai_demo/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
User name is John Doe.
|
||||
User is an AI Engineer.
|
||||
User is interested in AI Agents.
|
||||
User is based in San Francisco, California.
|
||||
19
cognee/complex_demos/crewai_demo/pyproject.toml
Normal file
19
cognee/complex_demos/crewai_demo/pyproject.toml
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
[project]
|
||||
name = "crewai_demo"
|
||||
version = "0.1.0"
|
||||
description = "Cognee crewAI demo"
|
||||
authors = [{ name = "Laszlo Hajdu", email = "laszlo@topoteretes.com" }]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.114.0,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
run_crew = "association_layer_demo.main:run"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.crewai]
|
||||
type = "crew"
|
||||
0
cognee/complex_demos/crewai_demo/src/__init__.py
Normal file
0
cognee/complex_demos/crewai_demo/src/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from .github_dev_profile import GitHubDevProfile
|
||||
from .github_dev_comments import GitHubDevComments
|
||||
from .github_dev_commits import GitHubDevCommits
|
||||
|
||||
__all__ = ["GitHubDevProfile", "GitHubDevComments", "GitHubDevCommits"]
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
soft_skills_expert_agent:
|
||||
role: >
|
||||
Focused on communication, collaboration, and documentation excellence.
|
||||
goal: >
|
||||
Evaluate README clarity, issue discussions, and community engagement to score
|
||||
communication clarity and open-source culture participation.
|
||||
backstory: >
|
||||
You are an active OSS community manager who values clear writing, inclusive
|
||||
discussion, and strong documentation. You look for evidence of empathy,
|
||||
responsiveness, and collaborative spirit.
|
||||
|
||||
technical_expert_agent:
|
||||
role: >
|
||||
Specialized in evaluating technical skills and code quality.
|
||||
goal: >
|
||||
Analyze repository metadata and commit histories to score coding diversity,
|
||||
depth of contributions, and commit quality.
|
||||
backstory: >
|
||||
You are a seasoned software architect and open-source maintainer. You deeply
|
||||
understand python code structure, language ecosystems, and best practices.
|
||||
Your mission is to objectively rate each candidate’s technical excellence.
|
||||
|
||||
decision_maker_agent:
|
||||
role: >
|
||||
CTO/CEO-level decision maker who integrates expert feedback.
|
||||
goal: >
|
||||
Read the technical and soft-skills evaluations and decide whether to hire
|
||||
each candidate, justifying the decision.
|
||||
backstory: >
|
||||
You are the company’s CTO. You balance technical requirements, team culture,
|
||||
and long-term vision. You weigh both skill scores and communication ratings
|
||||
to make a final hire/no-hire call.
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
soft_skills_assessment_applicant1_task:
|
||||
description: >
|
||||
Search cognee for comments authored by '{applicant_1}'.
|
||||
Use the "search_from_cognee" tool to collect information.
|
||||
Evaluate their communication clarity, community engagement, and kindness.
|
||||
Ask multiple questions if needed to uncover diverse interactions.
|
||||
Return a complete and reasoned assessment of their soft skills.
|
||||
|
||||
--- Example Output ---
|
||||
Input:
|
||||
applicant_1: Sarah Jennings
|
||||
|
||||
Output:
|
||||
- Name: Sarah Jennings
|
||||
- communication_clarity: 0.92
|
||||
- community_engagement: 0.88
|
||||
- kindness: 0.95
|
||||
- reasoning: >
|
||||
Sarah consistently communicates with clarity and structure. In several threads, her responses broke down complex issues into actionable steps,
|
||||
showing strong explanatory skills. She uses inclusive language like “let’s”, “we should”, and frequently thanks others for their input,
|
||||
which indicates a high degree of kindness. Sarah also initiates or joins collaborative threads, offering feedback or connecting people with
|
||||
relevant documentation. Her tone is encouraging and non-defensive, even when correcting misunderstandings. These patterns were observed across
|
||||
over 8 threads involving different team members over a 3-week span.
|
||||
|
||||
expected_output: >
|
||||
- Name: {applicant_1}
|
||||
- communication_clarity (0–1)
|
||||
- community_engagement (0–1)
|
||||
- kindness (0–1)
|
||||
- reasoning: (string)
|
||||
agent: soft_skills_expert_agent
|
||||
|
||||
soft_skills_assessment_applicant2_task:
|
||||
description: >
|
||||
Search cognee for comments authored by '{applicant_2}'.
|
||||
Use the "search_from_cognee" tool to collect information.
|
||||
Evaluate their communication clarity, community engagement, and kindness.
|
||||
Ask multiple questions if needed to uncover diverse interactions.
|
||||
Return a complete and reasoned assessment of their soft skills.
|
||||
|
||||
--- Example Output ---
|
||||
Input:
|
||||
applicant_1: Sarah Jennings
|
||||
|
||||
Output:
|
||||
- Name: Sarah Jennings
|
||||
- communication_clarity: 0.92
|
||||
- community_engagement: 0.88
|
||||
- kindness: 0.95
|
||||
- reasoning: >
|
||||
Sarah consistently communicates with clarity and structure. In several threads, her responses broke down complex issues into actionable steps,
|
||||
showing strong explanatory skills. She uses inclusive language like “let’s”, “we should”, and frequently thanks others for their input,
|
||||
which indicates a high degree of kindness. Sarah also initiates or joins collaborative threads, offering feedback or connecting people with
|
||||
relevant documentation. Her tone is encouraging and non-defensive, even when correcting misunderstandings. These patterns were observed across
|
||||
over 8 threads involving different team members over a 3-week span.
|
||||
|
||||
expected_output: >
|
||||
- Name: {applicant_2}
|
||||
- communication_clarity (0–1)
|
||||
- community_engagement (0–1)
|
||||
- kindness (0–1)
|
||||
- reasoning: (string)
|
||||
agent: soft_skills_expert_agent
|
||||
|
||||
technical_assessment_applicant1_task:
|
||||
description: >
|
||||
Analyze the repository metadata and commit history associated with '{applicant_1}'.
|
||||
Use the "search_from_cognee" tool to collect information.
|
||||
Score their code_diversity, depth_of_contribution, and commit_quality.
|
||||
Base your assessment strictly on technical input—ignore soft skills.
|
||||
|
||||
--- Example Output ---
|
||||
Input:
|
||||
applicant_1: Daniel Murphy
|
||||
|
||||
Output:
|
||||
- Name: Daniel Murphy
|
||||
- code_diversity: 0.87
|
||||
- depth_of_contribution: 0.91
|
||||
- commit_quality: 0.83
|
||||
- reasoning: >
|
||||
Daniel contributed to multiple areas of the codebase including frontend UI components, backend API endpoints, test coverage,
|
||||
and CI/CD configuration. His commit history spans over 6 weeks with consistent activity and includes thoughtful messages
|
||||
(e.g., “refactor auth flow to support multi-tenant login” or “add unit tests for pricing logic edge cases”).
|
||||
His pull requests often include both implementation and tests, showing technical completeness.
|
||||
Several commits show iterative problem-solving and cleanup after peer feedback, indicating thoughtful collaboration
|
||||
and improvement over time.
|
||||
expected_output: >
|
||||
- Name: {applicant_1}
|
||||
- code_diversity (0–1)
|
||||
- depth_of_contribution (0–1)
|
||||
- commit_quality (0–1)
|
||||
- reasoning: (string)
|
||||
agent: technical_expert_agent
|
||||
|
||||
technical_assessment_applicant2_task:
|
||||
description: >
|
||||
Analyze the repository metadata and commit history associated with '{applicant_2}'.
|
||||
Use the "search_from_cognee" tool to collect information.
|
||||
Score their code_diversity, depth_of_contribution, and commit_quality.
|
||||
Base your assessment strictly on technical input—ignore soft skills.
|
||||
|
||||
--- Example Output ---
|
||||
Input:
|
||||
applicant_1: Daniel Murphy
|
||||
|
||||
Output:
|
||||
- Name: Daniel Murphy
|
||||
- code_diversity: 0.87
|
||||
- depth_of_contribution: 0.91
|
||||
- commit_quality: 0.83
|
||||
- reasoning: >
|
||||
Daniel contributed to multiple areas of the codebase including frontend UI components, backend API endpoints, test coverage,
|
||||
and CI/CD configuration. His commit history spans over 6 weeks with consistent activity and includes thoughtful messages
|
||||
(e.g., “refactor auth flow to support multi-tenant login” or “add unit tests for pricing logic edge cases”).
|
||||
His pull requests often include both implementation and tests, showing technical completeness.
|
||||
Several commits show iterative problem-solving and cleanup after peer feedback, indicating thoughtful collaboration
|
||||
and improvement over time.
|
||||
|
||||
expected_output: >
|
||||
- Name: {applicant_2}
|
||||
- code_diversity (0–1)
|
||||
- depth_of_contribution (0–1)
|
||||
- commit_quality (0–1)
|
||||
- reasoning: (string)
|
||||
agent: technical_expert_agent
|
||||
|
||||
hiring_decision_task:
|
||||
description: >
|
||||
Review the technical and soft skill assessment task outputs for candidates: -{applicant_1} and -{applicant_2},
|
||||
then decide HIRE or NO_HIRE for each candidate with a detailed reasoning.
|
||||
The people to evaluate are:
|
||||
-{applicant_1}
|
||||
-{applicant_2}
|
||||
We have to hire one of them.
|
||||
|
||||
Prepare the final output for the ingest_hiring_decision_task.
|
||||
|
||||
|
||||
expected_output: >
|
||||
A string strictly containing the following for each person:
|
||||
- Person
|
||||
- decision: "HIRE" or "NO_HIRE",
|
||||
- reasoning: (string)
|
||||
agent: decision_maker_agent
|
||||
|
||||
ingest_hiring_decision_task:
|
||||
description: >
|
||||
Take the final hiring decision from the hiring_decision_task report and ingest it into Cognee using the "ingest_report_to_cognee" tool.
|
||||
Do not re-evaluate—just save the result using the tool you have.
|
||||
expected_output: >
|
||||
- confirmation: string message confirming successful ingestion into Cognee
|
||||
agent: decision_maker_agent
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class CogneeBuild(BaseTool):
|
||||
name: str = "Cognee Build"
|
||||
description: str = "Creates a memory and builds a knowledge graph using cognee."
|
||||
|
||||
def _run(self, inputs) -> str:
|
||||
import cognee
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
try:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
for meta in inputs.values():
|
||||
text = meta["file_content"]
|
||||
node_set = meta["nodeset"]
|
||||
await cognee.add(text, node_set=node_set)
|
||||
|
||||
await cognee.cognify(is_stream_info_enabled=True)
|
||||
|
||||
return "Knowledge Graph is done."
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if not loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.create_task(main())
|
||||
except Exception as e:
|
||||
return f"Tool execution error: {str(e)}"
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
import asyncio
|
||||
import nest_asyncio
|
||||
from crewai.tools import BaseTool
|
||||
from typing import Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CogneeIngestionInput(BaseModel):
|
||||
text: str = Field(
|
||||
"",
|
||||
description="The text of the report The format you should follow is {'text': 'your report'}",
|
||||
)
|
||||
|
||||
|
||||
class CogneeIngestion(BaseTool):
|
||||
name: str = "ingest_report_to_cognee"
|
||||
description: str = "This tool can be used to ingest the final hiring report into cognee"
|
||||
args_schema: Type[BaseModel] = CogneeIngestionInput
|
||||
_nodeset_name: str
|
||||
|
||||
def __init__(self, nodeset_name: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._nodeset_name = nodeset_name
|
||||
|
||||
def _run(self, text: str) -> str:
|
||||
import cognee
|
||||
# from secrets import choice
|
||||
# from string import ascii_letters, digits
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# hash6 = "".join(choice(ascii_letters + digits) for _ in range(6))
|
||||
dataset_name = "final_reports"
|
||||
await cognee.add(text, node_set=[self._nodeset_name], dataset_name=dataset_name)
|
||||
await cognee.cognify(datasets=dataset_name, is_stream_info_enabled=True)
|
||||
|
||||
return "Report ingested successfully into Cognee memory."
|
||||
except Exception as e:
|
||||
return f"Error during ingestion: {str(e)}"
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if not loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
nest_asyncio.apply(loop)
|
||||
|
||||
result = loop.run_until_complete(main())
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
return f"Tool execution error: {str(e)}"
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
import nest_asyncio
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from typing import Type
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from cognee.modules.engine.models import NodeSet
|
||||
|
||||
|
||||
class CogneeSearchInput(BaseModel):
|
||||
query: str = Field(
|
||||
"",
|
||||
description="The natural language question to ask the memory engine."
|
||||
"The format you should follow is {'query': 'your query'}",
|
||||
)
|
||||
|
||||
|
||||
class CogneeSearch(BaseTool):
|
||||
name: str = "search_from_cognee"
|
||||
description: str = (
|
||||
"Use this tool to search the Cognee memory graph. "
|
||||
"Provide a natural language query that describes the information you want to retrieve, "
|
||||
"such as comments authored or files changes by a specific person."
|
||||
)
|
||||
args_schema: Type[BaseModel] = CogneeSearchInput
|
||||
_nodeset_name: str = PrivateAttr()
|
||||
|
||||
def __init__(self, nodeset_name: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._nodeset_name = nodeset_name
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
import asyncio
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
||||
async def main():
|
||||
try:
|
||||
print(query)
|
||||
|
||||
search_results = await GraphCompletionRetriever(
|
||||
top_k=5,
|
||||
node_type=NodeSet,
|
||||
node_name=[self._nodeset_name],
|
||||
).get_context(query=query)
|
||||
|
||||
return search_results
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if not loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
nest_asyncio.apply(loop)
|
||||
|
||||
result = loop.run_until_complete(main())
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
return f"Tool execution error: {str(e)}"
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
from crewai.tools import BaseTool
|
||||
|
||||
from ..github_ingest_datapoints import cognify_github_data_from_username
|
||||
|
||||
|
||||
class GithubIngestion(BaseTool):
|
||||
name: str = "Github graph builder"
|
||||
description: str = "Ingests the github graph of a person into Cognee"
|
||||
|
||||
def _run(self, applicant_1, applicant_2) -> str:
|
||||
import asyncio
|
||||
|
||||
# import cognee
|
||||
import os
|
||||
# from cognee.low_level import setup as cognee_setup
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# await cognee.prune.prune_data()
|
||||
# await cognee.prune.prune_system(metadata=True)
|
||||
# await cognee_setup()
|
||||
token = os.getenv("GITHUB_TOKEN")
|
||||
|
||||
await cognify_github_data_from_username(applicant_1, token)
|
||||
await cognify_github_data_from_username(applicant_2, token)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if not loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.create_task(main())
|
||||
except Exception as e:
|
||||
return f"Tool execution error: {str(e)}"
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import requests
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
GITHUB_API_URL = "https://api.github.com/graphql"
|
||||
|
||||
logger = get_logger("github_comments")
|
||||
|
||||
|
||||
class GitHubCommentBase(ABC):
|
||||
"""Base class for GitHub comment providers."""
|
||||
|
||||
def __init__(self, token, username, limit=10):
|
||||
self.token = token
|
||||
self.username = username
|
||||
self.limit = limit
|
||||
|
||||
def _run_query(self, query: str) -> dict:
|
||||
"""Executes a GraphQL query against GitHub's API."""
|
||||
headers = {"Authorization": f"Bearer {self.token}"}
|
||||
response = requests.post(GITHUB_API_URL, json={"query": query}, headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Query failed: {response.status_code} - {response.text}")
|
||||
return response.json()["data"]
|
||||
|
||||
def get_comments(self):
|
||||
"""Template method that orchestrates the comment retrieval process."""
|
||||
try:
|
||||
query = self._build_query()
|
||||
data = self._run_query(query)
|
||||
raw_comments = self._extract_comments(data)
|
||||
return [self._format_comment(item) for item in raw_comments[: self.limit]]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching {self._get_comment_type()} comments: {e}")
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def _build_query(self) -> str:
|
||||
"""Builds the GraphQL query string."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _extract_comments(self, data) -> list:
|
||||
"""Extracts the comment data from the GraphQL response."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _format_comment(self, item) -> dict:
|
||||
"""Formats a single comment."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_comment_type(self) -> str:
|
||||
"""Returns the type of comment this provider handles."""
|
||||
pass
|
||||
|
|
@ -0,0 +1,298 @@
|
|||
from datetime import datetime, timedelta
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_comment_base import (
|
||||
GitHubCommentBase,
|
||||
logger,
|
||||
)
|
||||
|
||||
|
||||
class IssueCommentsProvider(GitHubCommentBase):
|
||||
"""Provider for GitHub issue comments."""
|
||||
|
||||
QUERY_TEMPLATE = """
|
||||
{{
|
||||
user(login: "{username}") {{
|
||||
issueComments(first: {limit}, orderBy: {{field: UPDATED_AT, direction: DESC}}) {{
|
||||
nodes {{
|
||||
body
|
||||
createdAt
|
||||
updatedAt
|
||||
url
|
||||
issue {{
|
||||
number
|
||||
title
|
||||
url
|
||||
repository {{
|
||||
nameWithOwner
|
||||
}}
|
||||
state
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
def _build_query(self) -> str:
|
||||
"""Builds the GraphQL query for issue comments."""
|
||||
return self.QUERY_TEMPLATE.format(username=self.username, limit=self.limit)
|
||||
|
||||
def _extract_comments(self, data) -> list:
|
||||
"""Extracts issue comments from the GraphQL response."""
|
||||
return data["user"]["issueComments"]["nodes"]
|
||||
|
||||
def _format_comment(self, comment) -> dict:
|
||||
"""Formats an issue comment from GraphQL."""
|
||||
comment_id = comment["url"].split("/")[-1] if comment["url"] else None
|
||||
|
||||
return {
|
||||
"repo": comment["issue"]["repository"]["nameWithOwner"],
|
||||
"issue_number": comment["issue"]["number"],
|
||||
"comment_id": comment_id,
|
||||
"body": comment["body"],
|
||||
"text": comment["body"],
|
||||
"created_at": comment["createdAt"],
|
||||
"updated_at": comment["updatedAt"],
|
||||
"html_url": comment["url"],
|
||||
"issue_url": comment["issue"]["url"],
|
||||
"author_association": "COMMENTER",
|
||||
"issue_title": comment["issue"]["title"],
|
||||
"issue_state": comment["issue"]["state"],
|
||||
"login": self.username,
|
||||
"type": "issue_comment",
|
||||
}
|
||||
|
||||
def _get_comment_type(self) -> str:
|
||||
"""Returns the comment type for error messages."""
|
||||
return "issue"
|
||||
|
||||
|
||||
class PrReviewsProvider(GitHubCommentBase):
|
||||
"""Provider for GitHub PR reviews."""
|
||||
|
||||
QUERY_TEMPLATE = """
|
||||
{{
|
||||
user(login: "{username}") {{
|
||||
contributionsCollection {{
|
||||
pullRequestReviewContributions(first: {fetch_limit}) {{
|
||||
nodes {{
|
||||
pullRequestReview {{
|
||||
body
|
||||
createdAt
|
||||
updatedAt
|
||||
url
|
||||
state
|
||||
pullRequest {{
|
||||
number
|
||||
title
|
||||
url
|
||||
repository {{
|
||||
nameWithOwner
|
||||
}}
|
||||
state
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
def __init__(self, token, username, limit=10, fetch_limit=None):
|
||||
"""Initialize with token, username, and optional limits."""
|
||||
super().__init__(token, username, limit)
|
||||
self.fetch_limit = fetch_limit if fetch_limit is not None else 10 * limit
|
||||
|
||||
def _build_query(self) -> str:
|
||||
"""Builds the GraphQL query for PR reviews."""
|
||||
return self.QUERY_TEMPLATE.format(username=self.username, fetch_limit=self.fetch_limit)
|
||||
|
||||
def _extract_comments(self, data) -> list:
|
||||
"""Extracts PR reviews from the GraphQL response."""
|
||||
contributions = data["user"]["contributionsCollection"]["pullRequestReviewContributions"][
|
||||
"nodes"
|
||||
]
|
||||
return [
|
||||
node["pullRequestReview"] for node in contributions if node["pullRequestReview"]["body"]
|
||||
]
|
||||
|
||||
def _format_comment(self, review) -> dict:
|
||||
"""Formats a PR review from GraphQL."""
|
||||
review_id = review["url"].split("/")[-1] if review["url"] else None
|
||||
|
||||
return {
|
||||
"repo": review["pullRequest"]["repository"]["nameWithOwner"],
|
||||
"issue_number": review["pullRequest"]["number"],
|
||||
"comment_id": review_id,
|
||||
"body": review["body"],
|
||||
"text": review["body"],
|
||||
"created_at": review["createdAt"],
|
||||
"updated_at": review["updatedAt"],
|
||||
"html_url": review["url"],
|
||||
"issue_url": review["pullRequest"]["url"],
|
||||
"author_association": "COMMENTER",
|
||||
"issue_title": review["pullRequest"]["title"],
|
||||
"issue_state": review["pullRequest"]["state"],
|
||||
"login": self.username,
|
||||
"review_state": review["state"],
|
||||
"type": "pr_review",
|
||||
}
|
||||
|
||||
def _get_comment_type(self) -> str:
|
||||
"""Returns the comment type for error messages."""
|
||||
return "PR review"
|
||||
|
||||
|
||||
class PrReviewCommentsProvider(GitHubCommentBase):
|
||||
"""Provider for GitHub PR review comments (inline code comments)."""
|
||||
|
||||
PR_CONTRIBUTIONS_TEMPLATE = """
|
||||
{{
|
||||
user(login: "{username}") {{
|
||||
contributionsCollection {{
|
||||
pullRequestReviewContributions(first: {fetch_limit}) {{
|
||||
nodes {{
|
||||
pullRequestReview {{
|
||||
pullRequest {{
|
||||
number
|
||||
title
|
||||
url
|
||||
repository {{
|
||||
nameWithOwner
|
||||
}}
|
||||
state
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
PR_COMMENTS_TEMPLATE = """
|
||||
{{
|
||||
repository(owner: "{owner}", name: "{repo}") {{
|
||||
pullRequest(number: {pr_number}) {{
|
||||
reviews(first: {reviews_limit}, author: "{username}") {{
|
||||
nodes {{
|
||||
comments(first: {comments_limit}) {{
|
||||
nodes {{
|
||||
body
|
||||
createdAt
|
||||
updatedAt
|
||||
url
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token,
|
||||
username,
|
||||
limit=10,
|
||||
fetch_limit=None,
|
||||
reviews_limit=None,
|
||||
comments_limit=None,
|
||||
pr_limit=None,
|
||||
):
|
||||
"""Initialize with token, username, and optional limits."""
|
||||
super().__init__(token, username, limit)
|
||||
self.fetch_limit = fetch_limit if fetch_limit is not None else 4 * limit
|
||||
self.reviews_limit = reviews_limit if reviews_limit is not None else 2 * limit
|
||||
self.comments_limit = comments_limit if comments_limit is not None else 3 * limit
|
||||
self.pr_limit = pr_limit if pr_limit is not None else 2 * limit
|
||||
|
||||
def _build_query(self) -> str:
|
||||
"""Builds the GraphQL query for PR contributions."""
|
||||
return self.PR_CONTRIBUTIONS_TEMPLATE.format(
|
||||
username=self.username, fetch_limit=self.fetch_limit
|
||||
)
|
||||
|
||||
def _extract_comments(self, data) -> list:
|
||||
"""Extracts PR review comments using a two-step approach."""
|
||||
prs = self._get_reviewed_prs(data)
|
||||
return self._fetch_comments_for_prs(prs)
|
||||
|
||||
def _get_reviewed_prs(self, data) -> list:
|
||||
"""Gets a deduplicated list of PRs the user has reviewed."""
|
||||
contributions = data["user"]["contributionsCollection"]["pullRequestReviewContributions"][
|
||||
"nodes"
|
||||
]
|
||||
unique_prs = []
|
||||
|
||||
for node in contributions:
|
||||
pr = node["pullRequestReview"]["pullRequest"]
|
||||
if not any(existing_pr["url"] == pr["url"] for existing_pr in unique_prs):
|
||||
unique_prs.append(pr)
|
||||
|
||||
return unique_prs[: min(self.pr_limit, len(unique_prs))]
|
||||
|
||||
def _fetch_comments_for_prs(self, prs) -> list:
|
||||
"""Fetches inline comments for each PR in the list."""
|
||||
all_comments = []
|
||||
|
||||
for pr in prs:
|
||||
comments = self._get_comments_for_pr(pr)
|
||||
all_comments.extend(comments)
|
||||
|
||||
return all_comments
|
||||
|
||||
def _get_comments_for_pr(self, pr) -> list:
|
||||
"""Fetches the inline comments for a specific PR."""
|
||||
owner, repo = pr["repository"]["nameWithOwner"].split("/")
|
||||
|
||||
pr_query = self.PR_COMMENTS_TEMPLATE.format(
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
pr_number=pr["number"],
|
||||
username=self.username,
|
||||
reviews_limit=self.reviews_limit,
|
||||
comments_limit=self.comments_limit,
|
||||
)
|
||||
|
||||
try:
|
||||
pr_comments = []
|
||||
pr_data = self._run_query(pr_query)
|
||||
reviews = pr_data["repository"]["pullRequest"]["reviews"]["nodes"]
|
||||
|
||||
for review in reviews:
|
||||
for comment in review["comments"]["nodes"]:
|
||||
comment["_pr_data"] = pr
|
||||
pr_comments.append(comment)
|
||||
|
||||
return pr_comments
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching comments for PR #{pr['number']}: {e}")
|
||||
return []
|
||||
|
||||
def _format_comment(self, comment) -> dict:
|
||||
"""Formats a PR review comment from GraphQL."""
|
||||
pr = comment["_pr_data"]
|
||||
comment_id = comment["url"].split("/")[-1] if comment["url"] else None
|
||||
|
||||
return {
|
||||
"repo": pr["repository"]["nameWithOwner"],
|
||||
"issue_number": pr["number"],
|
||||
"comment_id": comment_id,
|
||||
"body": comment["body"],
|
||||
"text": comment["body"],
|
||||
"created_at": comment["createdAt"],
|
||||
"updated_at": comment["updatedAt"],
|
||||
"html_url": comment["url"],
|
||||
"issue_url": pr["url"],
|
||||
"author_association": "COMMENTER",
|
||||
"issue_title": pr["title"],
|
||||
"issue_state": pr["state"],
|
||||
"login": self.username,
|
||||
"type": "pr_review_comment",
|
||||
}
|
||||
|
||||
def _get_comment_type(self) -> str:
|
||||
"""Returns the comment type for error messages."""
|
||||
return "PR review comment"
|
||||
|
|
@ -0,0 +1,169 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_datapoints import (
|
||||
GitHubUser,
|
||||
Repository,
|
||||
File,
|
||||
FileChange,
|
||||
Comment,
|
||||
Issue,
|
||||
Commit,
|
||||
)
|
||||
|
||||
logger = get_logger("github_datapoints")
|
||||
|
||||
|
||||
def create_github_user_datapoint(user_data, nodesets: List[NodeSet]):
|
||||
"""Creates just the GitHubUser DataPoint object from the user data, with node sets."""
|
||||
if not user_data:
|
||||
return None
|
||||
|
||||
user_id = uuid5(NAMESPACE_OID, user_data.get("login", ""))
|
||||
|
||||
user = GitHubUser(
|
||||
id=user_id,
|
||||
name=user_data.get("login", ""),
|
||||
bio=user_data.get("bio"),
|
||||
company=user_data.get("company"),
|
||||
location=user_data.get("location"),
|
||||
public_repos=user_data.get("public_repos", 0),
|
||||
followers=user_data.get("followers", 0),
|
||||
following=user_data.get("following", 0),
|
||||
interacts_with=[],
|
||||
belongs_to_set=nodesets,
|
||||
)
|
||||
|
||||
logger.debug(f"Created GitHubUser with ID: {user_id}")
|
||||
|
||||
return [user] + nodesets
|
||||
|
||||
|
||||
def create_repository_datapoint(repo_name: str, nodesets: List[NodeSet]) -> Repository:
|
||||
"""Creates a Repository DataPoint with a consistent ID."""
|
||||
repo_id = uuid5(NAMESPACE_OID, repo_name)
|
||||
repo = Repository(
|
||||
id=repo_id,
|
||||
name=repo_name,
|
||||
has_issue=[],
|
||||
has_commit=[],
|
||||
belongs_to_set=nodesets,
|
||||
)
|
||||
logger.debug(f"Created Repository with ID: {repo_id} for {repo_name}")
|
||||
return repo
|
||||
|
||||
|
||||
def create_file_datapoint(filename: str, repo_name: str, nodesets: List[NodeSet]) -> File:
|
||||
"""Creates a File DataPoint with a consistent ID."""
|
||||
file_key = f"{repo_name}:{filename}"
|
||||
file_id = uuid5(NAMESPACE_OID, file_key)
|
||||
file = File(
|
||||
id=file_id, name=filename, filename=filename, repo=repo_name, belongs_to_set=nodesets
|
||||
)
|
||||
logger.debug(f"Created File with ID: {file_id} for {filename}")
|
||||
return file
|
||||
|
||||
|
||||
def create_commit_datapoint(
|
||||
commit_data: Dict[str, Any], user: GitHubUser, nodesets: List[NodeSet]
|
||||
) -> Commit:
|
||||
"""Creates a Commit DataPoint with a consistent ID and connection to user."""
|
||||
commit_id = uuid5(NAMESPACE_OID, commit_data.get("commit_sha", ""))
|
||||
commit = Commit(
|
||||
id=commit_id,
|
||||
name=commit_data.get("commit_sha", ""),
|
||||
commit_sha=commit_data.get("commit_sha", ""),
|
||||
text="Commit message:" + (str)(commit_data.get("commit_message", "")),
|
||||
commit_date=commit_data.get("commit_date", ""),
|
||||
commit_url=commit_data.get("commit_url", ""),
|
||||
author_name=commit_data.get("login", ""),
|
||||
repo=commit_data.get("repo", ""),
|
||||
has_change=[],
|
||||
belongs_to_set=nodesets,
|
||||
)
|
||||
logger.debug(f"Created Commit with ID: {commit_id} for {commit_data.get('commit_sha', '')}")
|
||||
return commit
|
||||
|
||||
|
||||
def create_file_change_datapoint(
|
||||
fc_data: Dict[str, Any], user: GitHubUser, file: File, nodesets: List[NodeSet]
|
||||
) -> FileChange:
|
||||
"""Creates a FileChange DataPoint with a consistent ID."""
|
||||
fc_key = (
|
||||
f"{fc_data.get('repo', '')}:{fc_data.get('commit_sha', '')}:{fc_data.get('filename', '')}"
|
||||
)
|
||||
fc_id = uuid5(NAMESPACE_OID, fc_key)
|
||||
|
||||
file_change = FileChange(
|
||||
id=fc_id,
|
||||
name=fc_data.get("filename", ""),
|
||||
filename=fc_data.get("filename", ""),
|
||||
status=fc_data.get("status", ""),
|
||||
additions=fc_data.get("additions", 0),
|
||||
deletions=fc_data.get("deletions", 0),
|
||||
changes=fc_data.get("changes", 0),
|
||||
text=fc_data.get("diff", ""),
|
||||
commit_sha=fc_data.get("commit_sha", ""),
|
||||
repo=fc_data.get("repo", ""),
|
||||
modifies=file.filename,
|
||||
changed_by=user,
|
||||
belongs_to_set=nodesets,
|
||||
)
|
||||
logger.debug(f"Created FileChange with ID: {fc_id} for {fc_data.get('filename', '')}")
|
||||
return file_change
|
||||
|
||||
|
||||
def create_issue_datapoint(
|
||||
issue_data: Dict[str, Any], repo_name: str, nodesets: List[NodeSet]
|
||||
) -> Issue:
|
||||
"""Creates an Issue DataPoint with a consistent ID."""
|
||||
issue_key = f"{repo_name}:{issue_data.get('issue_number', '')}"
|
||||
issue_id = uuid5(NAMESPACE_OID, issue_key)
|
||||
|
||||
issue = Issue(
|
||||
id=issue_id,
|
||||
name=str(issue_data.get("issue_number", 0)),
|
||||
number=issue_data.get("issue_number", 0),
|
||||
text=issue_data.get("issue_title", ""),
|
||||
state=issue_data.get("issue_state", ""),
|
||||
repository=repo_name,
|
||||
is_pr=False,
|
||||
has_comment=[],
|
||||
belongs_to_set=nodesets,
|
||||
)
|
||||
logger.debug(f"Created Issue with ID: {issue_id} for {issue_data.get('issue_title', '')}")
|
||||
return issue
|
||||
|
||||
|
||||
def create_comment_datapoint(
|
||||
comment_data: Dict[str, Any], user: GitHubUser, nodesets: List[NodeSet]
|
||||
) -> Comment:
|
||||
"""Creates a Comment DataPoint with a consistent ID and connection to user."""
|
||||
comment_key = f"{comment_data.get('repo', '')}:{comment_data.get('issue_number', '')}:{comment_data.get('comment_id', '')}"
|
||||
comment_id = uuid5(NAMESPACE_OID, comment_key)
|
||||
|
||||
comment = Comment(
|
||||
id=comment_id,
|
||||
name=str(comment_data.get("comment_id", "")),
|
||||
comment_id=str(comment_data.get("comment_id", "")),
|
||||
text=comment_data.get("body", ""),
|
||||
created_at=comment_data.get("created_at", ""),
|
||||
updated_at=comment_data.get("updated_at", ""),
|
||||
author_name=comment_data.get("login", ""),
|
||||
issue_number=comment_data.get("issue_number", 0),
|
||||
repo=comment_data.get("repo", ""),
|
||||
authored_by=user,
|
||||
belongs_to_set=nodesets,
|
||||
)
|
||||
logger.debug(f"Created Comment with ID: {comment_id}")
|
||||
return comment
|
||||
|
||||
|
||||
def create_github_datapoints(github_data, nodesets: List[NodeSet]):
|
||||
"""Creates DataPoint objects from GitHub data - simplified to just create user for now."""
|
||||
if not github_data:
|
||||
return None
|
||||
|
||||
return create_github_user_datapoint(github_data["user"], nodesets)
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Optional, List
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class File(DataPoint):
|
||||
"""File is now a leaf node without any lists of other DataPoints"""
|
||||
|
||||
filename: str
|
||||
name: str
|
||||
repo: str
|
||||
metadata: dict = {"index_fields": ["filename"]}
|
||||
|
||||
|
||||
class GitHubUser(DataPoint):
|
||||
name: Optional[str]
|
||||
bio: Optional[str]
|
||||
company: Optional[str]
|
||||
location: Optional[str]
|
||||
public_repos: int
|
||||
followers: int
|
||||
following: int
|
||||
interacts_with: List["Repository"] = []
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class FileChange(DataPoint):
|
||||
filename: str
|
||||
name: str
|
||||
status: str
|
||||
additions: int
|
||||
deletions: int
|
||||
changes: int
|
||||
text: str
|
||||
commit_sha: str
|
||||
repo: str
|
||||
modifies: str
|
||||
changed_by: GitHubUser
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class Comment(DataPoint):
|
||||
comment_id: str
|
||||
name: str
|
||||
text: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
author_name: str
|
||||
issue_number: int
|
||||
repo: str
|
||||
authored_by: GitHubUser
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class Issue(DataPoint):
|
||||
number: int
|
||||
name: str
|
||||
text: str
|
||||
state: str
|
||||
repository: str
|
||||
is_pr: bool
|
||||
has_comment: List[Comment] = []
|
||||
|
||||
|
||||
class Commit(DataPoint):
|
||||
commit_sha: str
|
||||
name: str
|
||||
text: str
|
||||
commit_date: str
|
||||
commit_url: str
|
||||
author_name: str
|
||||
repo: str
|
||||
has_change: List[FileChange] = []
|
||||
|
||||
|
||||
class Repository(DataPoint):
|
||||
name: str
|
||||
has_issue: List[Issue] = []
|
||||
has_commit: List[Commit] = []
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
from github import Github
|
||||
from datetime import datetime
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_comment_providers import (
|
||||
IssueCommentsProvider,
|
||||
PrReviewsProvider,
|
||||
PrReviewCommentsProvider,
|
||||
)
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_comment_base import logger
|
||||
|
||||
|
||||
class GitHubDevComments:
|
||||
"""Facade class for working with a GitHub developer's comments."""
|
||||
|
||||
def __init__(self, profile, limit=10, include_issue_details=True):
|
||||
"""Initialize with a GitHubDevProfile instance and default parameters."""
|
||||
self.profile = profile
|
||||
self.limit = limit
|
||||
self.include_issue_details = include_issue_details
|
||||
|
||||
def get_issue_comments(self):
|
||||
"""Fetches the most recent comments made by the user on issues and PRs across repositories."""
|
||||
if not self.profile.user:
|
||||
logger.warning(f"No user found for profile {self.profile.username}")
|
||||
return None
|
||||
|
||||
logger.debug(f"Fetching comments for {self.profile.username} with limit={self.limit}")
|
||||
|
||||
# Create providers with just the basic limit - they will handle their own multipliers
|
||||
issue_provider = IssueCommentsProvider(
|
||||
self.profile.token, self.profile.username, self.limit
|
||||
)
|
||||
pr_review_provider = PrReviewsProvider(
|
||||
self.profile.token, self.profile.username, self.limit
|
||||
)
|
||||
pr_comment_provider = PrReviewCommentsProvider(
|
||||
self.profile.token, self.profile.username, self.limit
|
||||
)
|
||||
|
||||
issue_comments = issue_provider.get_comments()
|
||||
pr_reviews = pr_review_provider.get_comments()
|
||||
pr_review_comments = pr_comment_provider.get_comments()
|
||||
|
||||
total_comments = issue_comments + pr_reviews + pr_review_comments
|
||||
logger.info(
|
||||
f"Retrieved {len(total_comments)} comments for {self.profile.username} "
|
||||
f"({len(issue_comments)} issue, {len(pr_reviews)} PR reviews, "
|
||||
f"{len(pr_review_comments)} PR review comments)"
|
||||
)
|
||||
|
||||
return total_comments
|
||||
|
||||
def set_limit(self, limit=None, include_issue_details=None):
|
||||
"""Sets the limit for comments to retrieve."""
|
||||
if limit is not None:
|
||||
self.limit = limit
|
||||
if include_issue_details is not None:
|
||||
self.include_issue_details = include_issue_details
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
from github import Github
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class GitHubDevCommits:
|
||||
"""Class for working with a GitHub developer's commits in pull requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
profile,
|
||||
days=30,
|
||||
prs_limit=10,
|
||||
commits_per_pr=5,
|
||||
include_files=False,
|
||||
skip_no_diff=False,
|
||||
):
|
||||
"""Initialize with a GitHubDevProfile instance and default parameters."""
|
||||
self.profile = profile
|
||||
self.days = days
|
||||
self.prs_limit = prs_limit
|
||||
self.commits_per_pr = commits_per_pr
|
||||
self.include_files = include_files
|
||||
self.skip_no_diff = skip_no_diff
|
||||
self.file_keys = ["filename", "status", "additions", "deletions", "changes", "diff"]
|
||||
|
||||
def get_user_commits(self):
|
||||
"""Fetches user's most recent commits from pull requests."""
|
||||
if not self.profile.user:
|
||||
return None
|
||||
|
||||
commits = self._collect_user_pr_commits()
|
||||
return {"user": self.profile.get_user_info(), "commits": commits}
|
||||
|
||||
def get_user_file_changes(self):
|
||||
"""Returns a flat list of file changes with associated commit information from PRs."""
|
||||
if not self.profile.user:
|
||||
return None
|
||||
|
||||
all_files = []
|
||||
commits = self._collect_user_pr_commits(include_files=True)
|
||||
|
||||
for commit in commits:
|
||||
if "files" not in commit:
|
||||
continue
|
||||
|
||||
commit_info = {
|
||||
"repo": commit["repo"],
|
||||
"commit_sha": commit["sha"],
|
||||
"commit_message": commit["message"],
|
||||
"commit_date": commit["date"],
|
||||
"commit_url": commit["url"],
|
||||
"pr_number": commit.get("pr_number"),
|
||||
"pr_title": commit.get("pr_title"),
|
||||
}
|
||||
|
||||
file_changes = []
|
||||
for file in commit["files"]:
|
||||
file_data = {key: file.get(key) for key in self.file_keys}
|
||||
file_changes.append({**file_data, **commit_info})
|
||||
|
||||
all_files.extend(file_changes)
|
||||
|
||||
return all_files
|
||||
|
||||
def set_options(
|
||||
self, days=None, prs_limit=None, commits_per_pr=None, include_files=None, skip_no_diff=None
|
||||
):
|
||||
"""Sets commit search parameters."""
|
||||
if days is not None:
|
||||
self.days = days
|
||||
if prs_limit is not None:
|
||||
self.prs_limit = prs_limit
|
||||
if commits_per_pr is not None:
|
||||
self.commits_per_pr = commits_per_pr
|
||||
if include_files is not None:
|
||||
self.include_files = include_files
|
||||
if skip_no_diff is not None:
|
||||
self.skip_no_diff = skip_no_diff
|
||||
|
||||
def _get_date_filter(self, days):
|
||||
"""Creates a date filter string for GitHub search queries."""
|
||||
if not days:
|
||||
return ""
|
||||
|
||||
date_limit = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
return f" created:>={date_limit}"
|
||||
|
||||
def _collect_user_pr_commits(self, include_files=None):
|
||||
"""Collects and sorts a user's recent commits from pull requests they authored."""
|
||||
include_files = include_files if include_files is not None else self.include_files
|
||||
|
||||
prs = self._get_user_prs()
|
||||
|
||||
if not prs:
|
||||
return []
|
||||
|
||||
all_commits = []
|
||||
for pr in prs[: self.prs_limit]:
|
||||
pr_commits = self._get_commits_from_pr(pr, include_files)
|
||||
all_commits.extend(pr_commits)
|
||||
|
||||
sorted_commits = sorted(all_commits, key=lambda x: x["date"], reverse=True)
|
||||
return sorted_commits
|
||||
|
||||
def _get_user_prs(self):
|
||||
"""Gets pull requests authored by the user."""
|
||||
date_filter = self._get_date_filter(self.days)
|
||||
query = f"author:{self.profile.username} is:pr is:merged{date_filter}"
|
||||
|
||||
try:
|
||||
return list(self.profile.github.search_issues(query))
|
||||
except Exception as e:
|
||||
print(f"Error searching for PRs: {e}")
|
||||
return []
|
||||
|
||||
def _get_commits_from_pr(self, pr_issue, include_files=None):
|
||||
"""Gets commits by the user from a specific PR."""
|
||||
include_files = include_files if include_files is not None else self.include_files
|
||||
|
||||
pr_info = self._get_pull_request_object(pr_issue)
|
||||
if not pr_info:
|
||||
return []
|
||||
|
||||
repo_name, pr = pr_info
|
||||
|
||||
all_commits = self._get_all_pr_commits(pr, pr_issue.number)
|
||||
if not all_commits:
|
||||
return []
|
||||
|
||||
user_commits = [
|
||||
c
|
||||
for c in all_commits
|
||||
if c.author and hasattr(c.author, "login") and c.author.login == self.profile.username
|
||||
]
|
||||
|
||||
commit_data = [
|
||||
self._extract_commit_data(commit, repo_name, pr_issue, include_files)
|
||||
for commit in user_commits[: self.commits_per_pr]
|
||||
]
|
||||
|
||||
return commit_data
|
||||
|
||||
def _get_pull_request_object(self, pr_issue):
|
||||
"""Gets repository and pull request objects from an issue."""
|
||||
try:
|
||||
repo_name = pr_issue.repository.full_name
|
||||
repo = self.profile.github.get_repo(repo_name)
|
||||
pr = repo.get_pull(pr_issue.number)
|
||||
return (repo_name, pr)
|
||||
except Exception as e:
|
||||
print(f"Error accessing PR #{pr_issue.number}: {e}")
|
||||
return None
|
||||
|
||||
def _get_all_pr_commits(self, pr, pr_number):
|
||||
"""Gets all commits from a pull request."""
|
||||
try:
|
||||
return list(pr.get_commits())
|
||||
except Exception as e:
|
||||
print(f"Error retrieving commits from PR #{pr_number}: {e}")
|
||||
return None
|
||||
|
||||
def _extract_commit_data(self, commit, repo_name, pr_issue, include_files=None):
|
||||
"""Extracts relevant data from a commit object within a PR context."""
|
||||
commit_data = {
|
||||
"repo": repo_name,
|
||||
"sha": commit.sha,
|
||||
"message": commit.commit.message,
|
||||
"date": commit.commit.author.date,
|
||||
"url": commit.html_url,
|
||||
"pr_number": pr_issue.number,
|
||||
"pr_title": pr_issue.title,
|
||||
"pr_url": pr_issue.html_url,
|
||||
}
|
||||
|
||||
include_files = include_files if include_files is not None else self.include_files
|
||||
|
||||
if include_files:
|
||||
commit_data["files"] = self._extract_commit_files(commit)
|
||||
|
||||
return commit_data
|
||||
|
||||
def _extract_commit_files(self, commit):
|
||||
"""Extracts files changed in a commit, including diffs."""
|
||||
files = []
|
||||
for file in commit.files:
|
||||
if self.skip_no_diff and not file.patch:
|
||||
continue
|
||||
|
||||
file_data = {key: getattr(file, key, None) for key in self.file_keys}
|
||||
|
||||
if "diff" in self.file_keys:
|
||||
file_data["diff"] = file.patch if file.patch else "No diff available for this file"
|
||||
|
||||
files.append(file_data)
|
||||
return files
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
from github import Github
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_dev_comments import GitHubDevComments
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_dev_commits import GitHubDevCommits
|
||||
|
||||
|
||||
class GitHubDevProfile:
|
||||
"""Class for working with a GitHub developer's profile, commits, and activity."""
|
||||
|
||||
def __init__(self, username, token):
|
||||
"""Initialize with a username and GitHub API token."""
|
||||
self.github = Github(token) if token else Github()
|
||||
self.token = token
|
||||
self.username = username
|
||||
self.user = self._get_user(username)
|
||||
self.user_info = self._extract_user_info() if self.user else None
|
||||
self.comments = GitHubDevComments(self) if self.user else None
|
||||
self.commits = GitHubDevCommits(self) if self.user else None
|
||||
|
||||
def get_user_info(self):
|
||||
"""Returns the cached user information."""
|
||||
return self.user_info
|
||||
|
||||
def get_user_repos(self, limit=None):
|
||||
"""Returns a list of user's repositories with limit."""
|
||||
if not self.user:
|
||||
return []
|
||||
|
||||
repos = list(self.user.get_repos())
|
||||
if limit:
|
||||
repos = repos[:limit]
|
||||
return repos
|
||||
|
||||
def get_user_commits(self, days=30, prs_limit=5, commits_per_pr=3, include_files=False):
|
||||
"""Fetches user's most recent commits from pull requests."""
|
||||
if not self.commits:
|
||||
return None
|
||||
|
||||
self.commits.set_options(
|
||||
days=days,
|
||||
prs_limit=prs_limit,
|
||||
commits_per_pr=commits_per_pr,
|
||||
include_files=include_files,
|
||||
)
|
||||
|
||||
return self.commits.get_user_commits()
|
||||
|
||||
def get_user_file_changes(self, days=30, prs_limit=5, commits_per_pr=3, skip_no_diff=True):
|
||||
"""Returns a flat list of file changes from PRs with associated commit information."""
|
||||
if not self.commits:
|
||||
return None
|
||||
|
||||
self.commits.set_options(
|
||||
days=days,
|
||||
prs_limit=prs_limit,
|
||||
commits_per_pr=commits_per_pr,
|
||||
include_files=True,
|
||||
skip_no_diff=skip_no_diff,
|
||||
)
|
||||
|
||||
return self.commits.get_user_file_changes()
|
||||
|
||||
def get_issue_comments(self, limit=10, include_issue_details=True):
|
||||
"""Fetches the most recent comments made by the user on issues and PRs across repositories."""
|
||||
if not self.comments:
|
||||
return None
|
||||
|
||||
self.comments.set_limit(
|
||||
limit=limit,
|
||||
include_issue_details=include_issue_details,
|
||||
)
|
||||
|
||||
return self.comments.get_issue_comments()
|
||||
|
||||
def _get_user(self, username):
|
||||
"""Fetches a GitHub user object."""
|
||||
try:
|
||||
return self.github.get_user(username)
|
||||
except Exception as e:
|
||||
print(f"Error connecting to GitHub API: {e}")
|
||||
return None
|
||||
|
||||
def _extract_user_info(self):
|
||||
"""Extracts basic information from a GitHub user object."""
|
||||
return {
|
||||
"login": self.user.login,
|
||||
"name": self.user.name,
|
||||
"bio": self.user.bio,
|
||||
"company": self.user.company,
|
||||
"location": self.user.location,
|
||||
"public_repos": self.user.public_repos,
|
||||
"followers": self.user.followers,
|
||||
"following": self.user.following,
|
||||
}
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
import json
|
||||
import asyncio
|
||||
import cognee
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_dev_profile import GitHubDevProfile
|
||||
|
||||
|
||||
def get_github_profile_data(
|
||||
username, token=None, days=30, prs_limit=5, commits_per_pr=3, issues_limit=5, max_comments=3
|
||||
):
|
||||
"""Fetches comprehensive GitHub profile data including user info, commits from PRs, and comments."""
|
||||
token = token or ""
|
||||
profile = GitHubDevProfile(username, token)
|
||||
|
||||
if not profile.user:
|
||||
return None
|
||||
|
||||
commits_result = profile.get_user_commits(
|
||||
days=days, prs_limit=prs_limit, commits_per_pr=commits_per_pr, include_files=True
|
||||
)
|
||||
comments = profile.get_issue_comments(limit=max_comments, include_issue_details=True)
|
||||
|
||||
return {
|
||||
"user": profile.get_user_info(),
|
||||
"commits": commits_result["commits"] if commits_result else [],
|
||||
"comments": comments or [],
|
||||
}
|
||||
|
||||
|
||||
def get_github_file_changes(
|
||||
username, token=None, days=30, prs_limit=5, commits_per_pr=3, skip_no_diff=True
|
||||
):
|
||||
"""Fetches a flat list of file changes from PRs with associated commit information for a GitHub user."""
|
||||
token = token or ""
|
||||
profile = GitHubDevProfile(username, token)
|
||||
|
||||
if not profile.user:
|
||||
return None
|
||||
|
||||
file_changes = profile.get_user_file_changes(
|
||||
days=days, prs_limit=prs_limit, commits_per_pr=commits_per_pr, skip_no_diff=skip_no_diff
|
||||
)
|
||||
|
||||
return {"user": profile.get_user_info(), "file_changes": file_changes or []}
|
||||
|
||||
|
||||
def get_github_data_for_cognee(
|
||||
username,
|
||||
token=None,
|
||||
days=30,
|
||||
prs_limit=3,
|
||||
commits_per_pr=3,
|
||||
issues_limit=3,
|
||||
max_comments=3,
|
||||
skip_no_diff=True,
|
||||
):
|
||||
"""Fetches enriched GitHub data for a user with PR file changes and comments combined with user data."""
|
||||
token = token or ""
|
||||
profile = GitHubDevProfile(username, token)
|
||||
|
||||
if not profile.user:
|
||||
return None
|
||||
|
||||
user_info = profile.get_user_info()
|
||||
|
||||
file_changes = profile.get_user_file_changes(
|
||||
days=days, prs_limit=prs_limit, commits_per_pr=commits_per_pr, skip_no_diff=skip_no_diff
|
||||
)
|
||||
|
||||
enriched_file_changes = []
|
||||
if file_changes:
|
||||
enriched_file_changes = [item | user_info for item in file_changes]
|
||||
|
||||
comments = profile.get_issue_comments(limit=max_comments, include_issue_details=True)
|
||||
|
||||
enriched_comments = []
|
||||
if comments:
|
||||
enriched_comments = []
|
||||
for comment in comments:
|
||||
safe_user_info = {k: v for k, v in user_info.items() if k not in comment}
|
||||
enriched_comments.append(comment | safe_user_info)
|
||||
|
||||
return {"user": user_info, "file_changes": enriched_file_changes, "comments": enriched_comments}
|
||||
|
||||
|
||||
async def cognify_github_profile(username, token=None):
|
||||
"""Ingests GitHub data into Cognee with soft and technical node sets."""
|
||||
github_data = get_github_data_for_cognee(username=username, token=token)
|
||||
if not github_data:
|
||||
return False
|
||||
|
||||
await cognee.add(
|
||||
json.dumps(github_data["user"], default=str), node_set=["soft", "technical", username]
|
||||
)
|
||||
|
||||
for comment in github_data["comments"]:
|
||||
await cognee.add(
|
||||
"Comment: " + json.dumps(comment, default=str), node_set=["soft", username]
|
||||
)
|
||||
|
||||
for file_change in github_data["file_changes"]:
|
||||
await cognee.add(
|
||||
"File Change: " + json.dumps(file_change, default=str), node_set=["technical", username]
|
||||
)
|
||||
|
||||
await cognee.cognify()
|
||||
return True
|
||||
|
||||
|
||||
async def main(username):
|
||||
"""Main function for testing Cognee ingest."""
|
||||
import os
|
||||
import dotenv
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
|
||||
dotenv.load_dotenv()
|
||||
token = os.getenv("GITHUB_TOKEN")
|
||||
|
||||
await cognify_github_profile(username, token)
|
||||
|
||||
# success = await cognify_github_profile(username, token)
|
||||
|
||||
# if success:
|
||||
# visualization_path = os.path.join(os.path.dirname(__file__), "./.artifacts/github_graph.html")
|
||||
# await visualize_graph(visualization_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
username = ""
|
||||
asyncio.run(main(username))
|
||||
# token = os.getenv("GITHUB_TOKEN")
|
||||
# github_data = get_github_data_for_cognee(username=username, token=token)
|
||||
# print(json.dumps(github_data, indent=2, default=str))
|
||||
|
|
@ -0,0 +1,317 @@
|
|||
import json
|
||||
import asyncio
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from typing import Optional, List, Dict, Any
|
||||
import cognee
|
||||
from cognee.low_level import DataPoint
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_ingest import (
|
||||
get_github_data_for_cognee,
|
||||
)
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import PipelineRunCompleted, PipelineRunStarted
|
||||
from cognee.modules.graph.operations import get_formatted_graph_data
|
||||
from cognee.modules.crewai.get_crewai_pipeline_run_id import get_crewai_pipeline_run_id
|
||||
|
||||
# Import DataPoint classes from github_datapoints.py
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_datapoints import (
|
||||
GitHubUser,
|
||||
Repository,
|
||||
File,
|
||||
Commit,
|
||||
)
|
||||
|
||||
# Import creator functions from github_datapoint_creators.py
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.github_datapoint_creators import (
|
||||
create_github_user_datapoint,
|
||||
create_repository_datapoint,
|
||||
create_file_datapoint,
|
||||
create_commit_datapoint,
|
||||
create_file_change_datapoint,
|
||||
create_issue_datapoint,
|
||||
create_comment_datapoint,
|
||||
)
|
||||
|
||||
logger = get_logger("github_ingest")
|
||||
|
||||
|
||||
def collect_repositories(
|
||||
section: List[Dict[str, Any]],
|
||||
repositories: Dict[str, Repository],
|
||||
user: GitHubUser,
|
||||
nodesets: List[NodeSet],
|
||||
) -> None:
|
||||
"""Collect unique repositories from a data section and register them to the user."""
|
||||
for entry in section:
|
||||
repo_name = entry.get("repo", "")
|
||||
if not repo_name or repo_name in repositories:
|
||||
continue
|
||||
repo = create_repository_datapoint(repo_name, nodesets)
|
||||
repositories[repo_name] = repo
|
||||
user.interacts_with.append(repo)
|
||||
|
||||
|
||||
def get_or_create_repository(
|
||||
repo_name: str, repositories: Dict[str, Repository], user: GitHubUser, nodesets: List[NodeSet]
|
||||
) -> Repository:
|
||||
if repo_name in repositories:
|
||||
return repositories[repo_name]
|
||||
repo = create_repository_datapoint(repo_name, nodesets)
|
||||
repositories[repo_name] = repo
|
||||
user.interacts_with.append(repo)
|
||||
return repo
|
||||
|
||||
|
||||
def get_or_create_file(
|
||||
filename: str,
|
||||
repo_name: str,
|
||||
files: Dict[str, File],
|
||||
technical_nodeset: NodeSet,
|
||||
) -> File:
|
||||
file_key = f"{repo_name}:{filename}"
|
||||
if file_key in files:
|
||||
return files[file_key]
|
||||
file = create_file_datapoint(filename, repo_name, [technical_nodeset])
|
||||
files[file_key] = file
|
||||
return file
|
||||
|
||||
|
||||
def get_or_create_commit(
|
||||
commit_data: Dict[str, Any],
|
||||
user: GitHubUser,
|
||||
commits: Dict[str, Commit],
|
||||
repository: Repository,
|
||||
technical_nodeset: NodeSet,
|
||||
) -> Commit:
|
||||
commit_sha = commit_data.get("commit_sha", "")
|
||||
if commit_sha in commits:
|
||||
return commits[commit_sha]
|
||||
commit = create_commit_datapoint(commit_data, user, [technical_nodeset])
|
||||
commits[commit_sha] = commit
|
||||
link_commit_to_repo(commit, repository)
|
||||
return commit
|
||||
|
||||
|
||||
def link_file_to_repo(file: File, repository: Repository):
|
||||
if file not in repository.contains:
|
||||
repository.contains.append(file)
|
||||
|
||||
|
||||
def link_commit_to_repo(commit: Commit, repository: Repository):
|
||||
if commit not in repository.has_commit:
|
||||
repository.has_commit.append(commit)
|
||||
|
||||
|
||||
def process_file_changes_data(
|
||||
github_data: Dict[str, Any],
|
||||
user: GitHubUser,
|
||||
repositories: Dict[str, Repository],
|
||||
technical_nodeset: NodeSet,
|
||||
) -> List[DataPoint]:
|
||||
"""Process file changes data and build the graph structure with stronger connections."""
|
||||
file_changes = github_data.get("file_changes", [])
|
||||
if not file_changes:
|
||||
return []
|
||||
|
||||
collect_repositories(file_changes, repositories, user, [technical_nodeset])
|
||||
|
||||
files = {}
|
||||
commits = {}
|
||||
file_changes_list = []
|
||||
for fc_data in file_changes:
|
||||
repo_name = fc_data.get("repo", "")
|
||||
filename = fc_data.get("filename", "")
|
||||
commit_sha = fc_data.get("commit_sha", "")
|
||||
if not repo_name or not filename or not commit_sha:
|
||||
continue
|
||||
repository = get_or_create_repository(repo_name, repositories, user, [technical_nodeset])
|
||||
file = get_or_create_file(filename, repo_name, files, technical_nodeset)
|
||||
commit = get_or_create_commit(fc_data, user, commits, repository, technical_nodeset)
|
||||
file_change = create_file_change_datapoint(fc_data, user, file, [technical_nodeset])
|
||||
file_changes_list.append(file_change)
|
||||
if file_change not in commit.has_change:
|
||||
commit.has_change.append(file_change)
|
||||
all_datapoints = list(commits.values()) + file_changes_list
|
||||
return all_datapoints
|
||||
|
||||
|
||||
def process_comments_data(
|
||||
github_data: Dict[str, Any],
|
||||
user: GitHubUser,
|
||||
repositories: Dict[str, Repository],
|
||||
technical_nodeset: NodeSet,
|
||||
soft_nodeset: NodeSet,
|
||||
) -> List[DataPoint]:
|
||||
"""Process comments data and build the graph structure with stronger connections."""
|
||||
comments_data = github_data.get("comments", [])
|
||||
if not comments_data:
|
||||
return []
|
||||
|
||||
collect_repositories(comments_data, repositories, user, [soft_nodeset])
|
||||
|
||||
issues = {}
|
||||
comments_list = []
|
||||
for comment_data in comments_data:
|
||||
repo_name = comment_data.get("repo", "")
|
||||
issue_number = comment_data.get("issue_number", 0)
|
||||
if not repo_name or not issue_number:
|
||||
continue
|
||||
repository = get_or_create_repository(repo_name, repositories, user, [soft_nodeset])
|
||||
issue_key = f"{repo_name}:{issue_number}"
|
||||
if issue_key not in issues:
|
||||
issue = create_issue_datapoint(comment_data, repo_name, [soft_nodeset])
|
||||
issues[issue_key] = issue
|
||||
if issue not in repository.has_issue:
|
||||
repository.has_issue.append(issue)
|
||||
comment = create_comment_datapoint(comment_data, user, [soft_nodeset])
|
||||
comments_list.append(comment)
|
||||
if comment not in issues[issue_key].has_comment:
|
||||
issues[issue_key].has_comment.append(comment)
|
||||
all_datapoints = list(issues.values()) + comments_list
|
||||
return all_datapoints
|
||||
|
||||
|
||||
def build_github_datapoints_from_dict(github_data: Dict[str, Any]):
|
||||
"""Builds all DataPoints from a GitHub data dictionary."""
|
||||
if not github_data or "user" not in github_data:
|
||||
return None
|
||||
|
||||
soft_nodeset = NodeSet(id=uuid5(NAMESPACE_OID, "NodeSet:soft"), name="soft")
|
||||
technical_nodeset = NodeSet(id=uuid5(NAMESPACE_OID, "NodeSet:technical"), name="technical")
|
||||
|
||||
datapoints = create_github_user_datapoint(
|
||||
github_data["user"], [soft_nodeset, technical_nodeset]
|
||||
)
|
||||
if not datapoints:
|
||||
return None
|
||||
user = datapoints[0]
|
||||
|
||||
repositories = {}
|
||||
|
||||
file_change_datapoints = process_file_changes_data(
|
||||
github_data, user, repositories, technical_nodeset
|
||||
)
|
||||
comment_datapoints = process_comments_data(
|
||||
github_data, user, repositories, technical_nodeset, soft_nodeset
|
||||
)
|
||||
|
||||
all_datapoints = (
|
||||
datapoints + list(repositories.values()) + file_change_datapoints + comment_datapoints
|
||||
)
|
||||
return all_datapoints
|
||||
|
||||
|
||||
async def run_with_info_stream(tasks, user, data, dataset_id, pipeline_name):
|
||||
from cognee.modules.pipelines.queues.pipeline_run_info_queues import push_to_queue
|
||||
|
||||
pipeline_run = run_tasks(
|
||||
tasks=tasks,
|
||||
data=data,
|
||||
dataset_id=dataset_id,
|
||||
pipeline_name=pipeline_name,
|
||||
user=user,
|
||||
)
|
||||
|
||||
pipeline_run_id = get_crewai_pipeline_run_id(user.id)
|
||||
|
||||
async for pipeline_run_info in pipeline_run:
|
||||
if not isinstance(pipeline_run_info, PipelineRunStarted) and not isinstance(
|
||||
pipeline_run_info, PipelineRunCompleted
|
||||
):
|
||||
pipeline_run_info.payload = await get_formatted_graph_data()
|
||||
push_to_queue(pipeline_run_id, pipeline_run_info)
|
||||
|
||||
|
||||
async def cognify_github_data(github_data: dict):
|
||||
"""Process GitHub user, file changes, and comments data from a loaded dictionary."""
|
||||
all_datapoints = build_github_datapoints_from_dict(github_data)
|
||||
if not all_datapoints:
|
||||
logger.error("Failed to create datapoints")
|
||||
return False
|
||||
|
||||
dataset_id = uuid5(NAMESPACE_OID, "GitHub")
|
||||
|
||||
cognee_user = await get_default_user()
|
||||
tasks = [Task(add_data_points, task_config={"batch_size": 50})]
|
||||
|
||||
await run_with_info_stream(
|
||||
tasks=tasks,
|
||||
data=all_datapoints,
|
||||
dataset_id=dataset_id,
|
||||
pipeline_name="github_pipeline",
|
||||
user=cognee_user,
|
||||
)
|
||||
|
||||
logger.info(f"Done processing {len(all_datapoints)} datapoints")
|
||||
|
||||
|
||||
async def cognify_github_data_from_username(
|
||||
username: str,
|
||||
token: Optional[str] = None,
|
||||
days: int = 30,
|
||||
prs_limit: int = 3,
|
||||
commits_per_pr: int = 3,
|
||||
issues_limit: int = 3,
|
||||
max_comments: int = 3,
|
||||
skip_no_diff: bool = True,
|
||||
):
|
||||
"""Fetches GitHub data for a username and processes it through the DataPoint pipeline."""
|
||||
|
||||
logger.info(f"Fetching GitHub data for user: {username}")
|
||||
|
||||
github_data = get_github_data_for_cognee(
|
||||
username=username,
|
||||
token=token,
|
||||
days=days,
|
||||
prs_limit=prs_limit,
|
||||
commits_per_pr=commits_per_pr,
|
||||
issues_limit=issues_limit,
|
||||
max_comments=max_comments,
|
||||
skip_no_diff=skip_no_diff,
|
||||
)
|
||||
|
||||
if not github_data:
|
||||
logger.error(f"Failed to fetch GitHub data for user: {username}")
|
||||
return False
|
||||
|
||||
github_data = json.loads(json.dumps(github_data, default=str))
|
||||
|
||||
await cognify_github_data(github_data)
|
||||
|
||||
|
||||
async def process_github_from_file(json_file_path: str):
|
||||
"""Process GitHub data from a JSON file."""
|
||||
logger.info(f"Processing GitHub data from file: {json_file_path}")
|
||||
try:
|
||||
with open(json_file_path, "r") as f:
|
||||
github_data = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading JSON file: {e}")
|
||||
return False
|
||||
|
||||
return await cognify_github_data(github_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
token = os.getenv("GITHUB_TOKEN")
|
||||
|
||||
username = ""
|
||||
|
||||
async def cognify_from_username(username, token):
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await create_db_and_tables()
|
||||
await cognify_github_data_from_username(username, token)
|
||||
|
||||
# Run it
|
||||
asyncio.run(cognify_from_username(username, token))
|
||||
178
cognee/complex_demos/crewai_demo/src/crewai_demo/hiring_crew.py
Normal file
178
cognee/complex_demos/crewai_demo/src/crewai_demo/hiring_crew.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
from crewai import Agent, Crew, Process, Task
|
||||
from crewai.project import CrewBase, agent, crew, task
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.custom_tools.cognee_ingestion import (
|
||||
CogneeIngestion,
|
||||
)
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.custom_tools.cognee_search import CogneeSearch
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
role: str
|
||||
goal: str
|
||||
backstory: str
|
||||
|
||||
|
||||
@CrewBase
|
||||
class HiringCrew:
|
||||
agents_config = "config/agents.yaml"
|
||||
tasks_config = "config/tasks.yaml"
|
||||
|
||||
def __init__(self, inputs):
|
||||
self.inputs = inputs
|
||||
self
|
||||
|
||||
@agent
|
||||
def soft_skills_expert_agent(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["soft_skills_expert_agent"],
|
||||
tools=[CogneeSearch(nodeset_name="soft")],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
@agent
|
||||
def technical_expert_agent(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["technical_expert_agent"],
|
||||
tools=[CogneeSearch(nodeset_name="technical")],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
@agent
|
||||
def decision_maker_agent(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["decision_maker_agent"],
|
||||
tools=[CogneeIngestion(nodeset_name="final_report")],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
@task
|
||||
def soft_skills_assessment_applicant1_task(self) -> Task:
|
||||
self.tasks_config["soft_skills_assessment_applicant1_task"]["description"] = (
|
||||
self.tasks_config["soft_skills_assessment_applicant1_task"]["description"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
self.tasks_config["soft_skills_assessment_applicant1_task"]["expected_output"] = (
|
||||
self.tasks_config["soft_skills_assessment_applicant1_task"]["expected_output"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
return Task(
|
||||
config=self.tasks_config["soft_skills_assessment_applicant1_task"],
|
||||
async_execution=False,
|
||||
)
|
||||
|
||||
@task
|
||||
def soft_skills_assessment_applicant2_task(self) -> Task:
|
||||
self.tasks_config["soft_skills_assessment_applicant2_task"]["description"] = (
|
||||
self.tasks_config["soft_skills_assessment_applicant2_task"]["description"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
self.tasks_config["soft_skills_assessment_applicant2_task"]["expected_output"] = (
|
||||
self.tasks_config["soft_skills_assessment_applicant2_task"]["expected_output"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
return Task(
|
||||
config=self.tasks_config["soft_skills_assessment_applicant2_task"],
|
||||
async_execution=False,
|
||||
)
|
||||
|
||||
@task
|
||||
def technical_assessment_applicant1_task(self) -> Task:
|
||||
self.tasks_config["technical_assessment_applicant1_task"]["description"] = (
|
||||
self.tasks_config["technical_assessment_applicant1_task"]["description"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
self.tasks_config["technical_assessment_applicant1_task"]["expected_output"] = (
|
||||
self.tasks_config["technical_assessment_applicant1_task"]["expected_output"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
return Task(
|
||||
config=self.tasks_config["technical_assessment_applicant1_task"], async_execution=False
|
||||
)
|
||||
|
||||
@task
|
||||
def technical_assessment_applicant2_task(self) -> Task:
|
||||
self.tasks_config["technical_assessment_applicant2_task"]["description"] = (
|
||||
self.tasks_config["technical_assessment_applicant2_task"]["description"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
self.tasks_config["technical_assessment_applicant2_task"]["expected_output"] = (
|
||||
self.tasks_config["technical_assessment_applicant2_task"]["expected_output"].format(
|
||||
**self.inputs
|
||||
)
|
||||
)
|
||||
return Task(
|
||||
config=self.tasks_config["technical_assessment_applicant2_task"], async_execution=False
|
||||
)
|
||||
|
||||
@task
|
||||
def hiring_decision_task(self) -> Task:
|
||||
self.tasks_config["hiring_decision_task"]["description"] = self.tasks_config[
|
||||
"hiring_decision_task"
|
||||
]["description"].format(**self.inputs)
|
||||
self.tasks_config["hiring_decision_task"]["expected_output"] = self.tasks_config[
|
||||
"hiring_decision_task"
|
||||
]["expected_output"].format(**self.inputs)
|
||||
return Task(config=self.tasks_config["hiring_decision_task"], async_execution=False)
|
||||
|
||||
@task
|
||||
def ingest_hiring_decision_task(self) -> Task:
|
||||
self.tasks_config["ingest_hiring_decision_task"]["description"] = self.tasks_config[
|
||||
"ingest_hiring_decision_task"
|
||||
]["description"].format(**self.inputs)
|
||||
self.tasks_config["ingest_hiring_decision_task"]["expected_output"] = self.tasks_config[
|
||||
"ingest_hiring_decision_task"
|
||||
]["expected_output"].format(**self.inputs)
|
||||
return Task(
|
||||
config=self.tasks_config["ingest_hiring_decision_task"],
|
||||
async_execution=False,
|
||||
)
|
||||
|
||||
def refine_agent_configs(self, agent_name: str = None):
|
||||
system_prompt = (
|
||||
"You are an expert in improving agent definitions for autonomous AI systems. "
|
||||
"Given an agent's role, goal, and backstory, refine them to be:\n"
|
||||
"- Concise and well-written\n"
|
||||
"- Aligned with the agent’s function\n"
|
||||
"- Clear and professional\n"
|
||||
"- Consistent with multi-agent teamwork\n\n"
|
||||
"Return the updated definition as a JSON object with keys: role, goal, backstory."
|
||||
)
|
||||
|
||||
agent_keys = [agent_name] if agent_name else self.agents_config.keys()
|
||||
|
||||
for name in agent_keys:
|
||||
agent_def = self.agents_config[name]
|
||||
|
||||
user_prompt = f"""Here is the current agent definition:
|
||||
role: {agent_def["role"]}
|
||||
goal: {agent_def["goal"]}
|
||||
backstory: {agent_def["backstory"]}
|
||||
|
||||
Please improve it."""
|
||||
llm_client = get_llm_client()
|
||||
improved = llm_client.create_structured_output(
|
||||
text_input=user_prompt, system_prompt=system_prompt, response_model=AgentConfig
|
||||
)
|
||||
|
||||
self.agents_config[name] = improved.dict()
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential,
|
||||
verbose=True,
|
||||
share_crew=True,
|
||||
output_log_file="hiring_crew_log.txt",
|
||||
)
|
||||
53
cognee/complex_demos/crewai_demo/src/crewai_demo/main.py
Normal file
53
cognee/complex_demos/crewai_demo/src/crewai_demo/main.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import warnings
|
||||
import os
|
||||
from .hiring_crew import HiringCrew
|
||||
from cognee.complex_demos.crewai_demo.src.crewai_demo.custom_tools.github_ingestion import (
|
||||
GithubIngestion,
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
|
||||
|
||||
def print_environment():
|
||||
for key in sorted(os.environ):
|
||||
print(f"{key}={os.environ[key]}")
|
||||
|
||||
|
||||
def run_github_ingestion(applicant_1, applicant_2):
|
||||
return GithubIngestion().run(applicant_1=applicant_1, applicant_2=applicant_2)
|
||||
|
||||
|
||||
def run_hiring_crew(applicants: dict, number_of_rounds: int = 1, llm_client=None):
|
||||
for hiring_round in range(number_of_rounds):
|
||||
print(f"\nStarting hiring round {hiring_round + 1}...\n")
|
||||
crew = HiringCrew(inputs=applicants)
|
||||
if hiring_round > 0:
|
||||
print("Refining agent prompts for this round...")
|
||||
crew.refine_agent_configs(agent_name="soft_skills_expert_agent")
|
||||
crew.refine_agent_configs(agent_name="technical_expert_agent")
|
||||
crew.refine_agent_configs(agent_name="decision_maker_agent")
|
||||
|
||||
crew.crew().kickoff()
|
||||
|
||||
|
||||
def run(enable_ingestion=True, enable_crew=True):
|
||||
try:
|
||||
print_environment()
|
||||
|
||||
applicants = {"applicant_1": "hajdul88", "applicant_2": "lxobr"}
|
||||
|
||||
if enable_ingestion:
|
||||
run_github_ingestion(applicants["applicant_1"], applicants["applicant_2"])
|
||||
|
||||
if enable_crew:
|
||||
run_hiring_crew(applicants=applicants, number_of_rounds=5)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while running the process: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
enable_ingestion = True
|
||||
enable_crew = True
|
||||
|
||||
run(enable_ingestion=enable_ingestion, enable_crew=enable_crew)
|
||||
|
|
@ -61,4 +61,7 @@ class CorpusBuilderExecutor:
|
|||
await cognee.add(self.raw_corpus)
|
||||
|
||||
tasks = await self.task_getter(chunk_size=chunk_size, chunker=chunker)
|
||||
await cognee_pipeline(tasks=tasks)
|
||||
pipeline_run = cognee_pipeline(tasks=tasks)
|
||||
|
||||
async for run_info in pipeline_run:
|
||||
print(run_info)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import inspect
|
|||
from functools import wraps
|
||||
from abc import abstractmethod, ABC
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from typing import Protocol, Optional, Dict, Any, List, Type, Tuple
|
||||
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
|
@ -189,3 +189,6 @@ class GraphDBInterface(ABC):
|
|||
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
|
||||
"""Get all nodes connected to a given node with their relationships."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_nodeset_subgraph(self, node_type, node_name):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
@ -728,6 +728,66 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to get graph data: {e}")
|
||||
raise
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
|
||||
label = node_type.__name__
|
||||
primary_query = """
|
||||
UNWIND $names AS wantedName
|
||||
MATCH (n:Node)
|
||||
WHERE n.type = $label AND n.name = wantedName
|
||||
RETURN DISTINCT n.id
|
||||
"""
|
||||
primary_rows = await self.query(primary_query, {"names": node_name, "label": label})
|
||||
primary_ids = [row[0] for row in primary_rows]
|
||||
if not primary_ids:
|
||||
return [], []
|
||||
|
||||
neighbor_query = """
|
||||
MATCH (n:Node)-[:EDGE]-(nbr:Node)
|
||||
WHERE n.id IN $ids
|
||||
RETURN DISTINCT nbr.id
|
||||
"""
|
||||
nbr_rows = await self.query(neighbor_query, {"ids": primary_ids})
|
||||
neighbor_ids = [row[0] for row in nbr_rows]
|
||||
|
||||
all_ids = list({*primary_ids, *neighbor_ids})
|
||||
|
||||
nodes_query = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.id IN $ids
|
||||
RETURN n.id, n.name, n.type, n.properties
|
||||
"""
|
||||
node_rows = await self.query(nodes_query, {"ids": all_ids})
|
||||
nodes: List[Tuple[str, dict]] = []
|
||||
for node_id, name, typ, props in node_rows:
|
||||
data = {"id": node_id, "name": name, "type": typ}
|
||||
if props:
|
||||
try:
|
||||
data.update(json.loads(props))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON props for node {node_id}")
|
||||
nodes.append((node_id, data))
|
||||
|
||||
edges_query = """
|
||||
MATCH (a:Node)-[r:EDGE]-(b:Node)
|
||||
WHERE a.id IN $ids AND b.id IN $ids
|
||||
RETURN a.id, b.id, r.relationship_name, r.properties
|
||||
"""
|
||||
edge_rows = await self.query(edges_query, {"ids": all_ids})
|
||||
edges: List[Tuple[str, str, str, dict]] = []
|
||||
for from_id, to_id, rel_type, props in edge_rows:
|
||||
data = {}
|
||||
if props:
|
||||
try:
|
||||
data = json.loads(props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
|
||||
|
||||
edges.append((from_id, to_id, rel_type, data))
|
||||
|
||||
return nodes, edges
|
||||
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import json
|
|||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
import asyncio
|
||||
from textwrap import dedent
|
||||
from typing import Optional, Any, List, Dict
|
||||
from typing import Optional, Any, List, Dict, Type, Tuple
|
||||
from contextlib import asynccontextmanager
|
||||
from uuid import UUID
|
||||
from neo4j import AsyncSession
|
||||
|
|
@ -517,6 +517,58 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return (nodes, edges)
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
label = node_type.__name__
|
||||
|
||||
query = f"""
|
||||
UNWIND $names AS wantedName
|
||||
MATCH (n:`{label}`)
|
||||
WHERE n.name = wantedName
|
||||
WITH collect(DISTINCT n) AS primary
|
||||
|
||||
UNWIND primary AS p
|
||||
OPTIONAL MATCH (p)--(nbr)
|
||||
WITH primary, collect(DISTINCT nbr) AS nbrs
|
||||
WITH primary + nbrs AS nodelist
|
||||
|
||||
UNWIND nodelist AS node
|
||||
WITH collect(DISTINCT node) AS nodes
|
||||
|
||||
MATCH (a)-[r]-(b)
|
||||
WHERE a IN nodes AND b IN nodes
|
||||
WITH nodes, collect(DISTINCT r) AS rels
|
||||
|
||||
RETURN
|
||||
[n IN nodes |
|
||||
{{ id: n.id,
|
||||
properties: properties(n) }}] AS rawNodes,
|
||||
[r IN rels |
|
||||
{{ type: type(r),
|
||||
properties: properties(r) }}] AS rawRels
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"names": node_name})
|
||||
if not result:
|
||||
return [], []
|
||||
|
||||
raw_nodes = result[0]["rawNodes"]
|
||||
raw_rels = result[0]["rawRels"]
|
||||
|
||||
nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes]
|
||||
edges = [
|
||||
(
|
||||
r["properties"]["source_node_id"],
|
||||
r["properties"]["target_node_id"],
|
||||
r["type"],
|
||||
r["properties"],
|
||||
)
|
||||
for r in raw_rels
|
||||
]
|
||||
|
||||
return nodes, edges
|
||||
|
||||
async def get_filtered_graph_data(self, attribute_filters):
|
||||
"""
|
||||
Fetches nodes and relationships filtered by specified attribute values.
|
||||
|
|
|
|||
|
|
@ -250,14 +250,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
if len(vector_list) == 0:
|
||||
return []
|
||||
|
||||
# Normalize vector distance and add this as score information to vector_list
|
||||
normalized_values = normalize_distances(vector_list)
|
||||
for i in range(0, len(normalized_values)):
|
||||
vector_list[i]["score"] = normalized_values[i]
|
||||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("_distance"))
|
||||
for row in vector_list
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
self.max_tokens = max_tokens
|
||||
self.streaming = streaming
|
||||
|
||||
@observe(as_type="generation")
|
||||
# @observe(as_type="generation")
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
|
|
@ -77,7 +77,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
||||
@observe
|
||||
# @observe
|
||||
@sleep_and_retry_sync()
|
||||
@rate_limit_sync
|
||||
def create_structured_output(
|
||||
|
|
|
|||
11
cognee/modules/crewai/get_crewai_pipeline_run_id.py
Normal file
11
cognee/modules/crewai/get_crewai_pipeline_run_id.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||
|
||||
from cognee.modules.pipelines.utils import generate_pipeline_id, generate_pipeline_run_id
|
||||
|
||||
|
||||
def get_crewai_pipeline_run_id(user_id: UUID):
|
||||
dataset_id = uuid5(NAMESPACE_OID, "GitHub")
|
||||
pipeline_id = generate_pipeline_id(user_id, "github_pipeline")
|
||||
pipeline_run_id = generate_pipeline_run_id(pipeline_id, dataset_id)
|
||||
|
||||
return pipeline_run_id
|
||||
|
|
@ -5,4 +5,3 @@ class NodeSet(DataPoint):
|
|||
"""NodeSet data point."""
|
||||
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import List, Dict, Union
|
||||
from typing import List, Dict, Union, Optional, Type
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
|
|
@ -61,22 +61,27 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
node_dimension=1,
|
||||
edge_dimension=1,
|
||||
memory_fragment_filter=[],
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> None:
|
||||
if node_dimension < 1 or edge_dimension < 1:
|
||||
raise InvalidValueError(message="Dimensions must be positive integers")
|
||||
|
||||
try:
|
||||
if len(memory_fragment_filter) == 0:
|
||||
if node_type is not None and node_name is not None:
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
else:
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
|
||||
if not nodes_data:
|
||||
raise EntityNotFoundError(message="No node data retrieved from the database.")
|
||||
if not edges_data:
|
||||
raise EntityNotFoundError(message="No edge data retrieved from the database.")
|
||||
if not nodes_data or not edges_data:
|
||||
logger.warning("Empty projected graph.")
|
||||
return None
|
||||
|
||||
for node_id, properties in nodes_data:
|
||||
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
||||
|
|
|
|||
1
cognee/modules/graph/operations/__init__.py
Normal file
1
cognee/modules/graph/operations/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .get_formatted_graph_data import get_formatted_graph_data
|
||||
36
cognee/modules/graph/operations/get_formatted_graph_data.py
Normal file
36
cognee/modules/graph/operations/get_formatted_graph_data.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
|
||||
async def get_formatted_graph_data():
|
||||
graph_client = await get_graph_engine()
|
||||
(nodes, edges) = await graph_client.get_graph_data()
|
||||
|
||||
return {
|
||||
"nodes": list(
|
||||
map(
|
||||
lambda node: {
|
||||
"id": str(node[0]),
|
||||
"label": node[1]["name"]
|
||||
if ("name" in node[1] and node[1]["name"] != "")
|
||||
else f"{node[1]['type']}_{str(node[0])}",
|
||||
"type": node[1]["type"],
|
||||
"properties": {
|
||||
key: value
|
||||
for key, value in node[1].items()
|
||||
if key not in ["id", "type", "name"] and value is not None
|
||||
},
|
||||
},
|
||||
nodes,
|
||||
)
|
||||
),
|
||||
"edges": list(
|
||||
map(
|
||||
lambda edge: {
|
||||
"source": str(edge[0]),
|
||||
"target": str(edge[1]),
|
||||
"label": edge[2],
|
||||
},
|
||||
edges,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
|
@ -144,6 +144,7 @@ def expand_with_nodes_and_edges(
|
|||
is_a=type_node,
|
||||
description=node.description,
|
||||
ontology_valid=ontology_validated_source_ent,
|
||||
belongs_to_set=data_chunk.belongs_to_set,
|
||||
)
|
||||
|
||||
added_nodes_map[entity_node_key] = entity_node
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.data.models import GraphMetrics
|
||||
from cognee.modules.pipelines.models import PipelineRunInfo
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.pipelines.models import PipelineRun
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
|
||||
async def fetch_token_count(db_engine) -> int:
|
||||
|
|
@ -22,39 +23,39 @@ async def fetch_token_count(db_engine) -> int:
|
|||
return token_count_sum
|
||||
|
||||
|
||||
async def get_pipeline_run_metrics(pipeline_runs: list[PipelineRun], include_optional: bool):
|
||||
async def get_pipeline_run_metrics(pipeline_run: PipelineRunInfo, include_optional: bool):
|
||||
db_engine = get_relational_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
metrics_for_pipeline_runs = []
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
for pipeline_run in pipeline_runs:
|
||||
existing_metrics = await session.execute(
|
||||
select(GraphMetrics).where(GraphMetrics.id == pipeline_run.pipeline_run_id)
|
||||
)
|
||||
existing_metrics = existing_metrics.scalars().first()
|
||||
existing_metrics = await session.execute(
|
||||
select(GraphMetrics).where(GraphMetrics.id == pipeline_run.pipeline_run_id)
|
||||
)
|
||||
existing_metrics = existing_metrics.scalars().first()
|
||||
|
||||
if existing_metrics:
|
||||
metrics_for_pipeline_runs.append(existing_metrics)
|
||||
else:
|
||||
graph_metrics = await graph_engine.get_graph_metrics(include_optional)
|
||||
metrics = GraphMetrics(
|
||||
id=pipeline_run.pipeline_run_id,
|
||||
num_tokens=await fetch_token_count(db_engine),
|
||||
num_nodes=graph_metrics["num_nodes"],
|
||||
num_edges=graph_metrics["num_edges"],
|
||||
mean_degree=graph_metrics["mean_degree"],
|
||||
edge_density=graph_metrics["edge_density"],
|
||||
num_connected_components=graph_metrics["num_connected_components"],
|
||||
sizes_of_connected_components=graph_metrics["sizes_of_connected_components"],
|
||||
num_selfloops=graph_metrics["num_selfloops"],
|
||||
diameter=graph_metrics["diameter"],
|
||||
avg_shortest_path_length=graph_metrics["avg_shortest_path_length"],
|
||||
avg_clustering=graph_metrics["avg_clustering"],
|
||||
)
|
||||
metrics_for_pipeline_runs.append(metrics)
|
||||
session.add(metrics)
|
||||
|
||||
if existing_metrics:
|
||||
metrics_for_pipeline_runs.append(existing_metrics)
|
||||
else:
|
||||
graph_metrics = await graph_engine.get_graph_metrics(include_optional)
|
||||
metrics = GraphMetrics(
|
||||
id=pipeline_run.pipeline_run_id,
|
||||
num_tokens=await fetch_token_count(db_engine),
|
||||
num_nodes=graph_metrics["num_nodes"],
|
||||
num_edges=graph_metrics["num_edges"],
|
||||
mean_degree=graph_metrics["mean_degree"],
|
||||
edge_density=graph_metrics["edge_density"],
|
||||
num_connected_components=graph_metrics["num_connected_components"],
|
||||
sizes_of_connected_components=graph_metrics["sizes_of_connected_components"],
|
||||
num_selfloops=graph_metrics["num_selfloops"],
|
||||
diameter=graph_metrics["diameter"],
|
||||
avg_shortest_path_length=graph_metrics["avg_shortest_path_length"],
|
||||
avg_clustering=graph_metrics["avg_clustering"],
|
||||
)
|
||||
metrics_for_pipeline_runs.append(metrics)
|
||||
session.add(metrics)
|
||||
await session.commit()
|
||||
|
||||
return metrics_for_pipeline_runs
|
||||
|
|
|
|||
33
cognee/modules/pipelines/models/PipelineRunInfo.py
Normal file
33
cognee/modules/pipelines/models/PipelineRunInfo.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PipelineRunInfo(BaseModel):
|
||||
status: str
|
||||
pipeline_run_id: UUID
|
||||
payload: Optional[Any] = None
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
|
||||
|
||||
class PipelineRunStarted(PipelineRunInfo):
|
||||
status: str = "PipelineRunStarted"
|
||||
|
||||
|
||||
class PipelineRunYield(PipelineRunInfo):
|
||||
status: str = "PipelineRunYield"
|
||||
|
||||
|
||||
class PipelineRunCompleted(PipelineRunInfo):
|
||||
status: str = "PipelineRunCompleted"
|
||||
|
||||
|
||||
class PipelineRunErrored(PipelineRunInfo):
|
||||
status: str = "PipelineRunErrored"
|
||||
|
||||
|
||||
class PipelineRunActivity(BaseModel):
|
||||
status: str = "PipelineRunActivity"
|
||||
|
|
@ -1 +1,8 @@
|
|||
from .PipelineRun import PipelineRun, PipelineRunStatus
|
||||
from .PipelineRunInfo import (
|
||||
PipelineRunInfo,
|
||||
PipelineRunStarted,
|
||||
PipelineRunYield,
|
||||
PipelineRunCompleted,
|
||||
PipelineRunErrored,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from uuid import UUID, uuid4
|
||||
from uuid import UUID
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.pipelines.models import PipelineRun, PipelineRunStatus
|
||||
from cognee.modules.pipelines.utils import generate_pipeline_run_id
|
||||
|
||||
|
||||
async def log_pipeline_run_initiated(pipeline_id: str, pipeline_name: str, dataset_id: UUID):
|
||||
pipeline_run = PipelineRun(
|
||||
pipeline_run_id=uuid4(),
|
||||
pipeline_run_id=generate_pipeline_run_id(pipeline_id, dataset_id),
|
||||
pipeline_name=pipeline_name,
|
||||
pipeline_id=pipeline_id,
|
||||
status=PipelineRunStatus.DATASET_PROCESSING_INITIATED,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from cognee.modules.data.models import Data
|
|||
from cognee.modules.pipelines.models import PipelineRun, PipelineRunStatus
|
||||
from typing import Any
|
||||
|
||||
from cognee.modules.pipelines.utils import generate_pipeline_run_id
|
||||
|
||||
|
||||
async def log_pipeline_run_start(pipeline_id: str, pipeline_name: str, dataset_id: UUID, data: Any):
|
||||
if not data:
|
||||
|
|
@ -13,7 +15,7 @@ async def log_pipeline_run_start(pipeline_id: str, pipeline_name: str, dataset_i
|
|||
else:
|
||||
data_info = str(data)
|
||||
|
||||
pipeline_run_id = uuid4()
|
||||
pipeline_run_id = generate_pipeline_run_id(pipeline_id, dataset_id)
|
||||
|
||||
pipeline_run = PipelineRun(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
|
|
|
|||
|
|
@ -91,16 +91,11 @@ async def cognee_pipeline(
|
|||
|
||||
datasets = dataset_instances
|
||||
|
||||
awaitables = []
|
||||
|
||||
for dataset in datasets:
|
||||
awaitables.append(
|
||||
run_pipeline(
|
||||
dataset=dataset, user=user, tasks=tasks, data=data, pipeline_name=pipeline_name
|
||||
)
|
||||
)
|
||||
|
||||
return await asyncio.gather(*awaitables)
|
||||
async for run_info in run_pipeline(
|
||||
dataset=dataset, user=user, tasks=tasks, data=data, pipeline_name=pipeline_name
|
||||
):
|
||||
yield run_info
|
||||
|
||||
|
||||
async def run_pipeline(
|
||||
|
|
@ -161,12 +156,9 @@ async def run_pipeline(
|
|||
raise ValueError(f"Task {task} is not an instance of Task")
|
||||
|
||||
pipeline_run = run_tasks(tasks, dataset_id, data, user, pipeline_name)
|
||||
pipeline_run_status = None
|
||||
|
||||
async for run_status in pipeline_run:
|
||||
pipeline_run_status = run_status
|
||||
|
||||
return pipeline_run_status
|
||||
async for pipeline_run_info in pipeline_run:
|
||||
yield pipeline_run_info
|
||||
|
||||
|
||||
def check_dataset_name(dataset_name: str) -> str:
|
||||
|
|
|
|||
|
|
@ -1,18 +1,25 @@
|
|||
import json
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from typing import Any
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines.utils import generate_pipeline_id
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||
PipelineRunCompleted,
|
||||
PipelineRunErrored,
|
||||
PipelineRunStarted,
|
||||
PipelineRunYield,
|
||||
)
|
||||
|
||||
from cognee.modules.pipelines.operations import (
|
||||
log_pipeline_run_start,
|
||||
log_pipeline_run_complete,
|
||||
log_pipeline_run_error,
|
||||
)
|
||||
from cognee.modules.settings import get_current_settings
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
|
||||
from .run_tasks_base import run_tasks_base
|
||||
from ..tasks.task import Task
|
||||
|
|
@ -76,29 +83,44 @@ async def run_tasks(
|
|||
pipeline_name: str = "unknown_pipeline",
|
||||
context: dict = None,
|
||||
):
|
||||
pipeline_id = uuid5(NAMESPACE_OID, pipeline_name)
|
||||
if not user:
|
||||
user = get_default_user()
|
||||
|
||||
pipeline_id = generate_pipeline_id(user.id, pipeline_name)
|
||||
|
||||
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
||||
|
||||
yield pipeline_run
|
||||
pipeline_run_id = pipeline_run.pipeline_run_id
|
||||
|
||||
yield PipelineRunStarted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
payload=data,
|
||||
)
|
||||
|
||||
try:
|
||||
async for _ in run_tasks_with_telemetry(
|
||||
async for result in run_tasks_with_telemetry(
|
||||
tasks=tasks,
|
||||
data=data,
|
||||
user=user,
|
||||
pipeline_name=pipeline_id,
|
||||
context=context,
|
||||
):
|
||||
pass
|
||||
yield PipelineRunYield(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
payload=result,
|
||||
)
|
||||
|
||||
yield await log_pipeline_run_complete(
|
||||
await log_pipeline_run_complete(
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield await log_pipeline_run_error(
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, e
|
||||
yield PipelineRunCompleted(pipeline_run_id=pipeline_run_id)
|
||||
|
||||
except Exception as error:
|
||||
await log_pipeline_run_error(
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
|
||||
)
|
||||
raise e
|
||||
|
||||
yield PipelineRunErrored(payload=error)
|
||||
|
||||
raise error
|
||||
|
|
|
|||
37
cognee/modules/pipelines/queues/pipeline_run_info_queues.py
Normal file
37
cognee/modules/pipelines/queues/pipeline_run_info_queues.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from uuid import UUID
|
||||
from asyncio import Queue
|
||||
from typing import Optional
|
||||
|
||||
from cognee.modules.pipelines.models import PipelineRunInfo
|
||||
|
||||
|
||||
pipeline_run_info_queues = {}
|
||||
|
||||
|
||||
def initialize_queue(pipeline_run_id: UUID):
|
||||
pipeline_run_info_queues[str(pipeline_run_id)] = Queue()
|
||||
|
||||
|
||||
def get_queue(pipeline_run_id: UUID) -> Optional[Queue]:
|
||||
if str(pipeline_run_id) not in pipeline_run_info_queues:
|
||||
initialize_queue(pipeline_run_id)
|
||||
|
||||
return pipeline_run_info_queues.get(str(pipeline_run_id), None)
|
||||
|
||||
|
||||
def remove_queue(pipeline_run_id: UUID):
|
||||
pipeline_run_info_queues.pop(str(pipeline_run_id))
|
||||
|
||||
|
||||
def push_to_queue(pipeline_run_id: UUID, pipeline_run_info: PipelineRunInfo):
|
||||
queue = get_queue(pipeline_run_id)
|
||||
|
||||
if queue:
|
||||
queue.put_nowait(pipeline_run_info)
|
||||
|
||||
|
||||
def get_from_queue(pipeline_run_id: UUID) -> Optional[PipelineRunInfo]:
|
||||
queue = get_queue(pipeline_run_id)
|
||||
|
||||
item = queue.get_nowait() if queue and not queue.empty() else None
|
||||
return item
|
||||
2
cognee/modules/pipelines/utils/__init__.py
Normal file
2
cognee/modules/pipelines/utils/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from .generate_pipeline_id import generate_pipeline_id
|
||||
from .generate_pipeline_run_id import generate_pipeline_run_id
|
||||
5
cognee/modules/pipelines/utils/generate_pipeline_id.py
Normal file
5
cognee/modules/pipelines/utils/generate_pipeline_id.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||
|
||||
|
||||
def generate_pipeline_id(user_id: UUID, pipeline_name: str):
|
||||
return uuid5(NAMESPACE_OID, f"{str(user_id)}_{pipeline_name}")
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||
|
||||
|
||||
def generate_pipeline_run_id(pipeline_id: UUID, dataset_id: UUID):
|
||||
return uuid5(NAMESPACE_OID, f"{str(pipeline_id)}_{str(dataset_id)}")
|
||||
|
|
@ -24,3 +24,13 @@ class CypherSearchError(CogneeApiError):
|
|||
|
||||
class NoDataError(CriticalError):
|
||||
message: str = "No data found in the system, please add data first."
|
||||
|
||||
|
||||
class CollectionDistancesNotFoundError(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "No collection distances found for the given query.",
|
||||
name: str = "CollectionDistancesNotFoundError",
|
||||
status_code: int = status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Type, List
|
||||
from collections import Counter
|
||||
import string
|
||||
|
||||
|
|
@ -8,6 +8,9 @@ from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class GraphCompletionRetriever(BaseRetriever):
|
||||
|
|
@ -18,11 +21,15 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
self.node_type = node_type
|
||||
self.node_name = node_name
|
||||
|
||||
def _get_nodes(self, retrieved_edges: list) -> dict:
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
|
|
@ -68,7 +75,11 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
||||
|
||||
found_triplets = await brute_force_triplet_search(
|
||||
query, top_k=self.top_k, collections=vector_index_collections or None
|
||||
query,
|
||||
top_k=self.top_k,
|
||||
collections=vector_index_collections or None,
|
||||
node_type=self.node_type,
|
||||
node_name=self.node_name,
|
||||
)
|
||||
|
||||
return found_triplets
|
||||
|
|
@ -78,6 +89,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
triplets = await self.get_triplets(query)
|
||||
|
||||
if len(triplets) == 0:
|
||||
logger.warning("Empty context was provided to the completion")
|
||||
return ""
|
||||
|
||||
return await self.resolve_edges_to_text(triplets)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
|
|
@ -55,6 +55,8 @@ def format_triplets(edges):
|
|||
|
||||
async def get_memory_fragment(
|
||||
properties_to_project: Optional[List[str]] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> CogneeGraph:
|
||||
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
|
||||
graph_engine = await get_graph_engine()
|
||||
|
|
@ -68,6 +70,8 @@ async def get_memory_fragment(
|
|||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
except EntityNotFoundError:
|
||||
pass
|
||||
|
|
@ -82,6 +86,8 @@ async def brute_force_triplet_search(
|
|||
collections: List[str] = None,
|
||||
properties_to_project: List[str] = None,
|
||||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> list:
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
|
@ -93,6 +99,8 @@ async def brute_force_triplet_search(
|
|||
collections=collections,
|
||||
properties_to_project=properties_to_project,
|
||||
memory_fragment=memory_fragment,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
return retrieved_results
|
||||
|
||||
|
|
@ -104,6 +112,8 @@ async def brute_force_search(
|
|||
collections: List[str] = None,
|
||||
properties_to_project: List[str] = None,
|
||||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> list:
|
||||
"""
|
||||
Performs a brute force search to retrieve the top triplets from the graph.
|
||||
|
|
@ -115,6 +125,8 @@ async def brute_force_search(
|
|||
collections (Optional[List[str]]): List of collections to query.
|
||||
properties_to_project (Optional[List[str]]): List of properties to project.
|
||||
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
|
||||
node_type: node type to filter
|
||||
node_name: node name to filter
|
||||
|
||||
Returns:
|
||||
list: The top triplet results.
|
||||
|
|
@ -125,7 +137,9 @@ async def brute_force_search(
|
|||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(properties_to_project)
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project, node_type=node_type, node_name=node_name
|
||||
)
|
||||
|
||||
if collections is None:
|
||||
collections = [
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional, Type, List
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
|
|
@ -33,12 +33,20 @@ async def search(
|
|||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
):
|
||||
query = await log_query(query_text, query_type.value, user.id)
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id, datasets)
|
||||
search_results = await specific_search(
|
||||
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k
|
||||
query_type,
|
||||
query_text,
|
||||
user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
||||
filtered_search_results = []
|
||||
|
|
@ -61,6 +69,8 @@ async def specific_search(
|
|||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: List[Optional[str]] = None,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, Callable] = {
|
||||
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
|
||||
|
|
@ -73,6 +83,8 @@ async def specific_search(
|
|||
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
).get_completion,
|
||||
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class TestCogneeServerStart(unittest.TestCase):
|
|||
preexec_fn=os.setsid,
|
||||
)
|
||||
# Give the server some time to start
|
||||
time.sleep(20)
|
||||
time.sleep(30)
|
||||
|
||||
# Check if server started with errors
|
||||
if cls.server_process.poll() is not None:
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue