From 2018850dff8cc4437b151aa59d45583ff6e56522 Mon Sep 17 00:00:00 2001 From: Dmitrii Galkin Date: Sat, 26 Apr 2025 20:14:35 +0400 Subject: [PATCH] Add OpenAI-compatible chat and responses API endpoints with function calling support --- cognee/api/client.py | 3 + cognee/api/v1/responses/__init__.py | 3 + cognee/api/v1/responses/default_tools.py | 66 ++++++++ cognee/api/v1/responses/dispatch_function.py | 107 +++++++++++++ cognee/api/v1/responses/models.py | 102 ++++++++++++ cognee/api/v1/responses/routers/__init__.py | 3 + .../api/v1/responses/routers/default_tools.py | 86 +++++++++++ .../responses/routers/get_responses_router.py | 146 ++++++++++++++++++ 8 files changed, 516 insertions(+) create mode 100644 cognee/api/v1/responses/__init__.py create mode 100644 cognee/api/v1/responses/default_tools.py create mode 100644 cognee/api/v1/responses/dispatch_function.py create mode 100644 cognee/api/v1/responses/models.py create mode 100644 cognee/api/v1/responses/routers/__init__.py create mode 100644 cognee/api/v1/responses/routers/default_tools.py create mode 100644 cognee/api/v1/responses/routers/get_responses_router.py diff --git a/cognee/api/client.py b/cognee/api/client.py index b91e149c1..8d40b5fce 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -14,6 +14,7 @@ from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_ from cognee.api.v1.search.routers import get_search_router from cognee.api.v1.add.routers import get_add_router from cognee.api.v1.delete.routers import get_delete_router +from cognee.api.v1.responses.routers import get_responses_router from fastapi import Request from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError @@ -167,6 +168,8 @@ app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["vi app.include_router(get_delete_router(), prefix="/api/v1/delete", tags=["delete"]) +app.include_router(get_responses_router(), prefix="/api/v1/responses", tags=["responses"]) + codegraph_routes = get_code_pipeline_router() if codegraph_routes: app.include_router(codegraph_routes, prefix="/api/v1/code-pipeline", tags=["code-pipeline"]) diff --git a/cognee/api/v1/responses/__init__.py b/cognee/api/v1/responses/__init__.py new file mode 100644 index 000000000..73d14a3bc --- /dev/null +++ b/cognee/api/v1/responses/__init__.py @@ -0,0 +1,3 @@ +from cognee.api.v1.responses.routers import get_responses_router + +__all__ = ["get_responses_router"] diff --git a/cognee/api/v1/responses/default_tools.py b/cognee/api/v1/responses/default_tools.py new file mode 100644 index 000000000..1ac589c68 --- /dev/null +++ b/cognee/api/v1/responses/default_tools.py @@ -0,0 +1,66 @@ +DEFAULT_TOOLS = [ + { + "type": "function", + "name": "search", + "description": "Search for information within the knowledge graph", + "parameters": { + "type": "object", + "properties": { + "search_query": { + "type": "string", + "description": "The query to search for in the knowledge graph", + }, + "search_type": { + "type": "string", + "description": "Type of search to perform", + "enum": [ + "INSIGHTS", + "CODE", + "GRAPH_COMPLETION", + "SEMANTIC", + "NATURAL_LANGUAGE", + ], + }, + "top_k": { + "type": "integer", + "description": "Maximum number of results to return", + "default": 10, + }, + "datasets": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional list of dataset names to search within", + }, + }, + "required": ["search_query"], + }, + }, + { + "type": "function", + "name": "cognify_text", + "description": "Convert text into a knowledge graph or process all added content", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Text content to be converted into a knowledge graph", + }, + "graph_model_name": { + "type": "string", + "description": "Name of the graph model to use", + }, + "graph_model_file": { + "type": "string", + "description": "Path to a custom graph model file", + }, + }, + }, + }, + # Commented as dangerous + # { + # "type": "function", + # "name": "prune", + # "description": "Prune memory", + # }, +] diff --git a/cognee/api/v1/responses/dispatch_function.py b/cognee/api/v1/responses/dispatch_function.py new file mode 100644 index 000000000..3ab5a5a66 --- /dev/null +++ b/cognee/api/v1/responses/dispatch_function.py @@ -0,0 +1,107 @@ +import json +import logging +from typing import Any, Dict, Union + +from cognee.api.v1.responses.models import ToolCall +from cognee.modules.search.types import SearchType +from cognee.api.v1.add import add +from cognee.api.v1.search import search +from cognee.api.v1.cognify import cognify +from cognee.api.v1.prune import prune + + +from cognee.modules.users.methods import get_default_user +from cognee.api.v1.responses.default_tools import DEFAULT_TOOLS + +logger = logging.getLogger(__name__) + + +async def dispatch_function(tool_call: Union[ToolCall, Dict[str, Any]]) -> str: + """ + Dispatches a function call to the appropriate Cognee function. + """ + if isinstance(tool_call, dict): + function_data = tool_call.get("function", {}) + function_name = function_data.get("name", "") + arguments_str = function_data.get("arguments", "{}") + else: + function_name = tool_call.function.name + arguments_str = tool_call.function.arguments + + arguments = json.loads(arguments_str) + + logger.info(f"Dispatching function: {function_name} with args: {arguments}") + + user = await get_default_user() + + if function_name == "search": + return await handle_search(arguments, user) + elif function_name == "cognify_text": + return await handle_cognify(arguments, user) + elif function_name == "prune": + return await handle_prune(arguments, user) + else: + return f"Error: Unknown function {function_name}" + + +async def handle_search(arguments: Dict[str, Any], user) -> list: + """Handle search function call""" + search_tool = next((tool for tool in DEFAULT_TOOLS if tool["name"] == "search"), None) + required_params = ( + search_tool["parameters"].get("required", []) if search_tool else ["search_query"] + ) + + query = arguments.get("search_query") + if not query and "search_query" in required_params: + return "Error: Missing required 'search_query' parameter" + + search_type_str = arguments.get("search_type", "GRAPH_COMPLETION") + valid_search_types = ( + search_tool["parameters"]["properties"]["search_type"]["enum"] + if search_tool + else ["INSIGHTS", "CODE", "GRAPH_COMPLETION", "SEMANTIC", "NATURAL_LANGUAGE"] + ) + + if search_type_str not in valid_search_types: + logger.warning(f"Invalid search_type: {search_type_str}, defaulting to GRAPH_COMPLETION") + search_type_str = "GRAPH_COMPLETION" + + query_type = search_type_str + + top_k = arguments.get("top_k") + datasets = arguments.get("datasets") + system_prompt_path = arguments.get("system_prompt_path", "answer_simple_question.txt") + + results = await search( + query_text=query, + query_type=query_type, + datasets=datasets, + user=user, + system_prompt_path=system_prompt_path, + top_k=top_k if isinstance(top_k, int) else 10, + ) + + return results + + +async def handle_cognify(arguments: Dict[str, Any], user) -> str: + """Handle cognify function call""" + text = arguments.get("text") + graph_model_file = arguments.get("graph_model_file") + + if text: + await add(data=text, user=user) + + await cognify(user=user, ontology_file_path=graph_model_file if graph_model_file else None) + + return ( + "Text successfully converted into knowledge graph." + if text + else "Knowledge graph successfully updated with new information." + ) + + +async def handle_prune(arguments: Dict[str, Any], user) -> str: + """Handle prune function call""" + await prune() + return "Memory has been pruned successfully." diff --git a/cognee/api/v1/responses/models.py b/cognee/api/v1/responses/models.py new file mode 100644 index 000000000..aa33e4e48 --- /dev/null +++ b/cognee/api/v1/responses/models.py @@ -0,0 +1,102 @@ +import time +import uuid +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Field + +from enum import Enum + +from cognee.api.DTO import InDTO, OutDTO + + +class CogneeModel(str, Enum): + """Enum for supported model types""" + + COGNEEV1 = "cognee-v1" + + +class FunctionParameters(BaseModel): + """JSON Schema for function parameters""" + + type: str = "object" + properties: Dict[str, Dict[str, Any]] + required: Optional[List[str]] = None + + +class Function(BaseModel): + """Function definition compatible with OpenAI's format""" + + name: str + description: str + parameters: FunctionParameters + + +class ToolFunction(BaseModel): + """Tool function wrapper (for OpenAI compatibility)""" + + type: str = "function" + function: Function + + +class FunctionCall(BaseModel): + """Function call made by the assistant""" + + name: str + arguments: str + + +class ToolCall(BaseModel): + """Tool call made by the assistant""" + + id: str = Field(default_factory=lambda: f"call_{uuid.uuid4().hex}") + type: str = "function" + function: FunctionCall + + +class ChatUsage(BaseModel): + """Token usage information""" + + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +class ResponseRequest(InDTO): + """Request body for the new responses endpoint (OpenAI Responses API format)""" + + model: CogneeModel = CogneeModel.COGNEEV1 + input: str + tools: Optional[List[ToolFunction]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto" + user: Optional[str] = None + temperature: Optional[float] = 1.0 + max_tokens: Optional[int] = None + + +class ToolCallOutput(BaseModel): + """Output of a tool call in the responses API""" + + status: str = "success" # success/error + data: Optional[Dict[str, Any]] = None + + +class ResponseToolCall(BaseModel): + """Tool call in a response""" + + id: str = Field(default_factory=lambda: f"call_{uuid.uuid4().hex}") + type: str = "function" + function: FunctionCall + output: Optional[ToolCallOutput] = None + + +class ResponseResponse(OutDTO): + """Response body for the new responses endpoint""" + + id: str = Field(default_factory=lambda: f"resp_{uuid.uuid4().hex}") + created: int = Field(default_factory=lambda: int(time.time())) + model: str + object: str = "response" + status: str = "completed" + tool_calls: List[ResponseToolCall] + usage: Optional[ChatUsage] = None + metadata: Dict[str, Any] = None diff --git a/cognee/api/v1/responses/routers/__init__.py b/cognee/api/v1/responses/routers/__init__.py new file mode 100644 index 000000000..6d484e8f4 --- /dev/null +++ b/cognee/api/v1/responses/routers/__init__.py @@ -0,0 +1,3 @@ +from cognee.api.v1.responses.routers.get_responses_router import get_responses_router + +__all__ = ["get_responses_router"] diff --git a/cognee/api/v1/responses/routers/default_tools.py b/cognee/api/v1/responses/routers/default_tools.py new file mode 100644 index 000000000..75663829b --- /dev/null +++ b/cognee/api/v1/responses/routers/default_tools.py @@ -0,0 +1,86 @@ +DEFAULT_TOOLS = [ + { + "type": "function", + "name": "search", + "description": "Search for information within the knowledge graph", + "parameters": { + "type": "object", + "properties": { + "search_query": { + "type": "string", + "description": "The query to search for in the knowledge graph", + }, + "search_type": { + "type": "string", + "description": "Type of search to perform", + "enum": [ + "INSIGHTS", + "CODE", + "GRAPH_COMPLETION", + "SEMANTIC", + "NATURAL_LANGUAGE", + ], + }, + "top_k": { + "type": "integer", + "description": "Maximum number of results to return", + "default": 10, + }, + "datasets": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional list of dataset names to search within", + }, + }, + "required": ["search_query"], + }, + }, + { + "type": "function", + "name": "cognify_text", + "description": "Convert text into a knowledge graph or process all added content", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Text content to be converted into a knowledge graph", + }, + "graph_model_name": { + "type": "string", + "description": "Name of the graph model to use", + }, + "graph_model_file": { + "type": "string", + "description": "Path to a custom graph model file", + }, + }, + }, + }, + { + "type": "function", + "name": "prune", + "description": "Remove unnecessary or outdated information from the knowledge graph", + "parameters": { + "type": "object", + "properties": { + "prune_strategy": { + "type": "string", + "enum": ["light", "moderate", "aggressive"], + "description": "Strategy for pruning the knowledge graph", + "default": "moderate", + }, + "min_confidence": { + "type": "number", + "description": "Minimum confidence score to retain (0-1)", + "minimum": 0, + "maximum": 1, + }, + "older_than": { + "type": "string", + "description": "ISO date string - prune nodes older than this date", + }, + }, + }, + }, +] diff --git a/cognee/api/v1/responses/routers/get_responses_router.py b/cognee/api/v1/responses/routers/get_responses_router.py new file mode 100644 index 000000000..35a5988ca --- /dev/null +++ b/cognee/api/v1/responses/routers/get_responses_router.py @@ -0,0 +1,146 @@ +""" +Get router for the OpenAI-compatible responses API. +""" + +import logging +import uuid +from typing import Dict, List, Optional, Any +import openai +from fastapi import APIRouter +from cognee.api.v1.responses.models import ( + ResponseRequest, + ResponseResponse, + ResponseToolCall, + ChatUsage, + FunctionCall, + ToolCallOutput, +) +from cognee.api.v1.responses.dispatch_function import dispatch_function +from cognee.api.v1.responses.default_tools import DEFAULT_TOOLS +from cognee.infrastructure.llm.config import get_llm_config + + +def get_responses_router() -> APIRouter: + """ + Returns the FastAPI router for OpenAI-compatible responses. + + This implementation follows the new OpenAI Responses API format as described in: + https://platform.openai.com/docs/api-reference/responses/create + """ + + router = APIRouter() + logger = logging.getLogger(__name__) + + def _get_model_client(): + """ + Get appropriate client based on model name + """ + llm_config = get_llm_config() + return openai.OpenAI(api_key=llm_config.llm_api_key) + + async def call_openai_api_for_model( + input_text: str, + model: str, + tools: Optional[List[Dict[str, Any]]] = DEFAULT_TOOLS, + tool_choice: Any = "auto", + temperature: float = 1.0, + ) -> Dict[str, Any]: + """ + Call appropriate model API based on model name + """ + + # TODO: Support other models (e.g. cognee-v1-openai-gpt-3.5-turbo, etc.) + model = "gpt-4o" + + client = _get_model_client() + + logger.debug(f"Using model: {model}") + + response = client.responses.create( + model=model, + input=input_text, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + ) + logger.info(f"Response: {response}") + return response.model_dump() + + @router.post("/", response_model=ResponseResponse) + async def create_response( + request: ResponseRequest, + ) -> ResponseResponse: + """ + OpenAI-compatible responses endpoint with function calling support + """ + # Use default tools if none provided + tools = request.tools or DEFAULT_TOOLS + + # Call the API + response = await call_openai_api_for_model( + input_text=request.input, + model=request.model, + tools=tools, + tool_choice=request.tool_choice, + temperature=request.temperature, + ) + + # Use the response ID from the API or generate a new one + response_id = response.get("id", f"resp_{uuid.uuid4().hex}") + + # Check if there are function tool calls in the output + output = response.get("output", []) + + processed_tool_calls = [] + + # Process any function tool calls from the output + for item in output: + if isinstance(item, dict) and item.get("type") == "function_call": + # This is a tool call from the new format + function_name = item.get("name", "") + arguments_str = item.get("arguments", "{}") + call_id = item.get("call_id", f"call_{uuid.uuid4().hex}") + + # Create a format the dispatcher can handle + tool_call = { + "id": call_id, + "function": {"name": function_name, "arguments": arguments_str}, + "type": "function", + } + + # Dispatch the function + try: + function_result = await dispatch_function(tool_call) + output_status = "success" + except Exception as e: + logger.exception(f"Error executing function {function_name}: {e}") + function_result = f"Error executing {function_name}: {str(e)}" + output_status = "error" + + processed_call = ResponseToolCall( + id=call_id, + type="function", + function=FunctionCall(name=function_name, arguments=arguments_str), + output=ToolCallOutput(status=output_status, data={"result": function_result}), + ) + + processed_tool_calls.append(processed_call) + + # Get usage data from the response if available + usage = response.get("usage", {}) + + # Create the response object with all processed tool calls + response_obj = ResponseResponse( + id=response_id, + model=request.model, + tool_calls=processed_tool_calls, + usage=ChatUsage( + prompt_tokens=usage.get("input_tokens", 0), + completion_tokens=usage.get("output_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + ) + + return response_obj + + return router