Merge branch 'dev' into feature/sqlalchemy-custom-connect-args
This commit is contained in:
commit
2de1bd977d
62 changed files with 2566 additions and 3591 deletions
30
.github/workflows/basic_tests.yml
vendored
30
.github/workflows/basic_tests.yml
vendored
|
|
@ -197,33 +197,3 @@ jobs:
|
|||
|
||||
- name: Run Simple Examples
|
||||
run: uv run python ./examples/python/simple_example.py
|
||||
|
||||
graph-tests:
|
||||
name: Run Basic Graph Tests
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Run Graph Tests
|
||||
run: uv run python ./examples/python/code_graph_example.py --repo_path ./cognee/tasks/graph
|
||||
|
|
|
|||
2475
cognee-frontend/package-lock.json
generated
2475
cognee-frontend/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
|
@ -9,13 +9,13 @@
|
|||
"lint": "next lint"
|
||||
},
|
||||
"dependencies": {
|
||||
"@auth0/nextjs-auth0": "^4.6.0",
|
||||
"@auth0/nextjs-auth0": "^4.13.1",
|
||||
"classnames": "^2.5.1",
|
||||
"culori": "^4.0.1",
|
||||
"d3-force-3d": "^3.0.6",
|
||||
"next": "15.3.3",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"next": "16.0.4",
|
||||
"react": "^19.2.0",
|
||||
"react-dom": "^19.2.0",
|
||||
"react-force-graph-2d": "^1.27.1",
|
||||
"uuid": "^9.0.1"
|
||||
},
|
||||
|
|
@ -24,11 +24,11 @@
|
|||
"@tailwindcss/postcss": "^4.1.7",
|
||||
"@types/culori": "^4.0.0",
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^18",
|
||||
"@types/react-dom": "^18",
|
||||
"@types/react": "^19",
|
||||
"@types/react-dom": "^19",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"eslint": "^9",
|
||||
"eslint-config-next": "^15.3.3",
|
||||
"eslint-config-next": "^16.0.4",
|
||||
"eslint-config-prettier": "^10.1.5",
|
||||
"tailwindcss": "^4.1.7",
|
||||
"typescript": "^5"
|
||||
|
|
|
|||
|
|
@ -1,119 +0,0 @@
|
|||
import { useState } from "react";
|
||||
import { fetch } from "@/utils";
|
||||
import { v4 as uuid4 } from "uuid";
|
||||
import { LoadingIndicator } from "@/ui/App";
|
||||
import { CTAButton, Input } from "@/ui/elements";
|
||||
|
||||
interface CrewAIFormPayload extends HTMLFormElement {
|
||||
username1: HTMLInputElement;
|
||||
username2: HTMLInputElement;
|
||||
}
|
||||
|
||||
interface CrewAITriggerProps {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
onData: (data: any) => void;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
onActivity: (activities: any) => void;
|
||||
}
|
||||
|
||||
export default function CrewAITrigger({ onData, onActivity }: CrewAITriggerProps) {
|
||||
const [isCrewAIRunning, setIsCrewAIRunning] = useState(false);
|
||||
|
||||
const handleRunCrewAI = (event: React.FormEvent<CrewAIFormPayload>) => {
|
||||
event.preventDefault();
|
||||
const formElements = event.currentTarget;
|
||||
|
||||
const crewAIConfig = {
|
||||
username1: formElements.username1.value,
|
||||
username2: formElements.username2.value,
|
||||
};
|
||||
|
||||
const backendApiUrl = process.env.NEXT_PUBLIC_BACKEND_API_URL;
|
||||
const wsUrl = backendApiUrl.replace(/^http(s)?/, "ws");
|
||||
|
||||
const websocket = new WebSocket(`${wsUrl}/v1/crewai/subscribe`);
|
||||
|
||||
onActivity([{ id: uuid4(), timestamp: Date.now(), activity: "Dispatching hiring crew agents" }]);
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
if (data.status === "PipelineRunActivity") {
|
||||
onActivity([data.payload]);
|
||||
return;
|
||||
}
|
||||
|
||||
onData({
|
||||
nodes: data.payload.nodes,
|
||||
links: data.payload.edges,
|
||||
});
|
||||
|
||||
const nodes_type_map: { [key: string]: number } = {};
|
||||
|
||||
for (let i = 0; i < data.payload.nodes.length; i++) {
|
||||
const node = data.payload.nodes[i];
|
||||
if (!nodes_type_map[node.type]) {
|
||||
nodes_type_map[node.type] = 0;
|
||||
}
|
||||
nodes_type_map[node.type] += 1;
|
||||
}
|
||||
|
||||
const activityMessage = Object.entries(nodes_type_map).reduce((message, [type, count]) => {
|
||||
return `${message}\n | ${type}: ${count}`;
|
||||
}, "Graph updated:");
|
||||
|
||||
onActivity([{
|
||||
id: uuid4(),
|
||||
timestamp: Date.now(),
|
||||
activity: activityMessage,
|
||||
}]);
|
||||
|
||||
if (data.status === "PipelineRunCompleted") {
|
||||
websocket.close();
|
||||
}
|
||||
};
|
||||
|
||||
onData(null);
|
||||
setIsCrewAIRunning(true);
|
||||
|
||||
return fetch("/v1/crewai/run", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(crewAIConfig),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(() => {
|
||||
onActivity([{ id: uuid4(), timestamp: Date.now(), activity: "Hiring crew agents made a decision" }]);
|
||||
})
|
||||
.catch(() => {
|
||||
onActivity([{ id: uuid4(), timestamp: Date.now(), activity: "Hiring crew agents had problems while executing" }]);
|
||||
})
|
||||
.finally(() => {
|
||||
websocket.close();
|
||||
setIsCrewAIRunning(false);
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<form className="w-full flex flex-col gap-2" onSubmit={handleRunCrewAI}>
|
||||
<h1 className="text-2xl text-white">Cognee Dev Mexican Standoff</h1>
|
||||
<span className="text-white">Agents compare GitHub profiles, and make a decision who is a better developer</span>
|
||||
<div className="flex flex-row gap-2">
|
||||
<div className="flex flex-col w-full flex-1/2">
|
||||
<label className="block mb-1 text-white" htmlFor="username1">GitHub username</label>
|
||||
<Input name="username1" type="text" placeholder="Github Username" required defaultValue="hajdul88" />
|
||||
</div>
|
||||
<div className="flex flex-col w-full flex-1/2">
|
||||
<label className="block mb-1 text-white" htmlFor="username2">GitHub username</label>
|
||||
<Input name="username2" type="text" placeholder="Github Username" required defaultValue="lxobr" />
|
||||
</div>
|
||||
</div>
|
||||
<CTAButton type="submit" disabled={isCrewAIRunning} className="whitespace-nowrap">
|
||||
Start Mexican Standoff
|
||||
{isCrewAIRunning && <LoadingIndicator />}
|
||||
</CTAButton>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
|
|
@ -6,7 +6,6 @@ import { NodeObject, LinkObject } from "react-force-graph-2d";
|
|||
import { ChangeEvent, useEffect, useImperativeHandle, useRef, useState } from "react";
|
||||
|
||||
import { DeleteIcon } from "@/ui/Icons";
|
||||
// import { FeedbackForm } from "@/ui/Partials";
|
||||
import { CTAButton, Input, NeutralButton, Select } from "@/ui/elements";
|
||||
|
||||
interface GraphControlsProps {
|
||||
|
|
@ -111,7 +110,7 @@ export default function GraphControls({ data, isAddNodeFormOpen, onGraphShapeCha
|
|||
};
|
||||
|
||||
const [isAuthShapeChangeEnabled, setIsAuthShapeChangeEnabled] = useState(true);
|
||||
const shapeChangeTimeout = useRef<number | null>();
|
||||
const shapeChangeTimeout = useRef<number | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
onGraphShapeChange(DEFAULT_GRAPH_SHAPE);
|
||||
|
|
@ -230,12 +229,6 @@ export default function GraphControls({ data, isAddNodeFormOpen, onGraphShapeCha
|
|||
)}
|
||||
</>
|
||||
{/* )} */}
|
||||
|
||||
{/* {selectedTab === "feedback" && (
|
||||
<div className="flex flex-col gap-2">
|
||||
<FeedbackForm onSuccess={() => {}} />
|
||||
</div>
|
||||
)} */}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { useCallback, useRef, useState, MutableRefObject } from "react";
|
||||
import { useCallback, useRef, useState, RefObject } from "react";
|
||||
|
||||
import Link from "next/link";
|
||||
import { TextLogo } from "@/ui/App";
|
||||
|
|
@ -47,11 +47,11 @@ export default function GraphView() {
|
|||
updateData(newData);
|
||||
}, []);
|
||||
|
||||
const graphRef = useRef<GraphVisualizationAPI>();
|
||||
const graphRef = useRef<GraphVisualizationAPI>(null);
|
||||
|
||||
const graphControls = useRef<GraphControlsAPI>();
|
||||
const graphControls = useRef<GraphControlsAPI>(null);
|
||||
|
||||
const activityLog = useRef<ActivityLogAPI>();
|
||||
const activityLog = useRef<ActivityLogAPI>(null);
|
||||
|
||||
return (
|
||||
<main className="flex flex-col h-full">
|
||||
|
|
@ -74,21 +74,18 @@ export default function GraphView() {
|
|||
<div className="w-full h-full relative overflow-hidden">
|
||||
<GraphVisualization
|
||||
key={data?.nodes.length}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
ref={graphRef as RefObject<GraphVisualizationAPI>}
|
||||
data={data}
|
||||
graphControls={graphControls as MutableRefObject<GraphControlsAPI>}
|
||||
graphControls={graphControls as RefObject<GraphControlsAPI>}
|
||||
/>
|
||||
|
||||
<div className="absolute top-2 left-2 flex flex-col gap-2">
|
||||
<div className="bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-sm">
|
||||
<CogneeAddWidget onData={onDataChange} />
|
||||
</div>
|
||||
{/* <div className="bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-sm">
|
||||
<CrewAITrigger onData={onDataChange} onActivity={(activities) => activityLog.current?.updateActivityLog(activities)} />
|
||||
</div> */}
|
||||
<div className="bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-sm">
|
||||
<h2 className="text-xl text-white mb-4">Activity Log</h2>
|
||||
<ActivityLog ref={activityLog as MutableRefObject<ActivityLogAPI>} />
|
||||
<ActivityLog ref={activityLog as RefObject<ActivityLogAPI>} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
|
@ -96,7 +93,7 @@ export default function GraphView() {
|
|||
<div className="bg-gray-500 pt-4 pr-4 pb-4 pl-4 rounded-md w-110">
|
||||
<GraphControls
|
||||
data={data}
|
||||
ref={graphControls as MutableRefObject<GraphControlsAPI>}
|
||||
ref={graphControls as RefObject<GraphControlsAPI>}
|
||||
isAddNodeFormOpen={isAddNodeFormOpen}
|
||||
onFitIntoView={() => graphRef.current!.zoomToFit(1000, 50)}
|
||||
onGraphShapeChange={(shape) => graphRef.current!.setGraphShape(shape)}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import classNames from "classnames";
|
||||
import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
|
||||
import { RefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
|
||||
import { forceCollide, forceManyBody } from "d3-force-3d";
|
||||
import dynamic from "next/dynamic";
|
||||
import { GraphControlsAPI } from "./GraphControls";
|
||||
|
|
@ -16,9 +16,9 @@ const ForceGraph = dynamic(() => import("react-force-graph-2d"), {
|
|||
import type { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
|
||||
|
||||
interface GraphVisuzaliationProps {
|
||||
ref: MutableRefObject<GraphVisualizationAPI>;
|
||||
ref: RefObject<GraphVisualizationAPI>;
|
||||
data?: GraphData<NodeObject, LinkObject>;
|
||||
graphControls: MutableRefObject<GraphControlsAPI>;
|
||||
graphControls: RefObject<GraphControlsAPI>;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
|
|
@ -205,7 +205,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
function handleDagError(loopNodeIds: (string | number)[]) {}
|
||||
|
||||
const graphRef = useRef<ForceGraphMethods>();
|
||||
const graphRef = useRef<ForceGraphMethods>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (data && graphRef.current) {
|
||||
|
|
@ -224,6 +224,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
) => {
|
||||
if (!graphRef.current) {
|
||||
console.warn("GraphVisualization: graphRef not ready yet");
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return undefined as any;
|
||||
}
|
||||
|
||||
|
|
@ -239,7 +240,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
return (
|
||||
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
|
||||
<ForceGraph
|
||||
ref={graphRef}
|
||||
ref={graphRef as RefObject<ForceGraphMethods>}
|
||||
width={dimensions.width}
|
||||
height={dimensions.height}
|
||||
dagMode={graphShape as unknown as undefined}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"use server";
|
||||
"use client";
|
||||
|
||||
import Dashboard from "./Dashboard";
|
||||
|
||||
export default async function Page() {
|
||||
export default function Page() {
|
||||
const accessToken = "";
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
export { default } from "./dashboard/page";
|
||||
|
||||
// export const dynamic = "force-dynamic";
|
||||
export const dynamic = "force-dynamic";
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ export interface Dataset {
|
|||
|
||||
function useDatasets(useCloud = false) {
|
||||
const [datasets, setDatasets] = useState<Dataset[]>([]);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
// const statusTimeout = useRef<any>(null);
|
||||
|
||||
// const fetchDatasetStatuses = useCallback((datasets: Dataset[]) => {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import { NextResponse, type NextRequest } from "next/server";
|
|||
// import { auth0 } from "./modules/auth/auth0";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
export async function middleware(request: NextRequest) {
|
||||
export async function proxy(request: NextRequest) {
|
||||
// if (process.env.USE_AUTH0_AUTHORIZATION?.toLowerCase() === "true") {
|
||||
// if (request.nextUrl.pathname === "/auth/token") {
|
||||
// return NextResponse.next();
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { LoadingIndicator } from "@/ui/App";
|
||||
import { fetch, useBoolean } from "@/utils";
|
||||
import { CTAButton, TextArea } from "@/ui/elements";
|
||||
|
||||
interface SignInFormPayload extends HTMLFormElement {
|
||||
feedback: HTMLTextAreaElement;
|
||||
}
|
||||
|
||||
interface FeedbackFormProps {
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
export default function FeedbackForm({ onSuccess }: FeedbackFormProps) {
|
||||
const {
|
||||
value: isSubmittingFeedback,
|
||||
setTrue: disableFeedbackSubmit,
|
||||
setFalse: enableFeedbackSubmit,
|
||||
} = useBoolean(false);
|
||||
|
||||
const [feedbackError, setFeedbackError] = useState<string | null>(null);
|
||||
|
||||
const signIn = (event: React.FormEvent<SignInFormPayload>) => {
|
||||
event.preventDefault();
|
||||
const formElements = event.currentTarget;
|
||||
|
||||
setFeedbackError(null);
|
||||
disableFeedbackSubmit();
|
||||
|
||||
fetch("/v1/crewai/feedback", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
feedback: formElements.feedback.value,
|
||||
}),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(() => {
|
||||
onSuccess();
|
||||
formElements.feedback.value = "";
|
||||
})
|
||||
.catch(error => setFeedbackError(error.detail))
|
||||
.finally(() => enableFeedbackSubmit());
|
||||
};
|
||||
|
||||
return (
|
||||
<form onSubmit={signIn} className="flex flex-col gap-2">
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="mb-4">
|
||||
<label className="block text-white" htmlFor="feedback">Feedback on agent's reasoning</label>
|
||||
<TextArea id="feedback" name="feedback" type="text" placeholder="Your feedback" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<CTAButton type="submit">
|
||||
<span>Submit feedback</span>
|
||||
{isSubmittingFeedback && <LoadingIndicator />}
|
||||
</CTAButton>
|
||||
|
||||
{feedbackError && (
|
||||
<span className="text-s text-white">{feedbackError}</span>
|
||||
)}
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
|
@ -3,4 +3,3 @@ export { default as Footer } from "./Footer/Footer";
|
|||
export { default as SearchView } from "./SearchView/SearchView";
|
||||
export { default as IFrameView } from "./IFrameView/IFrameView";
|
||||
// export { default as Explorer } from "./Explorer/Explorer";
|
||||
export { default as FeedbackForm } from "./FeedbackForm";
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import { v4 as uuid4 } from "uuid";
|
||||
import classNames from "classnames";
|
||||
import { Fragment, MouseEvent, MutableRefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||
import { Fragment, MouseEvent, RefObject, useCallback, useEffect, useRef, useState } from "react";
|
||||
|
||||
import { useModal } from "@/ui/elements/Modal";
|
||||
import { CaretIcon, CloseIcon, PlusIcon } from "@/ui/Icons";
|
||||
|
|
@ -282,7 +282,7 @@ export default function Notebook({ notebook, updateNotebook, runCell }: Notebook
|
|||
function CellResult({ content }: { content: [] }) {
|
||||
const parsedContent = [];
|
||||
|
||||
const graphRef = useRef<GraphVisualizationAPI>();
|
||||
const graphRef = useRef<GraphVisualizationAPI>(null);
|
||||
const graphControls = useRef<GraphControlsAPI>({
|
||||
setSelectedNode: () => {},
|
||||
getSelectedNode: () => null,
|
||||
|
|
@ -298,7 +298,7 @@ function CellResult({ content }: { content: [] }) {
|
|||
<span className="text-sm pl-2 mb-4">reasoning graph</span>
|
||||
<GraphVisualization
|
||||
data={transformInsightsGraphData(line)}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
ref={graphRef as RefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-80"
|
||||
/>
|
||||
|
|
@ -346,7 +346,7 @@ function CellResult({ content }: { content: [] }) {
|
|||
<span className="text-sm pl-2 mb-4">reasoning graph (datasets: {datasetName})</span>
|
||||
<GraphVisualization
|
||||
data={transformToVisualizationData(graph)}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
ref={graphRef as RefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-80"
|
||||
/>
|
||||
|
|
@ -377,7 +377,7 @@ function CellResult({ content }: { content: [] }) {
|
|||
<span className="text-sm pl-2 mb-4">reasoning graph (datasets: {datasetName})</span>
|
||||
<GraphVisualization
|
||||
data={transformToVisualizationData(graph)}
|
||||
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
|
||||
ref={graphRef as RefObject<GraphVisualizationAPI>}
|
||||
graphControls={graphControls}
|
||||
className="min-h-80"
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ export default function NotebookCellHeader({
|
|||
setFalse: setIsNotRunningCell,
|
||||
} = useBoolean(false);
|
||||
|
||||
const [runInstance, setRunInstance] = useState<string>(isCloudEnvironment() ? "cloud" : "local");
|
||||
const [runInstance] = useState<string>(isCloudEnvironment() ? "cloud" : "local");
|
||||
|
||||
const handleCellRun = () => {
|
||||
if (runCell) {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
"moduleResolution": "bundler",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"jsx": "preserve",
|
||||
"jsx": "react-jsx",
|
||||
"incremental": true,
|
||||
"plugins": [
|
||||
{
|
||||
|
|
@ -32,7 +32,8 @@
|
|||
"next-env.d.ts",
|
||||
"**/*.ts",
|
||||
"**/*.tsx",
|
||||
".next/types/**/*.ts"
|
||||
".next/types/**/*.ts",
|
||||
".next/dev/types/**/*.ts"
|
||||
],
|
||||
"exclude": [
|
||||
"node_modules"
|
||||
|
|
|
|||
|
|
@ -90,97 +90,6 @@ async def health_check(request):
|
|||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def cognee_add_developer_rules(
|
||||
base_path: str = ".", graph_model_file: str = None, graph_model_name: str = None
|
||||
) -> list:
|
||||
"""
|
||||
Ingest core developer rule files into Cognee's memory layer.
|
||||
|
||||
This function loads a predefined set of developer-related configuration,
|
||||
rule, and documentation files from the base repository and assigns them
|
||||
to the special 'developer_rules' node set in Cognee. It ensures these
|
||||
foundational files are always part of the structured memory graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_path : str
|
||||
Root path to resolve relative file paths. Defaults to current directory.
|
||||
|
||||
graph_model_file : str, optional
|
||||
Optional path to a custom schema file for knowledge graph generation.
|
||||
|
||||
graph_model_name : str, optional
|
||||
Optional class name to use from the graph_model_file schema.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A message indicating how many rule files were scheduled for ingestion,
|
||||
and how to check their processing status.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Each file is processed asynchronously in the background.
|
||||
- Files are attached to the 'developer_rules' node set.
|
||||
- Missing files are skipped with a logged warning.
|
||||
"""
|
||||
|
||||
developer_rule_paths = [
|
||||
".cursorrules",
|
||||
".cursor/rules",
|
||||
".same/todos.md",
|
||||
".windsurfrules",
|
||||
".clinerules",
|
||||
"CLAUDE.md",
|
||||
".sourcegraph/memory.md",
|
||||
"AGENT.md",
|
||||
"AGENTS.md",
|
||||
]
|
||||
|
||||
async def cognify_task(file_path: str) -> None:
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info(f"Starting cognify for: {file_path}")
|
||||
try:
|
||||
await cognee_client.add(file_path, node_set=["developer_rules"])
|
||||
|
||||
model = None
|
||||
if graph_model_file and graph_model_name:
|
||||
if cognee_client.use_api:
|
||||
logger.warning(
|
||||
"Custom graph models are not supported in API mode, ignoring."
|
||||
)
|
||||
else:
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
||||
model = load_class(graph_model_file, graph_model_name)
|
||||
|
||||
await cognee_client.cognify(graph_model=model)
|
||||
logger.info(f"Cognify finished for: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Cognify failed for {file_path}: {str(e)}")
|
||||
raise ValueError(f"Failed to cognify: {str(e)}")
|
||||
|
||||
tasks = []
|
||||
for rel_path in developer_rule_paths:
|
||||
abs_path = os.path.join(base_path, rel_path)
|
||||
if os.path.isfile(abs_path):
|
||||
tasks.append(asyncio.create_task(cognify_task(abs_path)))
|
||||
else:
|
||||
logger.warning(f"Skipped missing developer rule file: {abs_path}")
|
||||
log_file = get_log_file_location()
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=(
|
||||
f"Started cognify for {len(tasks)} developer rule files in background.\n"
|
||||
f"All are added to the `developer_rules` node set.\n"
|
||||
f"Use `cognify_status` or check logs at {log_file} to monitor progress."
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def cognify(
|
||||
data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None
|
||||
|
|
@ -406,75 +315,6 @@ async def save_interaction(data: str) -> list:
|
|||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def codify(repo_path: str) -> list:
|
||||
"""
|
||||
Analyze and generate a code-specific knowledge graph from a software repository.
|
||||
|
||||
This function launches a background task that processes the provided repository
|
||||
and builds a code knowledge graph. The function returns immediately while
|
||||
the processing continues in the background due to MCP timeout constraints.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
repo_path : str
|
||||
Path to the code repository to analyze. This can be a local file path or a
|
||||
relative path to a repository. The path should point to the root of the
|
||||
repository or a specific directory within it.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list containing a single TextContent object with information about the
|
||||
background task launch and how to check its status.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The function launches a background task and returns immediately
|
||||
- The code graph generation may take significant time for larger repositories
|
||||
- Use the codify_status tool to check the progress of the operation
|
||||
- Process results are logged to the standard Cognee log file
|
||||
- All stdout is redirected to stderr to maintain MCP communication integrity
|
||||
"""
|
||||
|
||||
if cognee_client.use_api:
|
||||
error_msg = "❌ Codify operation is not available in API mode. Please use direct mode for code graph pipeline."
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
async def codify_task(repo_path: str):
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Codify process starting.")
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
results = []
|
||||
async for result in run_code_graph_pipeline(repo_path, False):
|
||||
results.append(result)
|
||||
logger.info(result)
|
||||
if all(results):
|
||||
logger.info("Codify process finished succesfully.")
|
||||
else:
|
||||
logger.info("Codify process failed.")
|
||||
|
||||
asyncio.create_task(codify_task(repo_path))
|
||||
|
||||
log_file = get_log_file_location()
|
||||
text = (
|
||||
f"Background process launched due to MCP timeout limitations.\n"
|
||||
f"To check current codify status use the codify_status tool\n"
|
||||
f"or you can check the log file at: {log_file}"
|
||||
)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=text,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search(search_query: str, search_type: str) -> list:
|
||||
"""
|
||||
|
|
@ -629,45 +469,6 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
return [types.TextContent(type="text", text=search_results)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_developer_rules() -> list:
|
||||
"""
|
||||
Retrieve all developer rules that were generated based on previous interactions.
|
||||
|
||||
This tool queries the Cognee knowledge graph and returns a list of developer
|
||||
rules.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list containing a single TextContent object with the retrieved developer rules.
|
||||
The format is plain text containing the developer rules in bulletpoints.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The specific logic for fetching rules is handled internally.
|
||||
- This tool does not accept any parameters and is intended for simple rule inspection use cases.
|
||||
"""
|
||||
|
||||
async def fetch_rules_from_cognee() -> str:
|
||||
"""Collect all developer rules from Cognee"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
if cognee_client.use_api:
|
||||
logger.warning("Developer rules retrieval is not available in API mode")
|
||||
return "Developer rules retrieval is not available in API mode"
|
||||
|
||||
developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules")
|
||||
return developer_rules
|
||||
|
||||
rules_text = await fetch_rules_from_cognee()
|
||||
|
||||
return [types.TextContent(type="text", text=rules_text)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def list_data(dataset_id: str = None) -> list:
|
||||
"""
|
||||
|
|
@ -953,48 +754,6 @@ async def cognify_status():
|
|||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def codify_status():
|
||||
"""
|
||||
Get the current status of the codify pipeline.
|
||||
|
||||
This function retrieves information about current and recently completed codify operations
|
||||
in the codebase dataset. It provides details on progress, success/failure status, and statistics
|
||||
about the processed code repositories.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list containing a single TextContent object with the status information as a string.
|
||||
The status includes information about active and completed jobs for the cognify_code_pipeline.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The function retrieves pipeline status specifically for the "cognify_code_pipeline" on the "codebase" dataset
|
||||
- Status information includes job progress, execution time, and completion status
|
||||
- The status is returned in string format for easy reading
|
||||
- This operation is not available in API mode
|
||||
"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
try:
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
user = await get_default_user()
|
||||
status = await cognee_client.get_pipeline_status(
|
||||
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
|
||||
)
|
||||
return [types.TextContent(type="text", text=str(status))]
|
||||
except NotImplementedError:
|
||||
error_msg = "❌ Pipeline status is not available in API mode"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Failed to get codify status: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
def node_to_string(node):
|
||||
node_data = ", ".join(
|
||||
[f'{key}: "{value}"' for key, value in node.items() if key in ["id", "name"]]
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from cognee.api.v1.notebooks.routers import get_notebooks_router
|
|||
from cognee.api.v1.permissions.routers import get_permissions_router
|
||||
from cognee.api.v1.settings.routers import get_settings_router
|
||||
from cognee.api.v1.datasets.routers import get_datasets_router
|
||||
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
|
||||
from cognee.api.v1.cognify.routers import get_cognify_router
|
||||
from cognee.api.v1.search.routers import get_search_router
|
||||
from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router
|
||||
from cognee.api.v1.memify.routers import get_memify_router
|
||||
|
|
@ -278,10 +278,6 @@ app.include_router(get_responses_router(), prefix="/api/v1/responses", tags=["re
|
|||
|
||||
app.include_router(get_sync_router(), prefix="/api/v1/sync", tags=["sync"])
|
||||
|
||||
codegraph_routes = get_code_pipeline_router()
|
||||
if codegraph_routes:
|
||||
app.include_router(codegraph_routes, prefix="/api/v1/code-pipeline", tags=["code-pipeline"])
|
||||
|
||||
app.include_router(
|
||||
get_users_router(),
|
||||
prefix="/api/v1/users",
|
||||
|
|
|
|||
|
|
@ -1,119 +0,0 @@
|
|||
import os
|
||||
import pathlib
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
|
||||
from cognee.api.v1.search import SearchType, search
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.ingestion import ingest_data
|
||||
from cognee.tasks.repo_processor import get_non_py_files, get_repo_file_dependencies
|
||||
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
observe = get_observe()
|
||||
|
||||
logger = get_logger("code_graph_pipeline")
|
||||
|
||||
|
||||
@observe
|
||||
async def run_code_graph_pipeline(
|
||||
repo_path,
|
||||
include_docs=False,
|
||||
excluded_paths: Optional[list[str]] = None,
|
||||
supported_languages: Optional[list[str]] = None,
|
||||
):
|
||||
import cognee
|
||||
from cognee.low_level import setup
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
cognee_config = get_cognify_config()
|
||||
user = await get_default_user()
|
||||
detailed_extraction = True
|
||||
|
||||
tasks = [
|
||||
Task(
|
||||
get_repo_file_dependencies,
|
||||
detailed_extraction=detailed_extraction,
|
||||
supported_languages=supported_languages,
|
||||
excluded_paths=excluded_paths,
|
||||
),
|
||||
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
|
||||
Task(add_data_points, task_config={"batch_size": 30}),
|
||||
]
|
||||
|
||||
if include_docs:
|
||||
# This tasks take a long time to complete
|
||||
non_code_tasks = [
|
||||
Task(get_non_py_files, task_config={"batch_size": 50}),
|
||||
Task(ingest_data, dataset_name="repo_docs", user=user),
|
||||
Task(classify_documents),
|
||||
Task(extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()),
|
||||
Task(
|
||||
extract_graph_from_data,
|
||||
graph_model=KnowledgeGraph,
|
||||
task_config={"batch_size": 50},
|
||||
),
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model=cognee_config.summarization_model,
|
||||
task_config={"batch_size": 50},
|
||||
),
|
||||
]
|
||||
|
||||
dataset_name = "codebase"
|
||||
|
||||
# Save dataset to database
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
dataset = await create_dataset(dataset_name, user, session)
|
||||
|
||||
if include_docs:
|
||||
non_code_pipeline_run = run_tasks(
|
||||
non_code_tasks, dataset.id, repo_path, user, "cognify_pipeline"
|
||||
)
|
||||
async for run_status in non_code_pipeline_run:
|
||||
yield run_status
|
||||
|
||||
async for run_status in run_tasks(
|
||||
tasks, dataset.id, repo_path, user, "cognify_code_pipeline", incremental_loading=False
|
||||
):
|
||||
yield run_status
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
async for run_status in run_code_graph_pipeline("REPO_PATH"):
|
||||
print(f"{run_status.pipeline_run_id}: {run_status.status}")
|
||||
|
||||
file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
|
||||
)
|
||||
await visualize_graph(file_path)
|
||||
|
||||
search_results = await search(
|
||||
query_type=SearchType.CODE,
|
||||
query_text="How is Relationship weight calculated?",
|
||||
)
|
||||
|
||||
for file in search_results:
|
||||
print(file["name"])
|
||||
|
||||
logger = setup_logging(name="code_graph_pipeline")
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,2 +1 @@
|
|||
from .get_cognify_router import get_cognify_router
|
||||
from .get_code_pipeline_router import get_code_pipeline_router
|
||||
|
|
|
|||
|
|
@ -1,90 +0,0 @@
|
|||
import json
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from cognee.api.DTO import InDTO
|
||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class CodePipelineIndexPayloadDTO(InDTO):
|
||||
repo_path: str
|
||||
include_docs: bool = False
|
||||
|
||||
|
||||
class CodePipelineRetrievePayloadDTO(InDTO):
|
||||
query: str
|
||||
full_input: str
|
||||
|
||||
|
||||
def get_code_pipeline_router() -> APIRouter:
|
||||
try:
|
||||
import cognee.api.v1.cognify.code_graph_pipeline
|
||||
except ModuleNotFoundError:
|
||||
logger.error("codegraph dependencies not found. Skipping codegraph API routes.")
|
||||
return None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/index", response_model=None)
|
||||
async def code_pipeline_index(payload: CodePipelineIndexPayloadDTO):
|
||||
"""
|
||||
Run indexation on a code repository.
|
||||
|
||||
This endpoint processes a code repository to create a knowledge graph
|
||||
of the codebase structure, dependencies, and relationships.
|
||||
|
||||
## Request Parameters
|
||||
- **repo_path** (str): Path to the code repository
|
||||
- **include_docs** (bool): Whether to include documentation files (default: false)
|
||||
|
||||
## Response
|
||||
No content returned. Processing results are logged.
|
||||
|
||||
## Error Codes
|
||||
- **409 Conflict**: Error during indexation process
|
||||
"""
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
try:
|
||||
async for result in run_code_graph_pipeline(payload.repo_path, payload.include_docs):
|
||||
logger.info(result)
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
@router.post("/retrieve", response_model=list[dict])
|
||||
async def code_pipeline_retrieve(payload: CodePipelineRetrievePayloadDTO):
|
||||
"""
|
||||
Retrieve context from the code knowledge graph.
|
||||
|
||||
This endpoint searches the indexed code repository to find relevant
|
||||
context based on the provided query.
|
||||
|
||||
## Request Parameters
|
||||
- **query** (str): Search query for code context
|
||||
- **full_input** (str): Full input text for processing
|
||||
|
||||
## Response
|
||||
Returns a list of relevant code files and context as JSON.
|
||||
|
||||
## Error Codes
|
||||
- **409 Conflict**: Error during retrieval process
|
||||
"""
|
||||
try:
|
||||
query = (
|
||||
payload.full_input.replace("cognee ", "")
|
||||
if payload.full_input.startswith("cognee ")
|
||||
else payload.full_input
|
||||
)
|
||||
|
||||
retriever = CodeRetriever()
|
||||
retrieved_files = await retriever.get_context(query)
|
||||
|
||||
return json.dumps(retrieved_files, cls=JSONEncoder)
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return router
|
||||
360
cognee/api/v1/ui/node_setup.py
Normal file
360
cognee/api/v1/ui/node_setup.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def get_nvm_dir() -> Path:
|
||||
"""
|
||||
Get the nvm directory path following standard nvm installation logic.
|
||||
Uses XDG_CONFIG_HOME if set, otherwise falls back to ~/.nvm.
|
||||
"""
|
||||
xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
|
||||
if xdg_config_home:
|
||||
return Path(xdg_config_home) / "nvm"
|
||||
return Path.home() / ".nvm"
|
||||
|
||||
|
||||
def get_nvm_sh_path() -> Path:
|
||||
"""
|
||||
Get the path to nvm.sh following standard nvm installation logic.
|
||||
"""
|
||||
return get_nvm_dir() / "nvm.sh"
|
||||
|
||||
|
||||
def check_nvm_installed() -> bool:
|
||||
"""
|
||||
Check if nvm (Node Version Manager) is installed.
|
||||
"""
|
||||
try:
|
||||
# Check if nvm is available in the shell
|
||||
# nvm is typically sourced in shell config files, so we need to check via shell
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, nvm-windows uses a different approach
|
||||
result = subprocess.run(
|
||||
["nvm", "version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
# On Unix-like systems, nvm is a shell function, so we need to source it
|
||||
# First check if nvm.sh exists
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if not nvm_path.exists():
|
||||
logger.debug(f"nvm.sh not found at {nvm_path}")
|
||||
return False
|
||||
|
||||
# Try to source nvm and check version, capturing errors
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && nvm --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
# Log the error to help diagnose configuration issues
|
||||
if result.stderr:
|
||||
logger.debug(f"nvm check failed: {result.stderr.strip()}")
|
||||
return False
|
||||
|
||||
return result.returncode == 0
|
||||
except Exception as e:
|
||||
logger.debug(f"Exception checking nvm: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def install_nvm() -> bool:
|
||||
"""
|
||||
Install nvm (Node Version Manager) on Unix-like systems.
|
||||
"""
|
||||
if platform.system() == "Windows":
|
||||
logger.error("nvm installation on Windows requires nvm-windows.")
|
||||
logger.error(
|
||||
"Please install nvm-windows manually from: https://github.com/coreybutler/nvm-windows"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info("Installing nvm (Node Version Manager)...")
|
||||
|
||||
try:
|
||||
# Download and install nvm
|
||||
nvm_install_script = "https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.3/install.sh"
|
||||
logger.info(f"Downloading nvm installer from {nvm_install_script}...")
|
||||
|
||||
response = requests.get(nvm_install_script, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
# Create a temporary script file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f:
|
||||
f.write(response.text)
|
||||
install_script_path = f.name
|
||||
|
||||
try:
|
||||
# Make the script executable and run it
|
||||
os.chmod(install_script_path, 0o755)
|
||||
result = subprocess.run(
|
||||
["bash", install_script_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("✓ nvm installed successfully")
|
||||
# Source nvm in current shell session
|
||||
nvm_dir = get_nvm_dir()
|
||||
if nvm_dir.exists():
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"nvm installation completed but nvm directory not found at {nvm_dir}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.error(f"nvm installation failed: {result.stderr}")
|
||||
return False
|
||||
finally:
|
||||
# Clean up temporary script
|
||||
try:
|
||||
os.unlink(install_script_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to download nvm installer: {str(e)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to install nvm: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def install_node_with_nvm() -> bool:
|
||||
"""
|
||||
Install the latest Node.js version using nvm.
|
||||
Returns True if installation succeeds, False otherwise.
|
||||
"""
|
||||
if platform.system() == "Windows":
|
||||
logger.error("Node.js installation via nvm on Windows requires nvm-windows.")
|
||||
logger.error("Please install Node.js manually from: https://nodejs.org/")
|
||||
return False
|
||||
|
||||
logger.info("Installing latest Node.js version using nvm...")
|
||||
|
||||
try:
|
||||
# Source nvm and install latest Node.js
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if not nvm_path.exists():
|
||||
logger.error(f"nvm.sh not found at {nvm_path}. nvm may not be properly installed.")
|
||||
return False
|
||||
|
||||
nvm_source_cmd = f"source {nvm_path}"
|
||||
install_cmd = f"{nvm_source_cmd} && nvm install node"
|
||||
|
||||
result = subprocess.run(
|
||||
["bash", "-c", install_cmd],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout for Node.js installation
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("✓ Node.js installed successfully via nvm")
|
||||
|
||||
# Set as default version
|
||||
use_cmd = f"{nvm_source_cmd} && nvm alias default node"
|
||||
subprocess.run(
|
||||
["bash", "-c", use_cmd],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
# Add nvm to PATH for current session
|
||||
# This ensures node/npm are available in subsequent commands
|
||||
nvm_dir = get_nvm_dir()
|
||||
if nvm_dir.exists():
|
||||
# Update PATH for current process
|
||||
nvm_bin = nvm_dir / "versions" / "node"
|
||||
# Find the latest installed version
|
||||
if nvm_bin.exists():
|
||||
versions = sorted(nvm_bin.iterdir(), reverse=True)
|
||||
if versions:
|
||||
latest_node_bin = versions[0] / "bin"
|
||||
if latest_node_bin.exists():
|
||||
current_path = os.environ.get("PATH", "")
|
||||
os.environ["PATH"] = f"{latest_node_bin}:{current_path}"
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to install Node.js: {result.stderr}")
|
||||
return False
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Timeout installing Node.js (this can take several minutes)")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error installing Node.js: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def check_node_npm() -> tuple[bool, str]: # (is_available, error_message)
|
||||
"""
|
||||
Check if Node.js and npm are available.
|
||||
If not available, attempts to install nvm and Node.js automatically.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Check Node.js - try direct command first, then with nvm if needed
|
||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
# If direct command fails, try with nvm sourced (in case nvm is installed but not in PATH)
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && node --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
logger.debug(f"Failed to source nvm or run node: {result.stderr.strip()}")
|
||||
if result.returncode != 0:
|
||||
# Node.js is not installed, try to install it
|
||||
logger.info("Node.js is not installed. Attempting to install automatically...")
|
||||
|
||||
# Check if nvm is installed
|
||||
if not check_nvm_installed():
|
||||
logger.info("nvm is not installed. Installing nvm first...")
|
||||
if not install_nvm():
|
||||
return (
|
||||
False,
|
||||
"Failed to install nvm. Please install Node.js manually from https://nodejs.org/",
|
||||
)
|
||||
|
||||
# Install Node.js using nvm
|
||||
if not install_node_with_nvm():
|
||||
return (
|
||||
False,
|
||||
"Failed to install Node.js. Please install Node.js manually from https://nodejs.org/",
|
||||
)
|
||||
|
||||
# Verify installation after automatic setup
|
||||
# Try with nvm sourced first
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && node --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
logger.debug(
|
||||
f"Failed to verify node after installation: {result.stderr.strip()}"
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["node", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode != 0:
|
||||
nvm_path = get_nvm_sh_path()
|
||||
return (
|
||||
False,
|
||||
f"Node.js installation completed but node command is not available. Please restart your terminal or source {nvm_path}",
|
||||
)
|
||||
|
||||
node_version = result.stdout.strip()
|
||||
logger.debug(f"Found Node.js version: {node_version}")
|
||||
|
||||
# Check npm - handle Windows PowerShell scripts
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, npm might be a PowerShell script, so we need to use shell=True
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10, shell=True
|
||||
)
|
||||
else:
|
||||
# On Unix-like systems, if we just installed via nvm, we may need to source nvm
|
||||
# Try direct command first
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode != 0:
|
||||
# Try with nvm sourced
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && npm --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
logger.debug(f"Failed to source nvm or run npm: {result.stderr.strip()}")
|
||||
|
||||
if result.returncode != 0:
|
||||
return False, "npm is not installed or not in PATH"
|
||||
|
||||
npm_version = result.stdout.strip()
|
||||
logger.debug(f"Found npm version: {npm_version}")
|
||||
|
||||
return True, f"Node.js {node_version}, npm {npm_version}"
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Timeout checking Node.js/npm installation"
|
||||
except FileNotFoundError:
|
||||
# Node.js is not installed, try to install it
|
||||
logger.info("Node.js is not found. Attempting to install automatically...")
|
||||
|
||||
# Check if nvm is installed
|
||||
if not check_nvm_installed():
|
||||
logger.info("nvm is not installed. Installing nvm first...")
|
||||
if not install_nvm():
|
||||
return (
|
||||
False,
|
||||
"Failed to install nvm. Please install Node.js manually from https://nodejs.org/",
|
||||
)
|
||||
|
||||
# Install Node.js using nvm
|
||||
if not install_node_with_nvm():
|
||||
return (
|
||||
False,
|
||||
"Failed to install Node.js. Please install Node.js manually from https://nodejs.org/",
|
||||
)
|
||||
|
||||
# Retry checking Node.js after installation
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["node", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
node_version = result.stdout.strip()
|
||||
# Check npm
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
result = subprocess.run(
|
||||
["bash", "-c", f"source {nvm_path} && npm --version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
npm_version = result.stdout.strip()
|
||||
return True, f"Node.js {node_version}, npm {npm_version}"
|
||||
elif result.stderr:
|
||||
logger.debug(f"Failed to source nvm or run npm: {result.stderr.strip()}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Exception retrying node/npm check: {str(e)}")
|
||||
|
||||
return False, "Node.js/npm not found. Please install Node.js from https://nodejs.org/"
|
||||
except Exception as e:
|
||||
return False, f"Error checking Node.js/npm: {str(e)}"
|
||||
50
cognee/api/v1/ui/npm_utils.py
Normal file
50
cognee/api/v1/ui/npm_utils.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from .node_setup import get_nvm_sh_path
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def run_npm_command(cmd: List[str], cwd: Path, timeout: int = 300) -> subprocess.CompletedProcess:
|
||||
"""
|
||||
Run an npm command, ensuring nvm is sourced if needed (Unix-like systems only).
|
||||
Returns the CompletedProcess result.
|
||||
"""
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, use shell=True for npm commands
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
# On Unix-like systems, try direct command first
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
# If it fails and nvm might be installed, try with nvm sourced
|
||||
if result.returncode != 0:
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if nvm_path.exists():
|
||||
nvm_cmd = f"source {nvm_path} && {' '.join(cmd)}"
|
||||
result = subprocess.run(
|
||||
["bash", "-c", nvm_cmd],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
logger.debug(f"npm command failed with nvm: {result.stderr.strip()}")
|
||||
return result
|
||||
|
|
@ -15,6 +15,8 @@ import shutil
|
|||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.version import get_cognee_version
|
||||
from .node_setup import check_node_npm, get_nvm_dir, get_nvm_sh_path
|
||||
from .npm_utils import run_npm_command
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -285,48 +287,6 @@ def find_frontend_path() -> Optional[Path]:
|
|||
return None
|
||||
|
||||
|
||||
def check_node_npm() -> tuple[bool, str]:
|
||||
"""
|
||||
Check if Node.js and npm are available.
|
||||
Returns (is_available, error_message)
|
||||
"""
|
||||
|
||||
try:
|
||||
# Check Node.js
|
||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
return False, "Node.js is not installed or not in PATH"
|
||||
|
||||
node_version = result.stdout.strip()
|
||||
logger.debug(f"Found Node.js version: {node_version}")
|
||||
|
||||
# Check npm - handle Windows PowerShell scripts
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, npm might be a PowerShell script, so we need to use shell=True
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10, shell=True
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return False, "npm is not installed or not in PATH"
|
||||
|
||||
npm_version = result.stdout.strip()
|
||||
logger.debug(f"Found npm version: {npm_version}")
|
||||
|
||||
return True, f"Node.js {node_version}, npm {npm_version}"
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Timeout checking Node.js/npm installation"
|
||||
except FileNotFoundError:
|
||||
return False, "Node.js/npm not found. Please install Node.js from https://nodejs.org/"
|
||||
except Exception as e:
|
||||
return False, f"Error checking Node.js/npm: {str(e)}"
|
||||
|
||||
|
||||
def install_frontend_dependencies(frontend_path: Path) -> bool:
|
||||
"""
|
||||
Install frontend dependencies if node_modules doesn't exist.
|
||||
|
|
@ -341,24 +301,7 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
|
|||
logger.info("Installing frontend dependencies (this may take a few minutes)...")
|
||||
|
||||
try:
|
||||
# Use shell=True on Windows for npm commands
|
||||
if platform.system() == "Windows":
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout
|
||||
)
|
||||
result = run_npm_command(["npm", "install"], frontend_path, timeout=300)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("Frontend dependencies installed successfully")
|
||||
|
|
@ -642,6 +585,21 @@ def start_ui(
|
|||
env["HOST"] = "localhost"
|
||||
env["PORT"] = str(port)
|
||||
|
||||
# If nvm is installed, ensure it's available in the environment
|
||||
nvm_path = get_nvm_sh_path()
|
||||
if platform.system() != "Windows" and nvm_path.exists():
|
||||
# Add nvm to PATH for the subprocess
|
||||
nvm_dir = get_nvm_dir()
|
||||
# Find the latest Node.js version installed via nvm
|
||||
nvm_versions = nvm_dir / "versions" / "node"
|
||||
if nvm_versions.exists():
|
||||
versions = sorted(nvm_versions.iterdir(), reverse=True)
|
||||
if versions:
|
||||
latest_node_bin = versions[0] / "bin"
|
||||
if latest_node_bin.exists():
|
||||
current_path = env.get("PATH", "")
|
||||
env["PATH"] = f"{latest_node_bin}:{current_path}"
|
||||
|
||||
# Start the development server
|
||||
logger.info(f"Starting frontend server at http://localhost:{port}")
|
||||
logger.info("This may take a moment to compile and start...")
|
||||
|
|
@ -659,14 +617,26 @@ def start_ui(
|
|||
shell=True,
|
||||
)
|
||||
else:
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
# On Unix-like systems, use bash with nvm sourced if available
|
||||
if nvm_path.exists():
|
||||
# Use bash to source nvm and run npm
|
||||
process = subprocess.Popen(
|
||||
["bash", "-c", f"source {nvm_path} && npm run dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
else:
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
|
||||
# Start threads to stream frontend output with prefix
|
||||
_stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from cognee.infrastructure.databases.exceptions import EmbeddingException
|
|||
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
||||
TikTokenTokenizer,
|
||||
)
|
||||
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||
|
||||
litellm.set_verbose = False
|
||||
logger = get_logger("FastembedEmbeddingEngine")
|
||||
|
|
@ -68,7 +69,7 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -96,11 +97,12 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
if self.mock:
|
||||
return [[0.0] * self.dimensions for _ in text]
|
||||
else:
|
||||
embeddings = self.embedding_model.embed(
|
||||
text,
|
||||
batch_size=len(text),
|
||||
parallel=None,
|
||||
)
|
||||
async with embedding_rate_limiter_context_manager():
|
||||
embeddings = self.embedding_model.embed(
|
||||
text,
|
||||
batch_size=len(text),
|
||||
parallel=None,
|
||||
)
|
||||
|
||||
return list(embeddings)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from cognee.infrastructure.llm.tokenizer.Mistral import (
|
|||
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
||||
TikTokenTokenizer,
|
||||
)
|
||||
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||
|
||||
litellm.set_verbose = False
|
||||
logger = get_logger("LiteLLMEmbeddingEngine")
|
||||
|
|
@ -109,13 +110,14 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
|
||||
return [data["embedding"] for data in response["data"]]
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model=self.model,
|
||||
input=text,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
async with embedding_rate_limiter_context_manager():
|
||||
response = await litellm.aembedding(
|
||||
model=self.model,
|
||||
input=text,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
|
||||
return [data["embedding"] for data in response.data]
|
||||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,7 @@ from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import Em
|
|||
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
||||
HuggingFaceTokenizer,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
|
||||
embedding_rate_limit_async,
|
||||
embedding_sleep_and_retry_async,
|
||||
)
|
||||
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||
from cognee.shared.utils import create_secure_ssl_context
|
||||
|
||||
logger = get_logger("OllamaEmbeddingEngine")
|
||||
|
|
@ -101,7 +98,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -120,14 +117,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
async with session.post(
|
||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||
) as response:
|
||||
data = await response.json()
|
||||
if "embeddings" in data:
|
||||
return data["embeddings"][0]
|
||||
else:
|
||||
return data["data"][0]["embedding"]
|
||||
async with embedding_rate_limiter_context_manager():
|
||||
async with session.post(
|
||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||
) as response:
|
||||
data = await response.json()
|
||||
if "embeddings" in data:
|
||||
return data["embeddings"][0]
|
||||
else:
|
||||
return data["data"][0]["embedding"]
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,544 +0,0 @@
|
|||
import threading
|
||||
import logging
|
||||
import functools
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import random
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Common error patterns that indicate rate limiting
|
||||
RATE_LIMIT_ERROR_PATTERNS = [
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"ratelimit",
|
||||
"too many requests",
|
||||
"retry after",
|
||||
"capacity",
|
||||
"quota",
|
||||
"limit exceeded",
|
||||
"tps limit exceeded",
|
||||
"request limit exceeded",
|
||||
"maximum requests",
|
||||
"exceeded your current quota",
|
||||
"throttled",
|
||||
"throttling",
|
||||
]
|
||||
|
||||
# Default retry settings
|
||||
DEFAULT_MAX_RETRIES = 5
|
||||
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
|
||||
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
|
||||
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
|
||||
|
||||
|
||||
class EmbeddingRateLimiter:
|
||||
"""
|
||||
Rate limiter for embedding API calls.
|
||||
|
||||
This class implements a singleton pattern to ensure that rate limiting
|
||||
is consistent across all embedding requests. It uses the limits
|
||||
library with a moving window strategy to control request rates.
|
||||
|
||||
The rate limiter uses the same configuration as the LLM API rate limiter
|
||||
but uses a separate key to track embedding API calls independently.
|
||||
|
||||
Public Methods:
|
||||
- get_instance
|
||||
- reset_instance
|
||||
- hit_limit
|
||||
- wait_if_needed
|
||||
- async_wait_if_needed
|
||||
|
||||
Instance Variables:
|
||||
- enabled
|
||||
- requests_limit
|
||||
- interval_seconds
|
||||
- request_times
|
||||
- lock
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""
|
||||
Retrieve the singleton instance of the EmbeddingRateLimiter.
|
||||
|
||||
This method ensures that only one instance of the class exists and
|
||||
is thread-safe. It lazily initializes the instance if it doesn't
|
||||
already exist.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The singleton instance of the EmbeddingRateLimiter class.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls.lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls):
|
||||
"""
|
||||
Reset the singleton instance of the EmbeddingRateLimiter.
|
||||
|
||||
This method is thread-safe and sets the instance to None, allowing
|
||||
for a new instance to be created when requested again.
|
||||
"""
|
||||
with cls.lock:
|
||||
cls._instance = None
|
||||
|
||||
def __init__(self):
|
||||
config = get_llm_config()
|
||||
self.enabled = config.embedding_rate_limit_enabled
|
||||
self.requests_limit = config.embedding_rate_limit_requests
|
||||
self.interval_seconds = config.embedding_rate_limit_interval
|
||||
self.request_times = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
logging.info(
|
||||
f"EmbeddingRateLimiter initialized: enabled={self.enabled}, "
|
||||
f"requests_limit={self.requests_limit}, interval_seconds={self.interval_seconds}"
|
||||
)
|
||||
|
||||
def hit_limit(self) -> bool:
|
||||
"""
|
||||
Check if the current request would exceed the rate limit.
|
||||
|
||||
This method checks if the rate limiter is enabled and evaluates
|
||||
the number of requests made in the elapsed interval.
|
||||
|
||||
Returns:
|
||||
- bool: True if the rate limit would be exceeded, False otherwise.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the rate limit would be exceeded, otherwise False.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
with self.lock:
|
||||
# Remove expired request times
|
||||
cutoff_time = current_time - self.interval_seconds
|
||||
self.request_times = [t for t in self.request_times if t > cutoff_time]
|
||||
|
||||
# Check if adding a new request would exceed the limit
|
||||
if len(self.request_times) >= self.requests_limit:
|
||||
logger.info(
|
||||
f"Rate limit hit: {len(self.request_times)} requests in the last {self.interval_seconds} seconds"
|
||||
)
|
||||
return True
|
||||
|
||||
# Otherwise, we're under the limit
|
||||
return False
|
||||
|
||||
def wait_if_needed(self) -> float:
|
||||
"""
|
||||
Block until a request can be made without exceeding the rate limit.
|
||||
|
||||
This method will wait if the current request would exceed the
|
||||
rate limit and returns the time waited in seconds.
|
||||
|
||||
Returns:
|
||||
- float: Time waited in seconds before a request is allowed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- float: Time waited in seconds before proceeding.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
wait_time = 0
|
||||
start_time = time.time()
|
||||
|
||||
while self.hit_limit():
|
||||
time.sleep(0.5) # Poll every 0.5 seconds
|
||||
wait_time = time.time() - start_time
|
||||
|
||||
# Record this request
|
||||
with self.lock:
|
||||
self.request_times.append(time.time())
|
||||
|
||||
return wait_time
|
||||
|
||||
async def async_wait_if_needed(self) -> float:
|
||||
"""
|
||||
Asynchronously wait until a request can be made without exceeding the rate limit.
|
||||
|
||||
This method will wait if the current request would exceed the
|
||||
rate limit and returns the time waited in seconds.
|
||||
|
||||
Returns:
|
||||
- float: Time waited in seconds before a request is allowed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- float: Time waited in seconds before proceeding.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
wait_time = 0
|
||||
start_time = time.time()
|
||||
|
||||
while self.hit_limit():
|
||||
await asyncio.sleep(0.5) # Poll every 0.5 seconds
|
||||
wait_time = time.time() - start_time
|
||||
|
||||
# Record this request
|
||||
with self.lock:
|
||||
self.request_times.append(time.time())
|
||||
|
||||
return wait_time
|
||||
|
||||
|
||||
def embedding_rate_limit_sync(func):
|
||||
"""
|
||||
Apply rate limiting to a synchronous embedding function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: Function to decorate with rate limiting logic.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the decorated function that applies rate limiting.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Wrap the given function with rate limiting logic to control the embedding API usage.
|
||||
|
||||
Checks if the rate limit has been exceeded before allowing the function to execute. If
|
||||
the limit is hit, it logs a warning and raises an EmbeddingException. Otherwise, it
|
||||
updates the request count and proceeds to call the original function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Variable length argument list for the wrapped function.
|
||||
- **kwargs: Keyword arguments for the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function if rate limiting conditions are met.
|
||||
"""
|
||||
limiter = EmbeddingRateLimiter.get_instance()
|
||||
|
||||
# Check if rate limiting is enabled and if we're at the limit
|
||||
if limiter.hit_limit():
|
||||
error_msg = "Embedding API rate limit exceeded"
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Create a custom embedding rate limit exception
|
||||
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
||||
|
||||
raise EmbeddingException(error_msg)
|
||||
|
||||
# Add this request to the counter and proceed
|
||||
limiter.wait_if_needed()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def embedding_rate_limit_async(func):
|
||||
"""
|
||||
Decorator that applies rate limiting to an asynchronous embedding function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: Async function to decorate.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the decorated async function that applies rate limiting.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Handle function calls with embedding rate limiting.
|
||||
|
||||
This asynchronous wrapper checks if the embedding API rate limit is exceeded before
|
||||
allowing the function to execute. If the limit is exceeded, it logs a warning and raises
|
||||
an EmbeddingException. If not, it waits as necessary and proceeds with the function
|
||||
call.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function after handling rate limiting.
|
||||
"""
|
||||
limiter = EmbeddingRateLimiter.get_instance()
|
||||
|
||||
# Check if rate limiting is enabled and if we're at the limit
|
||||
if limiter.hit_limit():
|
||||
error_msg = "Embedding API rate limit exceeded"
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Create a custom embedding rate limit exception
|
||||
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
||||
|
||||
raise EmbeddingException(error_msg)
|
||||
|
||||
# Add this request to the counter and proceed
|
||||
await limiter.async_wait_if_needed()
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def embedding_sleep_and_retry_sync(max_retries=5, base_backoff=1.0, jitter=0.5):
|
||||
"""
|
||||
Add retry with exponential backoff for synchronous embedding functions.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- max_retries: Maximum number of retries before giving up. (default 5)
|
||||
- base_backoff: Base backoff time in seconds for retry intervals. (default 1.0)
|
||||
- jitter: Jitter factor to randomize the backoff time to avoid collision. (default
|
||||
0.5)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
A decorator that retries the wrapped function on rate limit errors, applying
|
||||
exponential backoff with jitter.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
"""
|
||||
Wraps a function to apply retry logic on rate limit errors.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: The function to be wrapped with retry logic.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the wrapped function with retry logic applied.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Retry the execution of a function with backoff on failure due to rate limit errors.
|
||||
|
||||
This wrapper function will call the specified function and if it raises an exception, it
|
||||
will handle retries according to defined conditions. It will check the environment for a
|
||||
DISABLE_RETRIES flag to determine whether to retry or propagate errors immediately
|
||||
during tests. If the error is identified as a rate limit error, it will apply an
|
||||
exponential backoff strategy with jitter before retrying, up to a maximum number of
|
||||
retries. If the retries are exhausted, it raises the last encountered error.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function if successful; otherwise, raises the last
|
||||
error encountered after maximum retries are exhausted.
|
||||
"""
|
||||
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
||||
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
retries = 0
|
||||
last_error = None
|
||||
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Check if this is a rate limit error
|
||||
error_str = str(e).lower()
|
||||
error_type = type(e).__name__
|
||||
is_rate_limit = any(
|
||||
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
||||
)
|
||||
|
||||
if disable_retries:
|
||||
# For testing, propagate the exception immediately
|
||||
raise
|
||||
|
||||
if is_rate_limit and retries < max_retries:
|
||||
# Calculate backoff with jitter
|
||||
backoff = (
|
||||
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
||||
f"(attempt {retries + 1}/{max_retries}): "
|
||||
f"({error_str!r}, {error_type!r})"
|
||||
)
|
||||
|
||||
time.sleep(backoff)
|
||||
retries += 1
|
||||
last_error = e
|
||||
else:
|
||||
# Not a rate limit error or max retries reached, raise
|
||||
raise
|
||||
|
||||
# If we exit the loop due to max retries, raise the last error
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def embedding_sleep_and_retry_async(max_retries=5, base_backoff=1.0, jitter=0.5):
|
||||
"""
|
||||
Add retry logic with exponential backoff for asynchronous embedding functions.
|
||||
|
||||
This decorator retries the wrapped asynchronous function upon encountering rate limit
|
||||
errors, utilizing exponential backoff with optional jitter to space out retry attempts.
|
||||
It allows for a maximum number of retries before giving up and raising the last error
|
||||
encountered.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- max_retries: Maximum number of retries allowed before giving up. (default 5)
|
||||
- base_backoff: Base amount of time in seconds to wait before retrying after a rate
|
||||
limit error. (default 1.0)
|
||||
- jitter: Amount of randomness to add to the backoff duration to help mitigate burst
|
||||
issues on retries. (default 0.5)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns a decorated asynchronous function that implements the retry logic on rate
|
||||
limit errors.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
"""
|
||||
Handle retries for an async function with exponential backoff and jitter.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: An asynchronous function to be wrapped with retry logic.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the wrapper function that manages the retry behavior for the wrapped async
|
||||
function.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Handle retries for an async function with exponential backoff and jitter.
|
||||
|
||||
If the environment variable DISABLE_RETRIES is set to true, 1, or yes, the function will
|
||||
not retry on errors.
|
||||
It attempts to call the wrapped function until it succeeds or the maximum number of
|
||||
retries is reached. If an exception occurs, it checks if it's a rate limit error to
|
||||
determine if a retry is needed.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped async function if successful; raises the last
|
||||
encountered error if all retries fail.
|
||||
"""
|
||||
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
||||
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
retries = 0
|
||||
last_error = None
|
||||
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Check if this is a rate limit error
|
||||
error_str = str(e).lower()
|
||||
error_type = type(e).__name__
|
||||
is_rate_limit = any(
|
||||
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
||||
)
|
||||
|
||||
if disable_retries:
|
||||
# For testing, propagate the exception immediately
|
||||
raise
|
||||
|
||||
if is_rate_limit and retries < max_retries:
|
||||
# Calculate backoff with jitter
|
||||
backoff = (
|
||||
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
||||
f"(attempt {retries + 1}/{max_retries}): "
|
||||
f"({error_str!r}, {error_type!r})"
|
||||
)
|
||||
|
||||
await asyncio.sleep(backoff)
|
||||
retries += 1
|
||||
last_error = e
|
||||
else:
|
||||
# Not a rate limit error or max retries reached, raise
|
||||
raise
|
||||
|
||||
# If we exit the loop due to max retries, raise the last error
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
@ -193,6 +193,8 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
for (data_point_index, data_point) in enumerate(data_points)
|
||||
]
|
||||
|
||||
lance_data_points = list({dp.id: dp for dp in lance_data_points}.values())
|
||||
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
await (
|
||||
collection.merge_insert("id")
|
||||
|
|
|
|||
|
|
@ -74,6 +74,41 @@ class LLMConfig(BaseSettings):
|
|||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def strip_quotes_from_strings(self) -> "LLMConfig":
|
||||
"""
|
||||
Strip surrounding quotes from specific string fields that often come from
|
||||
environment variables with extra quotes (e.g., via Docker's --env-file).
|
||||
|
||||
Only applies to known config keys where quotes are invalid or cause issues.
|
||||
"""
|
||||
string_fields_to_strip = [
|
||||
"llm_api_key",
|
||||
"llm_endpoint",
|
||||
"llm_api_version",
|
||||
"baml_llm_api_key",
|
||||
"baml_llm_endpoint",
|
||||
"baml_llm_api_version",
|
||||
"fallback_api_key",
|
||||
"fallback_endpoint",
|
||||
"fallback_model",
|
||||
"llm_provider",
|
||||
"llm_model",
|
||||
"baml_llm_provider",
|
||||
"baml_llm_model",
|
||||
]
|
||||
|
||||
cls = self.__class__
|
||||
for field_name in string_fields_to_strip:
|
||||
if field_name not in cls.model_fields:
|
||||
continue
|
||||
value = getattr(self, field_name, None)
|
||||
if isinstance(value, str) and len(value) >= 2:
|
||||
if value[0] == value[-1] and value[0] in ("'", '"'):
|
||||
setattr(self, field_name, value[1:-1])
|
||||
|
||||
return self
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
"""Initialize the BAML registry after the model is created."""
|
||||
# Check if BAML is selected as structured output framework but not available
|
||||
|
|
|
|||
|
|
@ -1,7 +1,15 @@
|
|||
import asyncio
|
||||
from typing import Type
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_delay,
|
||||
wait_exponential_jitter,
|
||||
retry_if_not_exception_type,
|
||||
before_sleep_log,
|
||||
)
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction.create_dynamic_baml_type import (
|
||||
create_dynamic_baml_type,
|
||||
|
|
@ -10,12 +18,18 @@ from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type
|
|||
TypeBuilder,
|
||||
)
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client import b
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
import logging
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def acreate_structured_output(
|
||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
):
|
||||
|
|
@ -45,11 +59,12 @@ async def acreate_structured_output(
|
|||
tb = TypeBuilder()
|
||||
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, response_model)
|
||||
|
||||
result = await b.AcreateStructuredOutput(
|
||||
text_input=text_input,
|
||||
system_prompt=system_prompt,
|
||||
baml_options={"client_registry": config.baml_registry, "tb": type_builder},
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
result = await b.AcreateStructuredOutput(
|
||||
text_input=text_input,
|
||||
system_prompt=system_prompt,
|
||||
baml_options={"client_registry": config.baml_registry, "tb": type_builder},
|
||||
)
|
||||
|
||||
# Transform BAML response to proper pydantic reponse model
|
||||
if response_model is str:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from tenacity import (
|
|||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -45,7 +46,7 @@ class AnthropicAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -69,17 +70,17 @@ class AnthropicAdapter(LLMInterface):
|
|||
|
||||
- BaseModel: An instance of BaseModel containing the structured response.
|
||||
"""
|
||||
|
||||
return await self.aclient(
|
||||
model=self.model,
|
||||
max_tokens=4096,
|
||||
max_retries=5,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
from the following input: {text_input}. {system_prompt}""",
|
||||
}
|
||||
],
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient(
|
||||
model=self.model,
|
||||
max_tokens=4096,
|
||||
max_retries=2,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
from the following input: {text_input}. {system_prompt}""",
|
||||
}
|
||||
],
|
||||
response_model=response_model,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
LLMInterface,
|
||||
)
|
||||
import logging
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from tenacity import (
|
||||
retry,
|
||||
|
|
@ -73,7 +74,7 @@ class GeminiAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -105,24 +106,25 @@ class GeminiAdapter(LLMInterface):
|
|||
"""
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
max_retries=5,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
max_retries=2,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
@ -140,23 +142,24 @@ class GeminiAdapter(LLMInterface):
|
|||
)
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=2,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
LLMInterface,
|
||||
)
|
||||
import logging
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from tenacity import (
|
||||
retry,
|
||||
|
|
@ -73,7 +74,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -105,23 +106,24 @@ class GenericAPIAdapter(LLMInterface):
|
|||
"""
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=2,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
@ -139,23 +141,24 @@ class GenericAPIAdapter(LLMInterface):
|
|||
) from error
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=2,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
|
||||
import logging
|
||||
from tenacity import (
|
||||
|
|
@ -62,7 +63,7 @@ class MistralAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -97,13 +98,14 @@ class MistralAdapter(LLMInterface):
|
|||
},
|
||||
]
|
||||
try:
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_completion_tokens,
|
||||
max_retries=5,
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_completion_tokens,
|
||||
max_retries=2,
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
)
|
||||
if response.choices and response.choices[0].message.content:
|
||||
content = response.choices[0].message.content
|
||||
return response_model.model_validate_json(content)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
)
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_delay,
|
||||
|
|
@ -68,7 +70,7 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -95,28 +97,28 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
|
||||
- BaseModel: A structured output that conforms to the specified response model.
|
||||
"""
|
||||
|
||||
response = self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{text_input}",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
response = self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{text_input}",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=2,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
from cognee.infrastructure.llm.exceptions import (
|
||||
ContentPolicyFilterError,
|
||||
)
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -105,7 +106,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
@observe(as_type="generation")
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -135,34 +136,9 @@ class OpenAIAdapter(LLMInterface):
|
|||
"""
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
) as e:
|
||||
if not (self.fallback_model and self.fallback_api_key):
|
||||
raise e
|
||||
try:
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -173,11 +149,38 @@ class OpenAIAdapter(LLMInterface):
|
|||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.fallback_api_key,
|
||||
# api_base=self.fallback_endpoint,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
) as e:
|
||||
if not (self.fallback_model and self.fallback_api_key):
|
||||
raise e
|
||||
try:
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.fallback_api_key,
|
||||
# api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
|
|||
|
|
@ -211,24 +211,10 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
node.add_attribute("vector_distance", score)
|
||||
mapped_nodes += 1
|
||||
|
||||
async def map_vector_distances_to_graph_edges(
|
||||
self, vector_engine, query_vector, edge_distances
|
||||
) -> None:
|
||||
async def map_vector_distances_to_graph_edges(self, edge_distances) -> None:
|
||||
try:
|
||||
if query_vector is None or len(query_vector) == 0:
|
||||
raise ValueError("Failed to generate query embedding.")
|
||||
|
||||
if edge_distances is None:
|
||||
start_time = time.time()
|
||||
edge_distances = await vector_engine.search(
|
||||
collection_name="EdgeType_relationship_name",
|
||||
query_vector=query_vector,
|
||||
limit=None,
|
||||
)
|
||||
projection_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s"
|
||||
)
|
||||
return
|
||||
|
||||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
|
||||
|
|
|
|||
|
|
@ -1,232 +0,0 @@
|
|||
from typing import Any, Optional, List
|
||||
import asyncio
|
||||
import aiofiles
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
|
||||
logger = get_logger("CodeRetriever")
|
||||
|
||||
|
||||
class CodeRetriever(BaseRetriever):
|
||||
"""Retriever for handling code-based searches."""
|
||||
|
||||
class CodeQueryInfo(BaseModel):
|
||||
"""
|
||||
Model for representing the result of a query related to code files.
|
||||
|
||||
This class holds a list of filenames and the corresponding source code extracted from a
|
||||
query. It is used to encapsulate response data in a structured format.
|
||||
"""
|
||||
|
||||
filenames: List[str] = []
|
||||
sourcecode: str
|
||||
|
||||
def __init__(self, top_k: int = 3):
|
||||
"""Initialize retriever with search parameters."""
|
||||
self.top_k = top_k
|
||||
self.file_name_collections = ["CodeFile_name"]
|
||||
self.classes_and_functions_collections = [
|
||||
"ClassDefinition_source_code",
|
||||
"FunctionDefinition_source_code",
|
||||
]
|
||||
|
||||
async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
|
||||
"""Process the query using LLM to extract file names and source code parts."""
|
||||
logger.debug(
|
||||
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
|
||||
|
||||
try:
|
||||
result = await LLMGateway.acreate_structured_output(
|
||||
text_input=query,
|
||||
system_prompt=system_prompt,
|
||||
response_model=self.CodeQueryInfo,
|
||||
)
|
||||
logger.info(
|
||||
f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
|
||||
raise RuntimeError("Failed to retrieve structured output from LLM") from e
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Find relevant code files based on the query."""
|
||||
logger.info(
|
||||
f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
if not query or not isinstance(query, str):
|
||||
logger.error("Invalid query: must be a non-empty string")
|
||||
raise ValueError("The query must be a non-empty string.")
|
||||
|
||||
try:
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
logger.debug("Successfully initialized vector and graph engines")
|
||||
except Exception as e:
|
||||
logger.error(f"Database initialization error: {str(e)}")
|
||||
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
|
||||
|
||||
files_and_codeparts = await self._process_query(query)
|
||||
|
||||
similar_filenames = []
|
||||
similar_codepieces = []
|
||||
|
||||
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
|
||||
logger.info("No specific files/code extracted from query, performing general search")
|
||||
|
||||
for collection in self.file_name_collections:
|
||||
logger.debug(f"Searching {collection} collection with general query")
|
||||
search_results_file = await vector_engine.search(
|
||||
collection, query, limit=self.top_k
|
||||
)
|
||||
logger.debug(f"Found {len(search_results_file)} results in {collection}")
|
||||
for res in search_results_file:
|
||||
similar_filenames.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
existing_collection = []
|
||||
for collection in self.classes_and_functions_collections:
|
||||
if await vector_engine.has_collection(collection):
|
||||
existing_collection.append(collection)
|
||||
|
||||
if not existing_collection:
|
||||
raise RuntimeError("No collection found for code retriever")
|
||||
|
||||
for collection in existing_collection:
|
||||
logger.debug(f"Searching {collection} collection with general query")
|
||||
search_results_code = await vector_engine.search(
|
||||
collection, query, limit=self.top_k
|
||||
)
|
||||
logger.debug(f"Found {len(search_results_code)} results in {collection}")
|
||||
for res in search_results_code:
|
||||
similar_codepieces.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
|
||||
)
|
||||
|
||||
for collection in self.file_name_collections:
|
||||
for file_from_query in files_and_codeparts.filenames:
|
||||
logger.debug(f"Searching {collection} for specific file: {file_from_query}")
|
||||
search_results_file = await vector_engine.search(
|
||||
collection, file_from_query, limit=self.top_k
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(search_results_file)} results for file {file_from_query}"
|
||||
)
|
||||
for res in search_results_file:
|
||||
similar_filenames.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
for collection in self.classes_and_functions_collections:
|
||||
logger.debug(f"Searching {collection} with extracted source code")
|
||||
search_results_code = await vector_engine.search(
|
||||
collection, files_and_codeparts.sourcecode, limit=self.top_k
|
||||
)
|
||||
logger.debug(f"Found {len(search_results_code)} results for source code search")
|
||||
for res in search_results_code:
|
||||
similar_codepieces.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
total_items = len(similar_filenames) + len(similar_codepieces)
|
||||
logger.info(
|
||||
f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
|
||||
)
|
||||
|
||||
if total_items == 0:
|
||||
logger.warning("No search results found, returning empty list")
|
||||
return []
|
||||
|
||||
logger.debug("Getting graph connections for all search results")
|
||||
relevant_triplets = await asyncio.gather(
|
||||
*[
|
||||
graph_engine.get_connections(similar_piece["id"])
|
||||
for similar_piece in similar_filenames + similar_codepieces
|
||||
]
|
||||
)
|
||||
logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
|
||||
|
||||
paths = set()
|
||||
for i, sublist in enumerate(relevant_triplets):
|
||||
logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
|
||||
for tpl in sublist:
|
||||
if isinstance(tpl, tuple) and len(tpl) >= 3:
|
||||
if "file_path" in tpl[0]:
|
||||
paths.add(tpl[0]["file_path"])
|
||||
if "file_path" in tpl[2]:
|
||||
paths.add(tpl[2]["file_path"])
|
||||
|
||||
logger.info(f"Found {len(paths)} unique file paths to read")
|
||||
|
||||
retrieved_files = {}
|
||||
read_tasks = []
|
||||
for file_path in paths:
|
||||
|
||||
async def read_file(fp):
|
||||
try:
|
||||
logger.debug(f"Reading file: {fp}")
|
||||
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
|
||||
content = await f.read()
|
||||
retrieved_files[fp] = content
|
||||
logger.debug(f"Successfully read {len(content)} characters from {fp}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading {fp}: {e}")
|
||||
retrieved_files[fp] = ""
|
||||
|
||||
read_tasks.append(read_file(file_path))
|
||||
|
||||
await asyncio.gather(*read_tasks)
|
||||
logger.info(
|
||||
f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
|
||||
)
|
||||
|
||||
result = [
|
||||
{
|
||||
"name": file_path,
|
||||
"description": file_path,
|
||||
"content": retrieved_files[file_path],
|
||||
}
|
||||
for file_path in paths
|
||||
]
|
||||
|
||||
logger.info(f"Returning {len(result)} code file contexts")
|
||||
return result
|
||||
|
||||
async def get_completion(
|
||||
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Returns the code files context.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string to retrieve code context for.
|
||||
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
|
||||
the context for the query. (default None)
|
||||
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||
defaults to 'default_session'. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: The code files context, either provided or retrieved.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
return context
|
||||
10
cognee/modules/retrieval/register_retriever.py
Normal file
10
cognee/modules/retrieval/register_retriever.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from typing import Type
|
||||
|
||||
from .base_retriever import BaseRetriever
|
||||
from .registered_community_retrievers import registered_community_retrievers
|
||||
from ..search.types import SearchType
|
||||
|
||||
|
||||
def use_retriever(search_type: SearchType, retriever: Type[BaseRetriever]):
|
||||
"""Register a retriever class for a given search type."""
|
||||
registered_community_retrievers[search_type] = retriever
|
||||
|
|
@ -0,0 +1 @@
|
|||
registered_community_retrievers = {}
|
||||
|
|
@ -137,6 +137,9 @@ async def brute_force_triplet_search(
|
|||
"DocumentChunk_text",
|
||||
]
|
||||
|
||||
if "EdgeType_relationship_name" not in collections:
|
||||
collections.append("EdgeType_relationship_name")
|
||||
|
||||
try:
|
||||
vector_engine = get_vector_engine()
|
||||
except Exception as e:
|
||||
|
|
@ -197,9 +200,7 @@ async def brute_force_triplet_search(
|
|||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(
|
||||
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
|
||||
)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
|
|||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
)
|
||||
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
||||
|
||||
|
|
@ -162,10 +161,6 @@ async def get_search_type_tools(
|
|||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.CODE: [
|
||||
CodeRetriever(top_k=top_k).get_completion,
|
||||
CodeRetriever(top_k=top_k).get_context,
|
||||
],
|
||||
SearchType.CYPHER: [
|
||||
CypherSearchRetriever().get_completion,
|
||||
CypherSearchRetriever().get_context,
|
||||
|
|
@ -208,7 +203,19 @@ async def get_search_type_tools(
|
|||
):
|
||||
raise UnsupportedSearchTypeError("Cypher query search types are disabled.")
|
||||
|
||||
search_type_tools = search_tasks.get(query_type)
|
||||
from cognee.modules.retrieval.registered_community_retrievers import (
|
||||
registered_community_retrievers,
|
||||
)
|
||||
|
||||
if query_type in registered_community_retrievers:
|
||||
retriever = registered_community_retrievers[query_type]
|
||||
retriever_instance = retriever(top_k=top_k)
|
||||
search_type_tools = [
|
||||
retriever_instance.get_completion,
|
||||
retriever_instance.get_context,
|
||||
]
|
||||
else:
|
||||
search_type_tools = search_tasks.get(query_type)
|
||||
|
||||
if not search_type_tools:
|
||||
raise UnsupportedSearchTypeError(str(query_type))
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ class SearchType(Enum):
|
|||
TRIPLET_COMPLETION = "TRIPLET_COMPLETION"
|
||||
GRAPH_COMPLETION = "GRAPH_COMPLETION"
|
||||
GRAPH_SUMMARY_COMPLETION = "GRAPH_SUMMARY_COMPLETION"
|
||||
CODE = "CODE"
|
||||
CYPHER = "CYPHER"
|
||||
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
||||
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
|
||||
|
|
|
|||
30
cognee/shared/rate_limiting.py
Normal file
30
cognee/shared/rate_limiting.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from aiolimiter import AsyncLimiter
|
||||
from contextlib import nullcontext
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
llm_config = get_llm_config()
|
||||
|
||||
llm_rate_limiter = AsyncLimiter(
|
||||
llm_config.llm_rate_limit_requests, llm_config.embedding_rate_limit_interval
|
||||
)
|
||||
embedding_rate_limiter = AsyncLimiter(
|
||||
llm_config.embedding_rate_limit_requests, llm_config.embedding_rate_limit_interval
|
||||
)
|
||||
|
||||
|
||||
def llm_rate_limiter_context_manager():
|
||||
global llm_rate_limiter
|
||||
if llm_config.llm_rate_limit_enabled:
|
||||
return llm_rate_limiter
|
||||
else:
|
||||
# Return a no-op context manager if rate limiting is disabled
|
||||
return nullcontext()
|
||||
|
||||
|
||||
def embedding_rate_limiter_context_manager():
|
||||
global embedding_rate_limiter
|
||||
if llm_config.embedding_rate_limit_enabled:
|
||||
return embedding_rate_limiter
|
||||
else:
|
||||
# Return a no-op context manager if rate limiting is disabled
|
||||
return nullcontext()
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
import os
|
||||
import asyncio
|
||||
import argparse
|
||||
from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
|
||||
from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Execute the main logic of the dependency graph processor.
|
||||
|
||||
This function sets up argument parsing to retrieve the repository path, checks the
|
||||
existence of the specified path, and processes the repository to produce a dependency
|
||||
graph. If the repository path does not exist, it logs an error message and terminates
|
||||
without further execution.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("repo_path", help="Path to the repository")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_path = args.repo_path
|
||||
if not os.path.exists(repo_path):
|
||||
print(f"Error: The provided repository path does not exist: {repo_path}")
|
||||
return
|
||||
|
||||
graph = asyncio.run(get_repo_file_dependencies(repo_path))
|
||||
graph = asyncio.run(enrich_dependency_graph(graph))
|
||||
for node in graph.nodes:
|
||||
print(f"Node: {node}")
|
||||
for _, target, data in graph.out_edges(node, data=True):
|
||||
print(f" Edge to {target}, data: {data}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Get local script dependencies.")
|
||||
|
||||
# Suggested path: .../cognee/examples/python/simple_example.py
|
||||
parser.add_argument("script_path", type=str, help="Absolute path to the Python script file")
|
||||
|
||||
# Suggested path: .../cognee
|
||||
parser.add_argument("repo_path", type=str, help="Absolute path to the repository root")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dependencies = asyncio.run(get_local_script_dependencies(args.script_path, args.repo_path))
|
||||
|
||||
print("Dependencies:")
|
||||
for dependency in dependencies:
|
||||
print(dependency)
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
import os
|
||||
import asyncio
|
||||
import argparse
|
||||
from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Parse the command line arguments and print the repository file dependencies.
|
||||
|
||||
This function sets up an argument parser to retrieve the path of a repository. It checks
|
||||
if the provided path exists and if it doesn’t, it prints an error message and exits. If
|
||||
the path is valid, it calls an asynchronous function to get the dependencies and prints
|
||||
the nodes and their relations in the dependency graph.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("repo_path", help="Path to the repository")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_path = args.repo_path
|
||||
if not os.path.exists(repo_path):
|
||||
print(f"Error: The provided repository path does not exist: {repo_path}")
|
||||
return
|
||||
|
||||
graph = asyncio.run(get_repo_file_dependencies(repo_path))
|
||||
|
||||
for node in graph.nodes:
|
||||
print(f"Node: {node}")
|
||||
edges = graph.edges(node, data=True)
|
||||
for _, target, data in edges:
|
||||
print(f" Edge to {target}, Relation: {data.get('relation')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
from .get_non_code_files import get_non_py_files
|
||||
from .get_repo_file_dependencies import get_repo_file_dependencies
|
||||
|
|
@ -1,335 +0,0 @@
|
|||
import os
|
||||
import aiofiles
|
||||
import importlib
|
||||
from typing import AsyncGenerator, Optional
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
import tree_sitter_python as tspython
|
||||
from tree_sitter import Language, Node, Parser, Tree
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.low_level import DataPoint
|
||||
from cognee.shared.CodeGraphEntities import (
|
||||
CodeFile,
|
||||
ImportStatement,
|
||||
FunctionDefinition,
|
||||
ClassDefinition,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class FileParser:
|
||||
"""
|
||||
Handles the parsing of files into source code and an abstract syntax tree
|
||||
representation. Public methods include:
|
||||
|
||||
- parse_file: Parses a file and returns its source code and syntax tree representation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.parsed_files = {}
|
||||
|
||||
async def parse_file(self, file_path: str) -> tuple[str, Tree]:
|
||||
"""
|
||||
Parse a file and return its source code along with its syntax tree representation.
|
||||
|
||||
If the file has already been parsed, retrieve the result from memory instead of reading
|
||||
the file again.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to parse.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- tuple[str, Tree]: A tuple containing the source code of the file and its
|
||||
corresponding syntax tree representation.
|
||||
"""
|
||||
PY_LANGUAGE = Language(tspython.language())
|
||||
source_code_parser = Parser(PY_LANGUAGE)
|
||||
|
||||
if file_path not in self.parsed_files:
|
||||
source_code = await get_source_code(file_path)
|
||||
source_code_tree = source_code_parser.parse(bytes(source_code, "utf-8"))
|
||||
self.parsed_files[file_path] = (source_code, source_code_tree)
|
||||
|
||||
return self.parsed_files[file_path]
|
||||
|
||||
|
||||
async def get_source_code(file_path: str):
|
||||
"""
|
||||
Read source code from a file asynchronously.
|
||||
|
||||
This function attempts to open a file specified by the given file path, read its
|
||||
contents, and return the source code. In case of any errors during the file reading
|
||||
process, it logs an error message and returns None.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path to the file from which to read the source code.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the contents of the file as a string if successful, or None if an error
|
||||
occurs.
|
||||
"""
|
||||
try:
|
||||
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
||||
source_code = await f.read()
|
||||
return source_code
|
||||
except Exception as error:
|
||||
logger.error(f"Error reading file {file_path}: {str(error)}")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_module_path(module_name):
|
||||
"""
|
||||
Find the file path of a module.
|
||||
|
||||
Return the file path of the specified module if found, or return None if the module does
|
||||
not exist or cannot be located.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- module_name: The name of the module whose file path is to be resolved.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The file path of the module as a string or None if the module is not found.
|
||||
"""
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec and spec.origin:
|
||||
return spec.origin
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def find_function_location(
|
||||
module_path: str, function_name: str, parser: FileParser
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""
|
||||
Find the location of a function definition in a specified module.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- module_path (str): The path to the module where the function is defined.
|
||||
- function_name (str): The name of the function whose location is to be found.
|
||||
- parser (FileParser): An instance of FileParser used to parse the module's source
|
||||
code.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Optional[tuple[str, str]]: Returns a tuple containing the module path and the
|
||||
start point of the function if found; otherwise, returns None.
|
||||
"""
|
||||
if not module_path or not os.path.exists(module_path):
|
||||
return None
|
||||
|
||||
source_code, tree = parser.parse_file(module_path)
|
||||
root_node: Node = tree.root_node
|
||||
|
||||
for node in root_node.children:
|
||||
if node.type == "function_definition":
|
||||
func_name_node = node.child_by_field_name("name")
|
||||
|
||||
if func_name_node and func_name_node.text.decode() == function_name:
|
||||
return (module_path, node.start_point) # (line, column)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_local_script_dependencies(
|
||||
repo_path: str, script_path: str, detailed_extraction: bool = False
|
||||
) -> CodeFile:
|
||||
"""
|
||||
Retrieve local script dependencies and create a CodeFile object.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- repo_path (str): The path to the repository that contains the script.
|
||||
- script_path (str): The path of the script for which dependencies are being
|
||||
extracted.
|
||||
- detailed_extraction (bool): A flag indicating whether to perform a detailed
|
||||
extraction of code components.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- CodeFile: Returns a CodeFile object containing information about the script,
|
||||
including its dependencies and definitions.
|
||||
"""
|
||||
code_file_parser = FileParser()
|
||||
source_code, source_code_tree = await code_file_parser.parse_file(script_path)
|
||||
|
||||
file_path_relative_to_repo = script_path[len(repo_path) + 1 :]
|
||||
|
||||
if not detailed_extraction:
|
||||
code_file_node = CodeFile(
|
||||
id=uuid5(NAMESPACE_OID, script_path),
|
||||
name=file_path_relative_to_repo,
|
||||
source_code=source_code,
|
||||
file_path=script_path,
|
||||
language="python",
|
||||
)
|
||||
return code_file_node
|
||||
|
||||
code_file_node = CodeFile(
|
||||
id=uuid5(NAMESPACE_OID, script_path),
|
||||
name=file_path_relative_to_repo,
|
||||
source_code=None,
|
||||
file_path=script_path,
|
||||
language="python",
|
||||
)
|
||||
|
||||
async for part in extract_code_parts(source_code_tree.root_node, script_path=script_path):
|
||||
part.file_path = script_path
|
||||
|
||||
if isinstance(part, FunctionDefinition):
|
||||
code_file_node.provides_function_definition.append(part)
|
||||
if isinstance(part, ClassDefinition):
|
||||
code_file_node.provides_class_definition.append(part)
|
||||
if isinstance(part, ImportStatement):
|
||||
code_file_node.depends_on.append(part)
|
||||
|
||||
return code_file_node
|
||||
|
||||
|
||||
def find_node(nodes: list[Node], condition: callable) -> Node:
|
||||
"""
|
||||
Find and return the first node that satisfies the given condition.
|
||||
|
||||
Iterate through the provided list of nodes and return the first node for which the
|
||||
condition callable returns True. If no such node is found, return None.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- nodes (list[Node]): A list of Node objects to search through.
|
||||
- condition (callable): A callable that takes a Node and returns a boolean
|
||||
indicating if the node meets specified criteria.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Node: The first Node that matches the condition, or None if no such node exists.
|
||||
"""
|
||||
for node in nodes:
|
||||
if condition(node):
|
||||
return node
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def extract_code_parts(
|
||||
tree_root: Node, script_path: str, existing_nodes: list[DataPoint] = {}
|
||||
) -> AsyncGenerator[DataPoint, None]:
|
||||
"""
|
||||
Extract code parts from a given AST node tree asynchronously.
|
||||
|
||||
Iteratively yields DataPoint nodes representing import statements, function definitions,
|
||||
and class definitions found in the children of the specified tree root. The function
|
||||
checks
|
||||
if nodes are already present in the existing_nodes dictionary to prevent duplicates.
|
||||
This function has to be used in an asynchronous context, and it requires a valid
|
||||
tree_root
|
||||
and proper initialization of existing_nodes.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- tree_root (Node): The root node of the AST tree containing code parts to extract.
|
||||
- script_path (str): The file path of the script from which the AST was generated.
|
||||
- existing_nodes (list[DataPoint]): A dictionary that holds already extracted
|
||||
DataPoint nodes to avoid duplicates. (default {})
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Yields DataPoint nodes representing imported modules, functions, and classes.
|
||||
"""
|
||||
for child_node in tree_root.children:
|
||||
if child_node.type == "import_statement" or child_node.type == "import_from_statement":
|
||||
parts = child_node.text.decode("utf-8").split()
|
||||
|
||||
if parts[0] == "import":
|
||||
module_name = parts[1]
|
||||
function_name = None
|
||||
elif parts[0] == "from":
|
||||
module_name = parts[1]
|
||||
function_name = parts[3]
|
||||
|
||||
if " as " in function_name:
|
||||
function_name = function_name.split(" as ")[0]
|
||||
|
||||
if " as " in module_name:
|
||||
module_name = module_name.split(" as ")[0]
|
||||
|
||||
if function_name and "import " + function_name not in existing_nodes:
|
||||
import_statement_node = ImportStatement(
|
||||
name=function_name,
|
||||
module=module_name,
|
||||
start_point=child_node.start_point,
|
||||
end_point=child_node.end_point,
|
||||
file_path=script_path,
|
||||
source_code=child_node.text,
|
||||
)
|
||||
existing_nodes["import " + function_name] = import_statement_node
|
||||
|
||||
if function_name:
|
||||
yield existing_nodes["import " + function_name]
|
||||
|
||||
if module_name not in existing_nodes:
|
||||
import_statement_node = ImportStatement(
|
||||
name=module_name,
|
||||
module=module_name,
|
||||
start_point=child_node.start_point,
|
||||
end_point=child_node.end_point,
|
||||
file_path=script_path,
|
||||
source_code=child_node.text,
|
||||
)
|
||||
existing_nodes[module_name] = import_statement_node
|
||||
|
||||
yield existing_nodes[module_name]
|
||||
|
||||
if child_node.type == "function_definition":
|
||||
function_node = find_node(child_node.children, lambda node: node.type == "identifier")
|
||||
function_node_name = function_node.text
|
||||
|
||||
if function_node_name not in existing_nodes:
|
||||
function_definition_node = FunctionDefinition(
|
||||
name=function_node_name,
|
||||
start_point=child_node.start_point,
|
||||
end_point=child_node.end_point,
|
||||
file_path=script_path,
|
||||
source_code=child_node.text,
|
||||
)
|
||||
existing_nodes[function_node_name] = function_definition_node
|
||||
|
||||
yield existing_nodes[function_node_name]
|
||||
|
||||
if child_node.type == "class_definition":
|
||||
class_name_node = find_node(child_node.children, lambda node: node.type == "identifier")
|
||||
class_name_node_name = class_name_node.text
|
||||
|
||||
if class_name_node_name not in existing_nodes:
|
||||
class_definition_node = ClassDefinition(
|
||||
name=class_name_node_name,
|
||||
start_point=child_node.start_point,
|
||||
end_point=child_node.end_point,
|
||||
file_path=script_path,
|
||||
source_code=child_node.text,
|
||||
)
|
||||
existing_nodes[class_name_node_name] = class_definition_node
|
||||
|
||||
yield existing_nodes[class_name_node_name]
|
||||
|
|
@ -1,158 +0,0 @@
|
|||
import os
|
||||
|
||||
|
||||
async def get_non_py_files(repo_path):
|
||||
"""
|
||||
Get files that are not .py files and their contents.
|
||||
|
||||
Check if the specified repository path exists and if so, traverse the directory,
|
||||
collecting the paths of files that do not have a .py extension and meet the
|
||||
criteria set in the allowed and ignored patterns. Return a list of paths to
|
||||
those files.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- repo_path: The file system path to the repository to scan for non-Python files.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
A list of file paths that are not Python files and meet the specified criteria.
|
||||
"""
|
||||
if not os.path.exists(repo_path):
|
||||
return {}
|
||||
|
||||
IGNORED_PATTERNS = {
|
||||
".git",
|
||||
"__pycache__",
|
||||
"*.pyc",
|
||||
"*.pyo",
|
||||
"*.pyd",
|
||||
"node_modules",
|
||||
"*.egg-info",
|
||||
}
|
||||
|
||||
ALLOWED_EXTENSIONS = {
|
||||
".txt",
|
||||
".md",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".html",
|
||||
".css",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".sql",
|
||||
".log",
|
||||
".ini",
|
||||
".toml",
|
||||
".properties",
|
||||
".sh",
|
||||
".bash",
|
||||
".dockerfile",
|
||||
".gitignore",
|
||||
".gitattributes",
|
||||
".makefile",
|
||||
".pyproject",
|
||||
".requirements",
|
||||
".env",
|
||||
".pdf",
|
||||
".doc",
|
||||
".docx",
|
||||
".dot",
|
||||
".dotx",
|
||||
".rtf",
|
||||
".wps",
|
||||
".wpd",
|
||||
".odt",
|
||||
".ott",
|
||||
".ottx",
|
||||
".txt",
|
||||
".wp",
|
||||
".sdw",
|
||||
".sdx",
|
||||
".docm",
|
||||
".dotm",
|
||||
# Additional extensions for other programming languages
|
||||
".java",
|
||||
".c",
|
||||
".cpp",
|
||||
".h",
|
||||
".cs",
|
||||
".go",
|
||||
".php",
|
||||
".rb",
|
||||
".swift",
|
||||
".pl",
|
||||
".lua",
|
||||
".rs",
|
||||
".scala",
|
||||
".kt",
|
||||
".sh",
|
||||
".sql",
|
||||
".v",
|
||||
".asm",
|
||||
".pas",
|
||||
".d",
|
||||
".ml",
|
||||
".clj",
|
||||
".cljs",
|
||||
".erl",
|
||||
".ex",
|
||||
".exs",
|
||||
".f",
|
||||
".fs",
|
||||
".r",
|
||||
".pyi",
|
||||
".pdb",
|
||||
".ipynb",
|
||||
".rmd",
|
||||
".cabal",
|
||||
".hs",
|
||||
".nim",
|
||||
".vhdl",
|
||||
".verilog",
|
||||
".svelte",
|
||||
".html",
|
||||
".css",
|
||||
".scss",
|
||||
".less",
|
||||
".json5",
|
||||
".yaml",
|
||||
".yml",
|
||||
}
|
||||
|
||||
def should_process(path):
|
||||
"""
|
||||
Determine if a file should be processed based on its extension and path patterns.
|
||||
|
||||
This function checks if the file extension is in the allowed list and ensures that none
|
||||
of the ignored patterns are present in the provided file path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- path: The file path to check for processing eligibility.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns True if the file should be processed; otherwise, False.
|
||||
"""
|
||||
_, ext = os.path.splitext(path)
|
||||
return ext in ALLOWED_EXTENSIONS and not any(
|
||||
pattern in path for pattern in IGNORED_PATTERNS
|
||||
)
|
||||
|
||||
non_py_files_paths = [
|
||||
os.path.join(root, file)
|
||||
for root, _, files in os.walk(repo_path)
|
||||
for file in files
|
||||
if not file.endswith(".py") and should_process(os.path.join(root, file))
|
||||
]
|
||||
return non_py_files_paths
|
||||
|
|
@ -1,243 +0,0 @@
|
|||
import asyncio
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
from typing import AsyncGenerator, Optional, List
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.shared.CodeGraphEntities import CodeFile, Repository
|
||||
|
||||
# constant, declared only once
|
||||
EXCLUDED_DIRS: Set[str] = {
|
||||
".venv",
|
||||
"venv",
|
||||
"env",
|
||||
".env",
|
||||
"site-packages",
|
||||
"node_modules",
|
||||
"dist",
|
||||
"build",
|
||||
".git",
|
||||
"tests",
|
||||
"test",
|
||||
}
|
||||
|
||||
|
||||
async def get_source_code_files(
|
||||
repo_path,
|
||||
language_config: dict[str, list[str]] | None = None,
|
||||
excluded_paths: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Retrieve Python source code files from the specified repository path.
|
||||
|
||||
This function scans the given repository path for files that have the .py extension
|
||||
while excluding test files and files within a virtual environment. It returns a list of
|
||||
absolute paths to the source code files that are not empty.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- repo_path: Root path of the repository to search
|
||||
- language_config: dict mapping language names to file extensions, e.g.,
|
||||
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
|
||||
- excluded_paths: Optional list of path fragments or glob patterns to exclude
|
||||
|
||||
Returns:
|
||||
--------
|
||||
A list of (absolute_path, language) tuples for source code files.
|
||||
"""
|
||||
|
||||
def _get_language_from_extension(file, language_config):
|
||||
for lang, exts in language_config.items():
|
||||
for ext in exts:
|
||||
if file.endswith(ext):
|
||||
return lang
|
||||
return None
|
||||
|
||||
# Default config if not provided
|
||||
if language_config is None:
|
||||
language_config = {
|
||||
"python": [".py"],
|
||||
"javascript": [".js", ".jsx"],
|
||||
"typescript": [".ts", ".tsx"],
|
||||
"java": [".java"],
|
||||
"csharp": [".cs"],
|
||||
"go": [".go"],
|
||||
"rust": [".rs"],
|
||||
"cpp": [".cpp", ".c", ".h", ".hpp"],
|
||||
}
|
||||
|
||||
if not os.path.exists(repo_path):
|
||||
return []
|
||||
|
||||
source_code_files = set()
|
||||
for root, _, files in os.walk(repo_path):
|
||||
for file in files:
|
||||
lang = _get_language_from_extension(file, language_config)
|
||||
if lang is None:
|
||||
continue
|
||||
# Exclude tests, common build/venv directories and files provided in exclude_paths
|
||||
excluded_dirs = EXCLUDED_DIRS
|
||||
excluded_paths = {Path(p).resolve() for p in (excluded_paths or [])} # full paths
|
||||
|
||||
root_path = Path(root).resolve()
|
||||
root_parts = set(root_path.parts) # same as before
|
||||
base_name, _ext = os.path.splitext(file)
|
||||
if (
|
||||
base_name.startswith("test_")
|
||||
or base_name.endswith("_test")
|
||||
or ".test." in file
|
||||
or ".spec." in file
|
||||
or (excluded_dirs & root_parts) # name match
|
||||
or any(
|
||||
root_path.is_relative_to(p) # full-path match
|
||||
for p in excluded_paths
|
||||
)
|
||||
):
|
||||
continue
|
||||
file_path = os.path.abspath(os.path.join(root, file))
|
||||
if os.path.getsize(file_path) == 0:
|
||||
continue
|
||||
source_code_files.add((file_path, lang))
|
||||
|
||||
return sorted(list(source_code_files))
|
||||
|
||||
|
||||
def run_coroutine(coroutine_func, *args, **kwargs):
|
||||
"""
|
||||
Run a coroutine function until it completes.
|
||||
|
||||
This function creates a new asyncio event loop, sets it as the current loop, and
|
||||
executes the given coroutine function with the provided arguments. Once the coroutine
|
||||
completes, the loop is closed. Intended for use in environments where an existing event
|
||||
loop is not available or desirable.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- coroutine_func: The coroutine function to be run.
|
||||
- *args: Positional arguments to pass to the coroutine function.
|
||||
- **kwargs: Keyword arguments to pass to the coroutine function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The result returned by the coroutine after completion.
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result = loop.run_until_complete(coroutine_func(*args, **kwargs))
|
||||
loop.close()
|
||||
return result
|
||||
|
||||
|
||||
async def get_repo_file_dependencies(
|
||||
repo_path: str,
|
||||
detailed_extraction: bool = False,
|
||||
supported_languages: list = None,
|
||||
excluded_paths: Optional[List[str]] = None,
|
||||
) -> AsyncGenerator[DataPoint, None]:
|
||||
"""
|
||||
Generate a dependency graph for source files (multi-language) in the given repository path.
|
||||
|
||||
Check the validity of the repository path and yield a repository object followed by the
|
||||
dependencies of source files within that repository. Raise a FileNotFoundError if the
|
||||
provided path does not exist. The extraction of detailed dependencies can be controlled
|
||||
via the `detailed_extraction` argument. Languages considered can be restricted via
|
||||
the `supported_languages` argument.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- repo_path (str): The file path to the repository to process.
|
||||
- detailed_extraction (bool): Whether to perform a detailed extraction of code parts.
|
||||
- supported_languages (list | None): Subset of languages to include; if None, use defaults.
|
||||
"""
|
||||
|
||||
if isinstance(repo_path, list) and len(repo_path) == 1:
|
||||
repo_path = repo_path[0]
|
||||
|
||||
if not os.path.exists(repo_path):
|
||||
raise FileNotFoundError(f"Repository path {repo_path} does not exist.")
|
||||
|
||||
# Build language config from supported_languages
|
||||
default_language_config = {
|
||||
"python": [".py"],
|
||||
"javascript": [".js", ".jsx"],
|
||||
"typescript": [".ts", ".tsx"],
|
||||
"java": [".java"],
|
||||
"csharp": [".cs"],
|
||||
"go": [".go"],
|
||||
"rust": [".rs"],
|
||||
"cpp": [".cpp", ".c", ".h", ".hpp"],
|
||||
"c": [".c", ".h"],
|
||||
}
|
||||
if supported_languages is not None:
|
||||
language_config = {
|
||||
k: v for k, v in default_language_config.items() if k in supported_languages
|
||||
}
|
||||
else:
|
||||
language_config = default_language_config
|
||||
|
||||
source_code_files = await get_source_code_files(
|
||||
repo_path, language_config=language_config, excluded_paths=excluded_paths
|
||||
)
|
||||
|
||||
repo = Repository(
|
||||
id=uuid5(NAMESPACE_OID, repo_path),
|
||||
path=repo_path,
|
||||
)
|
||||
|
||||
yield repo
|
||||
|
||||
chunk_size = 100
|
||||
number_of_chunks = math.ceil(len(source_code_files) / chunk_size)
|
||||
chunk_ranges = [
|
||||
(
|
||||
chunk_number * chunk_size,
|
||||
min((chunk_number + 1) * chunk_size, len(source_code_files)) - 1,
|
||||
)
|
||||
for chunk_number in range(number_of_chunks)
|
||||
]
|
||||
|
||||
# Import dependency extractors for each language (Python for now, extend later)
|
||||
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
|
||||
import aiofiles
|
||||
# TODO: Add other language extractors here
|
||||
|
||||
for start_range, end_range in chunk_ranges:
|
||||
tasks = []
|
||||
for file_path, lang in source_code_files[start_range : end_range + 1]:
|
||||
# For now, only Python is supported; extend with other languages
|
||||
if lang == "python":
|
||||
tasks.append(
|
||||
get_local_script_dependencies(repo_path, file_path, detailed_extraction)
|
||||
)
|
||||
else:
|
||||
# Placeholder: create a minimal CodeFile for other languages
|
||||
async def make_codefile_stub(file_path=file_path, lang=lang):
|
||||
async with aiofiles.open(
|
||||
file_path, "r", encoding="utf-8", errors="replace"
|
||||
) as f:
|
||||
source = await f.read()
|
||||
return CodeFile(
|
||||
id=uuid5(NAMESPACE_OID, file_path),
|
||||
name=os.path.relpath(file_path, repo_path),
|
||||
file_path=file_path,
|
||||
language=lang,
|
||||
source_code=source,
|
||||
)
|
||||
|
||||
tasks.append(make_codefile_stub())
|
||||
|
||||
results: list[CodeFile] = await asyncio.gather(*tasks)
|
||||
|
||||
for source_code_file in results:
|
||||
source_code_file.part_of = repo
|
||||
if getattr(
|
||||
source_code_file, "language", None
|
||||
) is None and source_code_file.file_path.endswith(".py"):
|
||||
source_code_file.language = "python"
|
||||
yield source_code_file
|
||||
46
cognee/tests/unit/infrastructure/llm/test_llm_config.py
Normal file
46
cognee/tests/unit/infrastructure/llm/test_llm_config.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import pytest
|
||||
|
||||
from cognee.infrastructure.llm.config import LLMConfig
|
||||
|
||||
|
||||
def test_strip_quotes_from_strings():
|
||||
"""
|
||||
Test if the LLMConfig.strip_quotes_from_strings model validator behaves as expected.
|
||||
"""
|
||||
config = LLMConfig(
|
||||
# Strings with surrounding double quotes ("value" → value)
|
||||
llm_api_key='"double_value"',
|
||||
# Strings with surrounding single quotes ('value' → value)
|
||||
llm_endpoint="'single_value'",
|
||||
# Strings without quotes (value → value)
|
||||
llm_api_version="no_quotes_value",
|
||||
# Empty quoted strings ("" → empty string)
|
||||
fallback_model='""',
|
||||
# None values (should remain None)
|
||||
baml_llm_api_key=None,
|
||||
# Mixed quotes ("value' → unchanged)
|
||||
fallback_endpoint="\"mixed_quote'",
|
||||
# Strings with internal quotes ("internal\"quotes" → internal"quotes")
|
||||
baml_llm_model='"internal"quotes"',
|
||||
)
|
||||
|
||||
# Strings with surrounding double quotes ("value" → value)
|
||||
assert config.llm_api_key == "double_value"
|
||||
|
||||
# Strings with surrounding single quotes ('value' → value)
|
||||
assert config.llm_endpoint == "single_value"
|
||||
|
||||
# Strings without quotes (value → value)
|
||||
assert config.llm_api_version == "no_quotes_value"
|
||||
|
||||
# Empty quoted strings ("" → empty string)
|
||||
assert config.fallback_model == ""
|
||||
|
||||
# None values (should remain None)
|
||||
assert config.baml_llm_api_key is None
|
||||
|
||||
# Mixed quotes ("value' → unchanged)
|
||||
assert config.fallback_endpoint == "\"mixed_quote'"
|
||||
|
||||
# Strings with internal quotes ("internal\"quotes" → internal"quotes")
|
||||
assert config.baml_llm_model == 'internal"quotes'
|
||||
|
|
@ -4,10 +4,7 @@ from typing import List
|
|||
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import (
|
||||
LiteLLMEmbeddingEngine,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
|
||||
embedding_rate_limit_async,
|
||||
embedding_sleep_and_retry_async,
|
||||
)
|
||||
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||
|
||||
|
||||
class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
|
||||
|
|
@ -34,8 +31,6 @@ class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
|
|||
self.fail_every_n_requests = fail_every_n_requests
|
||||
self.add_delay = add_delay
|
||||
|
||||
@embedding_sleep_and_retry_async()
|
||||
@embedding_rate_limit_async
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Mock implementation that returns fixed embeddings and can
|
||||
|
|
@ -52,4 +47,5 @@ class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
|
|||
raise Exception(f"Mock failure on request #{self.request_count}")
|
||||
|
||||
# Return mock embeddings of the correct dimension
|
||||
return [[0.1] * self.dimensions for _ in text]
|
||||
async with embedding_rate_limiter_context_manager():
|
||||
return [[0.1] * self.dimensions for _ in text]
|
||||
|
|
|
|||
|
|
@ -6,9 +6,6 @@ import logging
|
|||
from cognee.infrastructure.llm.config import (
|
||||
get_llm_config,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
|
||||
EmbeddingRateLimiter,
|
||||
)
|
||||
from cognee.tests.unit.infrastructure.mock_embedding_engine import MockEmbeddingEngine
|
||||
|
||||
# Configure logging
|
||||
|
|
@ -33,7 +30,6 @@ async def test_embedding_rate_limiting_realistic():
|
|||
|
||||
# Clear the config and rate limiter caches to ensure our settings are applied
|
||||
get_llm_config.cache_clear()
|
||||
EmbeddingRateLimiter.reset_instance()
|
||||
|
||||
# Create a fresh config instance and verify settings
|
||||
config = get_llm_config()
|
||||
|
|
@ -170,7 +166,6 @@ async def test_with_mock_failures():
|
|||
|
||||
# Clear caches
|
||||
get_llm_config.cache_clear()
|
||||
EmbeddingRateLimiter.reset_instance()
|
||||
|
||||
# Create a mock engine configured to fail every 3rd request
|
||||
engine = MockEmbeddingEngine()
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ async def test_map_vector_distances_multiple_categories(setup_graph):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine):
|
||||
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
|
||||
"""Test mapping vector distances to edges when edge_distances provided."""
|
||||
graph = setup_graph
|
||||
|
||||
|
|
@ -325,48 +325,13 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, moc
|
|||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine):
|
||||
"""Test mapping edge distances when searching for them."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
mock_vector_engine.search.return_value = [
|
||||
MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=None,
|
||||
)
|
||||
|
||||
mock_vector_engine.search.assert_called_once()
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.88
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine):
|
||||
async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
||||
"""Test mapping edge distances when only some edges have results."""
|
||||
graph = setup_graph
|
||||
|
||||
|
|
@ -386,20 +351,14 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vect
|
|||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
assert graph.edges[1].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_edges_fallback_to_relationship_type(
|
||||
setup_graph, mock_vector_engine
|
||||
):
|
||||
async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_graph):
|
||||
"""Test that edge mapping falls back to relationship_type when edge_text is missing."""
|
||||
graph = setup_graph
|
||||
|
||||
|
|
@ -419,17 +378,13 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(
|
|||
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.85
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine):
|
||||
async def test_map_vector_distances_no_edge_matches(setup_graph):
|
||||
"""Test edge mapping when no edges match the distance results."""
|
||||
graph = setup_graph
|
||||
|
||||
|
|
@ -449,26 +404,22 @@ async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_eng
|
|||
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine):
|
||||
"""Test that invalid query vector raises error."""
|
||||
async def test_map_vector_distances_none_returns_early(setup_graph):
|
||||
"""Test that edge_distances=None returns early without error."""
|
||||
graph = setup_graph
|
||||
graph.add_node(Node("1"))
|
||||
graph.add_node(Node("2"))
|
||||
graph.add_edge(Edge(graph.get_node("1"), graph.get_node("2")))
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to generate query embedding"):
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[],
|
||||
edge_distances=None,
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@ async def test_brute_force_triplet_search_default_collections():
|
|||
"TextSummary_text",
|
||||
"EntityType_name",
|
||||
"DocumentChunk_text",
|
||||
"EdgeType_relationship_name",
|
||||
]
|
||||
|
||||
call_collections = [
|
||||
|
|
@ -154,7 +155,32 @@ async def test_brute_force_triplet_search_custom_collections():
|
|||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert call_collections == custom_collections
|
||||
assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_always_includes_edge_collection():
|
||||
"""Test that EdgeType_relationship_name is always searched even when not in collections."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
collections_without_edge = ["Entity_name", "TextSummary_text"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=collections_without_edge)
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert "EdgeType_relationship_name" in call_collections
|
||||
assert set(call_collections) == set(collections_without_edge) | {
|
||||
"EdgeType_relationship_name"
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -1,63 +0,0 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import cognee
|
||||
from cognee import SearchType
|
||||
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
|
||||
async def main(repo_path, include_docs):
|
||||
# Disable permissions feature for this example
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
|
||||
|
||||
run_status = False
|
||||
async for run_status in run_code_graph_pipeline(repo_path, include_docs=include_docs):
|
||||
run_status = run_status
|
||||
|
||||
# Test CODE search
|
||||
search_results = await cognee.search(query_type=SearchType.CODE, query_text="test")
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nSearch results are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
return run_status
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
|
||||
parser.add_argument(
|
||||
"--include_docs",
|
||||
type=lambda x: x.lower() in ("true", "1"),
|
||||
default=False,
|
||||
help="Whether or not to process non-code files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time",
|
||||
type=lambda x: x.lower() in ("true", "1"),
|
||||
default=True,
|
||||
help="Whether or not to time the pipeline run",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=ERROR)
|
||||
|
||||
args = parse_args()
|
||||
|
||||
if args.time:
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
asyncio.run(main(args.repo_path, args.include_docs))
|
||||
end_time = time.time()
|
||||
print("\n" + "=" * 50)
|
||||
print(f"Pipeline Execution Time: {end_time - start_time:.2f} seconds")
|
||||
print("=" * 50 + "\n")
|
||||
else:
|
||||
asyncio.run(main(args.repo_path, args.include_docs))
|
||||
52
poetry.lock
generated
52
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "accelerate"
|
||||
|
|
@ -242,6 +242,18 @@ files = [
|
|||
{file = "aioitertools-0.13.0.tar.gz", hash = "sha256:620bd241acc0bbb9ec819f1ab215866871b4bbd1f73836a55f799200ee86950c"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiolimiter"
|
||||
version = "1.2.1"
|
||||
description = "asyncio rate limiter, a leaky bucket implementation"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "aiolimiter-1.2.1-py3-none-any.whl", hash = "sha256:d3f249e9059a20badcb56b61601a83556133655c11d1eb3dd3e04ff069e5f3c7"},
|
||||
{file = "aiolimiter-1.2.1.tar.gz", hash = "sha256:e02a37ea1a855d9e832252a105420ad4d15011505512a1a1d814647451b5cca9"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiosignal"
|
||||
version = "1.4.0"
|
||||
|
|
@ -3309,8 +3321,6 @@ files = [
|
|||
{file = "greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d"},
|
||||
{file = "greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5"},
|
||||
{file = "greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f"},
|
||||
{file = "greenlet-3.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f47617f698838ba98f4ff4189aef02e7343952df3a615f847bb575c3feb177a7"},
|
||||
{file = "greenlet-3.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af41be48a4f60429d5cad9d22175217805098a9ef7c40bfef44f7669fb9d74d8"},
|
||||
{file = "greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246"},
|
||||
|
|
@ -3320,8 +3330,6 @@ files = [
|
|||
{file = "greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5"},
|
||||
{file = "greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb"},
|
||||
|
|
@ -3331,8 +3339,6 @@ files = [
|
|||
{file = "greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d"},
|
||||
{file = "greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945"},
|
||||
|
|
@ -3342,8 +3348,6 @@ files = [
|
|||
{file = "greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929"},
|
||||
{file = "greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:49a30d5fda2507ae77be16479bdb62a660fa51b1eb4928b524975b3bde77b3c0"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:299fd615cd8fc86267b47597123e3f43ad79c9d8a22bebdce535e53550763e2f"},
|
||||
|
|
@ -3351,8 +3355,6 @@ files = [
|
|||
{file = "greenlet-3.2.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b4a1870c51720687af7fa3e7cda6d08d801dae660f75a76f3845b642b4da6ee1"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:061dc4cf2c34852b052a8620d40f36324554bc192be474b9e9770e8c042fd735"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44358b9bf66c8576a9f57a590d5f5d6e72fa4228b763d0e43fee6d3b06d3a337"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2917bdf657f5859fbf3386b12d68ede4cf1f04c90c3a6bc1f013dd68a22e2269"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:015d48959d4add5d6c9f6c5210ee3803a830dce46356e3bc326d6776bde54681"},
|
||||
{file = "greenlet-3.2.4-cp314-cp314-win_amd64.whl", hash = "sha256:e37ab26028f12dbb0ff65f29a8d3d44a765c61e729647bf2ddfbbed621726f01"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:b6a7c19cf0d2742d0809a4c05975db036fdff50cd294a93632d6a310bf9ac02c"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:27890167f55d2387576d1f41d9487ef171849ea0359ce1510ca6e06c8bece11d"},
|
||||
|
|
@ -3362,8 +3364,6 @@ files = [
|
|||
{file = "greenlet-3.2.4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9913f1a30e4526f432991f89ae263459b1c64d1608c0d22a5c79c287b3c70df"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b90654e092f928f110e0007f572007c9727b5265f7632c2fa7415b4689351594"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:81701fd84f26330f0d5f4944d4e92e61afe6319dcd9775e39396e39d7c3e5f98"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:28a3c6b7cd72a96f61b0e4b2a36f681025b60ae4779cc73c1535eb5f29560b10"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:52206cd642670b0b320a1fd1cbfd95bca0e043179c1d8a045f2c6109dfe973be"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-win32.whl", hash = "sha256:65458b409c1ed459ea899e939f0e1cdb14f58dbc803f2f93c5eab5694d32671b"},
|
||||
{file = "greenlet-3.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:d2e685ade4dafd447ede19c31277a224a239a0a1a4eca4e6390efedf20260cfb"},
|
||||
{file = "greenlet-3.2.4.tar.gz", hash = "sha256:0dca0d95ff849f9a364385f36ab49f50065d76964944638be9691e1832e9f86d"},
|
||||
|
|
@ -4413,8 +4413,6 @@ groups = ["main"]
|
|||
markers = "extra == \"dlt\""
|
||||
files = [
|
||||
{file = "jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c"},
|
||||
{file = "jsonpath_ng-1.7.0-py2-none-any.whl", hash = "sha256:898c93fc173f0c336784a3fa63d7434297544b7198124a68f9a3ef9597b0ae6e"},
|
||||
{file = "jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -5688,11 +5686,8 @@ files = [
|
|||
{file = "lxml-5.4.0-cp36-cp36m-win_amd64.whl", hash = "sha256:7ce1a171ec325192c6a636b64c94418e71a1964f56d002cc28122fceff0b6121"},
|
||||
{file = "lxml-5.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:795f61bcaf8770e1b37eec24edf9771b307df3af74d1d6f27d812e15a9ff3872"},
|
||||
{file = "lxml-5.4.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:29f451a4b614a7b5b6c2e043d7b64a15bd8304d7e767055e8ab68387a8cacf4e"},
|
||||
{file = "lxml-5.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:891f7f991a68d20c75cb13c5c9142b2a3f9eb161f1f12a9489c82172d1f133c0"},
|
||||
{file = "lxml-5.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4aa412a82e460571fad592d0f93ce9935a20090029ba08eca05c614f99b0cc92"},
|
||||
{file = "lxml-5.4.0-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:ac7ba71f9561cd7d7b55e1ea5511543c0282e2b6450f122672a2694621d63b7e"},
|
||||
{file = "lxml-5.4.0-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:c5d32f5284012deaccd37da1e2cd42f081feaa76981f0eaa474351b68df813c5"},
|
||||
{file = "lxml-5.4.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:ce31158630a6ac85bddd6b830cffd46085ff90498b397bd0a259f59d27a12188"},
|
||||
{file = "lxml-5.4.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:31e63621e073e04697c1b2d23fcb89991790eef370ec37ce4d5d469f40924ed6"},
|
||||
{file = "lxml-5.4.0-cp37-cp37m-win32.whl", hash = "sha256:be2ba4c3c5b7900246a8f866580700ef0d538f2ca32535e991027bdaba944063"},
|
||||
{file = "lxml-5.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:09846782b1ef650b321484ad429217f5154da4d6e786636c38e434fa32e94e49"},
|
||||
|
|
@ -9343,10 +9338,8 @@ files = [
|
|||
{file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c47676e5b485393f069b4d7a811267d3168ce46f988fa602658b8bb901e9e64d"},
|
||||
{file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:a28d8c01a7b27a1e3265b11250ba7557e5f72b5ee9e5f3a2fa8d2949c29bf5d2"},
|
||||
{file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5f3f2732cf504a1aa9e9609d02f79bea1067d99edf844ab92c247bbca143303b"},
|
||||
{file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:865f9945ed1b3950d968ec4690ce68c55019d79e4497366d36e090327ce7db14"},
|
||||
{file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:91537a8df2bde69b1c1db01d6d944c831ca793952e4f57892600e96cee95f2cd"},
|
||||
{file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4dca1f356a67ecb68c81a7bc7809f1569ad9e152ce7fd02c2f2036862ca9f66b"},
|
||||
{file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:0da4de5c1ac69d94ed4364b6cbe7190c1a70d325f112ba783d83f8440285f152"},
|
||||
{file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37d8412565a7267f7d79e29ab66876e55cb5e8e7b3bbf94f8206f6795f8f7e7e"},
|
||||
{file = "psycopg2_binary-2.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:c665f01ec8ab273a61c62beeb8cce3014c214429ced8a308ca1fc410ecac3a39"},
|
||||
{file = "psycopg2_binary-2.9.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0e8480afd62362d0a6a27dd09e4ca2def6fa50ed3a4e7c09165266106b2ffa10"},
|
||||
|
|
@ -9354,10 +9347,8 @@ files = [
|
|||
{file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2e164359396576a3cc701ba8af4751ae68a07235d7a380c631184a611220d9a4"},
|
||||
{file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:d57c9c387660b8893093459738b6abddbb30a7eab058b77b0d0d1c7d521ddfd7"},
|
||||
{file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2c226ef95eb2250974bf6fa7a842082b31f68385c4f3268370e3f3870e7859ee"},
|
||||
{file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a311f1edc9967723d3511ea7d2708e2c3592e3405677bf53d5c7246753591fbb"},
|
||||
{file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ebb415404821b6d1c47353ebe9c8645967a5235e6d88f914147e7fd411419e6f"},
|
||||
{file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f07c9c4a5093258a03b28fab9b4f151aa376989e7f35f855088234e656ee6a94"},
|
||||
{file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:00ce1830d971f43b667abe4a56e42c1e2d594b32da4802e44a73bacacb25535f"},
|
||||
{file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cffe9d7697ae7456649617e8bb8d7a45afb71cd13f7ab22af3e5c61f04840908"},
|
||||
{file = "psycopg2_binary-2.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:304fd7b7f97eef30e91b8f7e720b3db75fee010b520e434ea35ed1ff22501d03"},
|
||||
{file = "psycopg2_binary-2.9.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:be9b840ac0525a283a96b556616f5b4820e0526addb8dcf6525a0fa162730be4"},
|
||||
|
|
@ -9365,10 +9356,8 @@ files = [
|
|||
{file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ab8905b5dcb05bf3fb22e0cf90e10f469563486ffb6a96569e51f897c750a76a"},
|
||||
{file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:bf940cd7e7fec19181fdbc29d76911741153d51cab52e5c21165f3262125685e"},
|
||||
{file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fa0f693d3c68ae925966f0b14b8edda71696608039f4ed61b1fe9ffa468d16db"},
|
||||
{file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a1cf393f1cdaf6a9b57c0a719a1068ba1069f022a59b8b1fe44b006745b59757"},
|
||||
{file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ef7a6beb4beaa62f88592ccc65df20328029d721db309cb3250b0aae0fa146c3"},
|
||||
{file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:31b32c457a6025e74d233957cc9736742ac5a6cb196c6b68499f6bb51390bd6a"},
|
||||
{file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:edcb3aeb11cb4bf13a2af3c53a15b3d612edeb6409047ea0b5d6a21a9d744b34"},
|
||||
{file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:62b6d93d7c0b61a1dd6197d208ab613eb7dcfdcca0a49c42ceb082257991de9d"},
|
||||
{file = "psycopg2_binary-2.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:b33fabeb1fde21180479b2d4667e994de7bbf0eec22832ba5d9b5e4cf65b6c6d"},
|
||||
{file = "psycopg2_binary-2.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b8fb3db325435d34235b044b199e56cdf9ff41223a4b9752e8576465170bb38c"},
|
||||
|
|
@ -9376,10 +9365,8 @@ files = [
|
|||
{file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8c55b385daa2f92cb64b12ec4536c66954ac53654c7f15a203578da4e78105c0"},
|
||||
{file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c0377174bf1dd416993d16edc15357f6eb17ac998244cca19bc67cdc0e2e5766"},
|
||||
{file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5c6ff3335ce08c75afaed19e08699e8aacf95d4a260b495a4a8545244fe2ceb3"},
|
||||
{file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:84011ba3109e06ac412f95399b704d3d6950e386b7994475b231cf61eec2fc1f"},
|
||||
{file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ba34475ceb08cccbdd98f6b46916917ae6eeb92b5ae111df10b544c3a4621dc4"},
|
||||
{file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b31e90fdd0f968c2de3b26ab014314fe814225b6c324f770952f7d38abf17e3c"},
|
||||
{file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:d526864e0f67f74937a8fce859bd56c979f5e2ec57ca7c627f5f1071ef7fee60"},
|
||||
{file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04195548662fa544626c8ea0f06561eb6203f1984ba5b4562764fbeb4c3d14b1"},
|
||||
{file = "psycopg2_binary-2.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:efff12b432179443f54e230fdf60de1f6cc726b6c832db8701227d089310e8aa"},
|
||||
{file = "psycopg2_binary-2.9.11-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:92e3b669236327083a2e33ccfa0d320dd01b9803b3e14dd986a4fc54aa00f4e1"},
|
||||
|
|
@ -9387,10 +9374,8 @@ files = [
|
|||
{file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b52a3f9bb540a3e4ec0f6ba6d31339727b2950c9772850d6545b7eae0b9d7c5"},
|
||||
{file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:db4fd476874ccfdbb630a54426964959e58da4c61c9feba73e6094d51303d7d8"},
|
||||
{file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47f212c1d3be608a12937cc131bd85502954398aaa1320cb4c14421a0ffccf4c"},
|
||||
{file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e35b7abae2b0adab776add56111df1735ccc71406e56203515e228a8dc07089f"},
|
||||
{file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:fcf21be3ce5f5659daefd2b3b3b6e4727b028221ddc94e6c1523425579664747"},
|
||||
{file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:9bd81e64e8de111237737b29d68039b9c813bdf520156af36d26819c9a979e5f"},
|
||||
{file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:32770a4d666fbdafab017086655bcddab791d7cb260a16679cc5a7338b64343b"},
|
||||
{file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c3cb3a676873d7506825221045bd70e0427c905b9c8ee8d6acd70cfcbd6e576d"},
|
||||
{file = "psycopg2_binary-2.9.11-cp314-cp314-win_amd64.whl", hash = "sha256:4012c9c954dfaccd28f94e84ab9f94e12df76b4afb22331b1f0d3154893a6316"},
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:20e7fb94e20b03dcc783f76c0865f9da39559dcc0c28dd1a3fce0d01902a6b9c"},
|
||||
|
|
@ -9398,10 +9383,8 @@ files = [
|
|||
{file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9d3a9edcfbe77a3ed4bc72836d466dfce4174beb79eda79ea155cc77237ed9e8"},
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:44fc5c2b8fa871ce7f0023f619f1349a0aa03a0857f2c96fbc01c657dcbbdb49"},
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9c55460033867b4622cda1b6872edf445809535144152e5d14941ef591980edf"},
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:2d11098a83cca92deaeaed3d58cfd150d49b3b06ee0d0852be466bf87596899e"},
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:691c807d94aecfbc76a14e1408847d59ff5b5906a04a23e12a89007672b9e819"},
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:8b81627b691f29c4c30a8f322546ad039c40c328373b11dff7490a3e1b517855"},
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_riscv64.whl", hash = "sha256:b637d6d941209e8d96a072d7977238eea128046effbf37d1d8b2c0764750017d"},
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:41360b01c140c2a03d346cec3280cf8a71aa07d94f3b1509fa0161c366af66b4"},
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:875039274f8a2361e5207857899706da840768e2a775bf8c65e82f60b197df02"},
|
||||
]
|
||||
|
|
@ -10692,13 +10675,6 @@ optional = false
|
|||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"},
|
||||
{file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"},
|
||||
{file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"},
|
||||
{file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"},
|
||||
|
|
@ -14468,4 +14444,4 @@ scraping = ["APScheduler", "beautifulsoup4", "lxml", "lxml", "playwright", "prot
|
|||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "6c8f26955a23ff510ddd0ba4eac4046fb2738e8b5787c5eb3b7abca91fec6905"
|
||||
content-hash = "09f7040236a62a2d610e79e92394bb0c23e13ed41ba4de92c064ab4d5430b84e"
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ dependencies = [
|
|||
"tenacity>=9.0.0",
|
||||
"fakeredis[lua]>=2.32.0",
|
||||
"diskcache>=5.6.3",
|
||||
"aiolimiter>=1.2.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
13
uv.lock
generated
13
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.10, <3.14"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'",
|
||||
|
|
@ -187,6 +187,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/10/a1/510b0a7fadc6f43a6ce50152e69dbd86415240835868bb0bd9b5b88b1e06/aioitertools-0.13.0-py3-none-any.whl", hash = "sha256:0be0292b856f08dfac90e31f4739432f4cb6d7520ab9eb73e143f4f2fa5259be", size = 24182, upload-time = "2025-11-06T22:17:06.502Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiolimiter"
|
||||
version = "1.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/23/b52debf471f7a1e42e362d959a3982bdcb4fe13a5d46e63d28868807a79c/aiolimiter-1.2.1.tar.gz", hash = "sha256:e02a37ea1a855d9e832252a105420ad4d15011505512a1a1d814647451b5cca9", size = 7185, upload-time = "2024-12-08T15:31:51.496Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/ba/df6e8e1045aebc4778d19b8a3a9bc1808adb1619ba94ca354d9ba17d86c3/aiolimiter-1.2.1-py3-none-any.whl", hash = "sha256:d3f249e9059a20badcb56b61601a83556133655c11d1eb3dd3e04ff069e5f3c7", size = 6711, upload-time = "2024-12-08T15:31:49.874Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiosignal"
|
||||
version = "1.4.0"
|
||||
|
|
@ -942,6 +951,7 @@ source = { editable = "." }
|
|||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
{ name = "aiohttp" },
|
||||
{ name = "aiolimiter" },
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "alembic" },
|
||||
{ name = "diskcache" },
|
||||
|
|
@ -1113,6 +1123,7 @@ scraping = [
|
|||
requires-dist = [
|
||||
{ name = "aiofiles", specifier = ">=23.2.1" },
|
||||
{ name = "aiohttp", specifier = ">=3.11.14,<4.0.0" },
|
||||
{ name = "aiolimiter", specifier = ">=1.2.1" },
|
||||
{ name = "aiosqlite", specifier = ">=0.20.0,<1.0.0" },
|
||||
{ name = "alembic", specifier = ">=1.13.3,<2" },
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.27" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue