Add OpenAI-compatible chat and responses API endpoints with function calling support
This commit is contained in:
parent
7d7df1876e
commit
2018850dff
8 changed files with 516 additions and 0 deletions
|
|
@ -14,6 +14,7 @@ from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_
|
|||
from cognee.api.v1.search.routers import get_search_router
|
||||
from cognee.api.v1.add.routers import get_add_router
|
||||
from cognee.api.v1.delete.routers import get_delete_router
|
||||
from cognee.api.v1.responses.routers import get_responses_router
|
||||
from fastapi import Request
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
|
@ -167,6 +168,8 @@ app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["vi
|
|||
|
||||
app.include_router(get_delete_router(), prefix="/api/v1/delete", tags=["delete"])
|
||||
|
||||
app.include_router(get_responses_router(), prefix="/api/v1/responses", tags=["responses"])
|
||||
|
||||
codegraph_routes = get_code_pipeline_router()
|
||||
if codegraph_routes:
|
||||
app.include_router(codegraph_routes, prefix="/api/v1/code-pipeline", tags=["code-pipeline"])
|
||||
|
|
|
|||
3
cognee/api/v1/responses/__init__.py
Normal file
3
cognee/api/v1/responses/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from cognee.api.v1.responses.routers import get_responses_router
|
||||
|
||||
__all__ = ["get_responses_router"]
|
||||
66
cognee/api/v1/responses/default_tools.py
Normal file
66
cognee/api/v1/responses/default_tools.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
DEFAULT_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "search",
|
||||
"description": "Search for information within the knowledge graph",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search_query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for in the knowledge graph",
|
||||
},
|
||||
"search_type": {
|
||||
"type": "string",
|
||||
"description": "Type of search to perform",
|
||||
"enum": [
|
||||
"INSIGHTS",
|
||||
"CODE",
|
||||
"GRAPH_COMPLETION",
|
||||
"SEMANTIC",
|
||||
"NATURAL_LANGUAGE",
|
||||
],
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 10,
|
||||
},
|
||||
"datasets": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of dataset names to search within",
|
||||
},
|
||||
},
|
||||
"required": ["search_query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "cognify_text",
|
||||
"description": "Convert text into a knowledge graph or process all added content",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text content to be converted into a knowledge graph",
|
||||
},
|
||||
"graph_model_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the graph model to use",
|
||||
},
|
||||
"graph_model_file": {
|
||||
"type": "string",
|
||||
"description": "Path to a custom graph model file",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
# Commented as dangerous
|
||||
# {
|
||||
# "type": "function",
|
||||
# "name": "prune",
|
||||
# "description": "Prune memory",
|
||||
# },
|
||||
]
|
||||
107
cognee/api/v1/responses/dispatch_function.py
Normal file
107
cognee/api/v1/responses/dispatch_function.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from cognee.api.v1.responses.models import ToolCall
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.api.v1.add import add
|
||||
from cognee.api.v1.search import search
|
||||
from cognee.api.v1.cognify import cognify
|
||||
from cognee.api.v1.prune import prune
|
||||
|
||||
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.api.v1.responses.default_tools import DEFAULT_TOOLS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def dispatch_function(tool_call: Union[ToolCall, Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Dispatches a function call to the appropriate Cognee function.
|
||||
"""
|
||||
if isinstance(tool_call, dict):
|
||||
function_data = tool_call.get("function", {})
|
||||
function_name = function_data.get("name", "")
|
||||
arguments_str = function_data.get("arguments", "{}")
|
||||
else:
|
||||
function_name = tool_call.function.name
|
||||
arguments_str = tool_call.function.arguments
|
||||
|
||||
arguments = json.loads(arguments_str)
|
||||
|
||||
logger.info(f"Dispatching function: {function_name} with args: {arguments}")
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
if function_name == "search":
|
||||
return await handle_search(arguments, user)
|
||||
elif function_name == "cognify_text":
|
||||
return await handle_cognify(arguments, user)
|
||||
elif function_name == "prune":
|
||||
return await handle_prune(arguments, user)
|
||||
else:
|
||||
return f"Error: Unknown function {function_name}"
|
||||
|
||||
|
||||
async def handle_search(arguments: Dict[str, Any], user) -> list:
|
||||
"""Handle search function call"""
|
||||
search_tool = next((tool for tool in DEFAULT_TOOLS if tool["name"] == "search"), None)
|
||||
required_params = (
|
||||
search_tool["parameters"].get("required", []) if search_tool else ["search_query"]
|
||||
)
|
||||
|
||||
query = arguments.get("search_query")
|
||||
if not query and "search_query" in required_params:
|
||||
return "Error: Missing required 'search_query' parameter"
|
||||
|
||||
search_type_str = arguments.get("search_type", "GRAPH_COMPLETION")
|
||||
valid_search_types = (
|
||||
search_tool["parameters"]["properties"]["search_type"]["enum"]
|
||||
if search_tool
|
||||
else ["INSIGHTS", "CODE", "GRAPH_COMPLETION", "SEMANTIC", "NATURAL_LANGUAGE"]
|
||||
)
|
||||
|
||||
if search_type_str not in valid_search_types:
|
||||
logger.warning(f"Invalid search_type: {search_type_str}, defaulting to GRAPH_COMPLETION")
|
||||
search_type_str = "GRAPH_COMPLETION"
|
||||
|
||||
query_type = search_type_str
|
||||
|
||||
top_k = arguments.get("top_k")
|
||||
datasets = arguments.get("datasets")
|
||||
system_prompt_path = arguments.get("system_prompt_path", "answer_simple_question.txt")
|
||||
|
||||
results = await search(
|
||||
query_text=query,
|
||||
query_type=query_type,
|
||||
datasets=datasets,
|
||||
user=user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k if isinstance(top_k, int) else 10,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def handle_cognify(arguments: Dict[str, Any], user) -> str:
|
||||
"""Handle cognify function call"""
|
||||
text = arguments.get("text")
|
||||
graph_model_file = arguments.get("graph_model_file")
|
||||
|
||||
if text:
|
||||
await add(data=text, user=user)
|
||||
|
||||
await cognify(user=user, ontology_file_path=graph_model_file if graph_model_file else None)
|
||||
|
||||
return (
|
||||
"Text successfully converted into knowledge graph."
|
||||
if text
|
||||
else "Knowledge graph successfully updated with new information."
|
||||
)
|
||||
|
||||
|
||||
async def handle_prune(arguments: Dict[str, Any], user) -> str:
|
||||
"""Handle prune function call"""
|
||||
await prune()
|
||||
return "Memory has been pruned successfully."
|
||||
102
cognee/api/v1/responses/models.py
Normal file
102
cognee/api/v1/responses/models.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from cognee.api.DTO import InDTO, OutDTO
|
||||
|
||||
|
||||
class CogneeModel(str, Enum):
|
||||
"""Enum for supported model types"""
|
||||
|
||||
COGNEEV1 = "cognee-v1"
|
||||
|
||||
|
||||
class FunctionParameters(BaseModel):
|
||||
"""JSON Schema for function parameters"""
|
||||
|
||||
type: str = "object"
|
||||
properties: Dict[str, Dict[str, Any]]
|
||||
required: Optional[List[str]] = None
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Function definition compatible with OpenAI's format"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: FunctionParameters
|
||||
|
||||
|
||||
class ToolFunction(BaseModel):
|
||||
"""Tool function wrapper (for OpenAI compatibility)"""
|
||||
|
||||
type: str = "function"
|
||||
function: Function
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
"""Function call made by the assistant"""
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Tool call made by the assistant"""
|
||||
|
||||
id: str = Field(default_factory=lambda: f"call_{uuid.uuid4().hex}")
|
||||
type: str = "function"
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class ChatUsage(BaseModel):
|
||||
"""Token usage information"""
|
||||
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class ResponseRequest(InDTO):
|
||||
"""Request body for the new responses endpoint (OpenAI Responses API format)"""
|
||||
|
||||
model: CogneeModel = CogneeModel.COGNEEV1
|
||||
input: str
|
||||
tools: Optional[List[ToolFunction]] = None
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto"
|
||||
user: Optional[str] = None
|
||||
temperature: Optional[float] = 1.0
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class ToolCallOutput(BaseModel):
|
||||
"""Output of a tool call in the responses API"""
|
||||
|
||||
status: str = "success" # success/error
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ResponseToolCall(BaseModel):
|
||||
"""Tool call in a response"""
|
||||
|
||||
id: str = Field(default_factory=lambda: f"call_{uuid.uuid4().hex}")
|
||||
type: str = "function"
|
||||
function: FunctionCall
|
||||
output: Optional[ToolCallOutput] = None
|
||||
|
||||
|
||||
class ResponseResponse(OutDTO):
|
||||
"""Response body for the new responses endpoint"""
|
||||
|
||||
id: str = Field(default_factory=lambda: f"resp_{uuid.uuid4().hex}")
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
object: str = "response"
|
||||
status: str = "completed"
|
||||
tool_calls: List[ResponseToolCall]
|
||||
usage: Optional[ChatUsage] = None
|
||||
metadata: Dict[str, Any] = None
|
||||
3
cognee/api/v1/responses/routers/__init__.py
Normal file
3
cognee/api/v1/responses/routers/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from cognee.api.v1.responses.routers.get_responses_router import get_responses_router
|
||||
|
||||
__all__ = ["get_responses_router"]
|
||||
86
cognee/api/v1/responses/routers/default_tools.py
Normal file
86
cognee/api/v1/responses/routers/default_tools.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
DEFAULT_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "search",
|
||||
"description": "Search for information within the knowledge graph",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search_query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for in the knowledge graph",
|
||||
},
|
||||
"search_type": {
|
||||
"type": "string",
|
||||
"description": "Type of search to perform",
|
||||
"enum": [
|
||||
"INSIGHTS",
|
||||
"CODE",
|
||||
"GRAPH_COMPLETION",
|
||||
"SEMANTIC",
|
||||
"NATURAL_LANGUAGE",
|
||||
],
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 10,
|
||||
},
|
||||
"datasets": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of dataset names to search within",
|
||||
},
|
||||
},
|
||||
"required": ["search_query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "cognify_text",
|
||||
"description": "Convert text into a knowledge graph or process all added content",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text content to be converted into a knowledge graph",
|
||||
},
|
||||
"graph_model_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the graph model to use",
|
||||
},
|
||||
"graph_model_file": {
|
||||
"type": "string",
|
||||
"description": "Path to a custom graph model file",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "prune",
|
||||
"description": "Remove unnecessary or outdated information from the knowledge graph",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prune_strategy": {
|
||||
"type": "string",
|
||||
"enum": ["light", "moderate", "aggressive"],
|
||||
"description": "Strategy for pruning the knowledge graph",
|
||||
"default": "moderate",
|
||||
},
|
||||
"min_confidence": {
|
||||
"type": "number",
|
||||
"description": "Minimum confidence score to retain (0-1)",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
},
|
||||
"older_than": {
|
||||
"type": "string",
|
||||
"description": "ISO date string - prune nodes older than this date",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
146
cognee/api/v1/responses/routers/get_responses_router.py
Normal file
146
cognee/api/v1/responses/routers/get_responses_router.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""
|
||||
Get router for the OpenAI-compatible responses API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Any
|
||||
import openai
|
||||
from fastapi import APIRouter
|
||||
from cognee.api.v1.responses.models import (
|
||||
ResponseRequest,
|
||||
ResponseResponse,
|
||||
ResponseToolCall,
|
||||
ChatUsage,
|
||||
FunctionCall,
|
||||
ToolCallOutput,
|
||||
)
|
||||
from cognee.api.v1.responses.dispatch_function import dispatch_function
|
||||
from cognee.api.v1.responses.default_tools import DEFAULT_TOOLS
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
|
||||
def get_responses_router() -> APIRouter:
|
||||
"""
|
||||
Returns the FastAPI router for OpenAI-compatible responses.
|
||||
|
||||
This implementation follows the new OpenAI Responses API format as described in:
|
||||
https://platform.openai.com/docs/api-reference/responses/create
|
||||
"""
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _get_model_client():
|
||||
"""
|
||||
Get appropriate client based on model name
|
||||
"""
|
||||
llm_config = get_llm_config()
|
||||
return openai.OpenAI(api_key=llm_config.llm_api_key)
|
||||
|
||||
async def call_openai_api_for_model(
|
||||
input_text: str,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = DEFAULT_TOOLS,
|
||||
tool_choice: Any = "auto",
|
||||
temperature: float = 1.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Call appropriate model API based on model name
|
||||
"""
|
||||
|
||||
# TODO: Support other models (e.g. cognee-v1-openai-gpt-3.5-turbo, etc.)
|
||||
model = "gpt-4o"
|
||||
|
||||
client = _get_model_client()
|
||||
|
||||
logger.debug(f"Using model: {model}")
|
||||
|
||||
response = client.responses.create(
|
||||
model=model,
|
||||
input=input_text,
|
||||
temperature=temperature,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
logger.info(f"Response: {response}")
|
||||
return response.model_dump()
|
||||
|
||||
@router.post("/", response_model=ResponseResponse)
|
||||
async def create_response(
|
||||
request: ResponseRequest,
|
||||
) -> ResponseResponse:
|
||||
"""
|
||||
OpenAI-compatible responses endpoint with function calling support
|
||||
"""
|
||||
# Use default tools if none provided
|
||||
tools = request.tools or DEFAULT_TOOLS
|
||||
|
||||
# Call the API
|
||||
response = await call_openai_api_for_model(
|
||||
input_text=request.input,
|
||||
model=request.model,
|
||||
tools=tools,
|
||||
tool_choice=request.tool_choice,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
|
||||
# Use the response ID from the API or generate a new one
|
||||
response_id = response.get("id", f"resp_{uuid.uuid4().hex}")
|
||||
|
||||
# Check if there are function tool calls in the output
|
||||
output = response.get("output", [])
|
||||
|
||||
processed_tool_calls = []
|
||||
|
||||
# Process any function tool calls from the output
|
||||
for item in output:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
# This is a tool call from the new format
|
||||
function_name = item.get("name", "")
|
||||
arguments_str = item.get("arguments", "{}")
|
||||
call_id = item.get("call_id", f"call_{uuid.uuid4().hex}")
|
||||
|
||||
# Create a format the dispatcher can handle
|
||||
tool_call = {
|
||||
"id": call_id,
|
||||
"function": {"name": function_name, "arguments": arguments_str},
|
||||
"type": "function",
|
||||
}
|
||||
|
||||
# Dispatch the function
|
||||
try:
|
||||
function_result = await dispatch_function(tool_call)
|
||||
output_status = "success"
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing function {function_name}: {e}")
|
||||
function_result = f"Error executing {function_name}: {str(e)}"
|
||||
output_status = "error"
|
||||
|
||||
processed_call = ResponseToolCall(
|
||||
id=call_id,
|
||||
type="function",
|
||||
function=FunctionCall(name=function_name, arguments=arguments_str),
|
||||
output=ToolCallOutput(status=output_status, data={"result": function_result}),
|
||||
)
|
||||
|
||||
processed_tool_calls.append(processed_call)
|
||||
|
||||
# Get usage data from the response if available
|
||||
usage = response.get("usage", {})
|
||||
|
||||
# Create the response object with all processed tool calls
|
||||
response_obj = ResponseResponse(
|
||||
id=response_id,
|
||||
model=request.model,
|
||||
tool_calls=processed_tool_calls,
|
||||
usage=ChatUsage(
|
||||
prompt_tokens=usage.get("input_tokens", 0),
|
||||
completion_tokens=usage.get("output_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
),
|
||||
)
|
||||
|
||||
return response_obj
|
||||
|
||||
return router
|
||||
Loading…
Add table
Reference in a new issue