ragflow/api/db/services/memory_service.py
2025-12-04 14:44:49 +08:00

136 lines
5.5 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List
import json
from api.db.db_models import DB, Memory, User
from api.db.services import duplicate_name
from api.db.services.common_service import CommonService
from api.utils.api_utils import get_error_argument_result
from api.constants import MEMORY_NAME_LIMIT
from common.misc_utils import get_uuid
from common.time_utils import get_format_time, current_timestamp
class MemoryService(CommonService):
# Service class for manage memory operations
model = Memory
@classmethod
@DB.connection_context()
def get_by_memory_id(cls, memory_id: str) -> Memory:
return cls.model.select().where(cls.model.memory_id == memory_id).first()
@classmethod
@DB.connection_context()
def get_by_filter(cls, filter_dict: dict, keywords: str, page: int = 1, page_size: int = 50):
fields = [
cls.model.memory_id,
cls.model.memory_name,
cls.model.avatar,
cls.model.tenant_id,
User.nickname.alias("owner_name"),
cls.model.memory_type,
cls.model.storage_type,
cls.model.embedding,
cls.model.llm,
cls.model.permissions,
cls.model.description,
cls.model.memory_size,
cls.model.forgetting_policy,
cls.model.temperature,
cls.model.system_prompt,
cls.model.user_prompt
]
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
if filter_dict.get("tenant_id"):
memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"]))
if filter_dict.get("memory_type"):
match len(filter_dict["memory_type"]):
case 1:
memories = memories.where(cls.model.memory_type.contains(filter_dict["memory_type"][0]))
case 2:
memories = memories.where(cls.model.memory_type.contains(filter_dict["memory_type"][0]) | cls.model.memory_type.contains(filter_dict["memory_type"][1]))
case 3:
memories = memories.where(cls.model.memory_type.contains(filter_dict["memory_type"][0]) | cls.model.memory_type.contains(filter_dict["memory_type"][1]) | cls.model.memory_type.contains(filter_dict["memory_type"][2]) )
case _:
return get_error_argument_result(message="Invalid memory type")
if filter_dict.get("storage_type"):
memories = memories.where(cls.model.storage_type == filter_dict["storage_type"])
if keywords:
memories = memories.where(cls.model.memory_name.contains(keywords))
count = memories.count()
memories = memories.order_by(cls.model.update_time.desc())
memories = memories.paginate(page, page_size)
return list(memories.dicts()), count
@classmethod
@DB.connection_context()
def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embedding: str, llm: str):
# Deduplicate name within tenant
memory_name = duplicate_name(
cls.query,
name=name,
tenant_id=tenant_id
)
if len(memory_name) > MEMORY_NAME_LIMIT:
return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}."
# build create dict
memory_info = {
"memory_id": get_uuid(),
"memory_name": memory_name,
"memory_type": json.dumps(memory_type),
"tenant_id": tenant_id,
"embedding": embedding,
"llm": llm,
"create_time": current_timestamp(),
"create_date": get_format_time(),
"update_time": current_timestamp(),
"update_date": get_format_time(),
}
obj = cls.model(**memory_info).save()
if not obj:
return False, "Could not create new memory."
db_row = cls.model.select().where(cls.model.memory_id == memory_info["memory_id"]).first()
return obj, db_row
@classmethod
@DB.connection_context()
def update_memory(cls, memory_id: str, update_dict: dict):
if not update_dict:
return 0
if update_dict.get("memory_type") and isinstance(update_dict["memory_type"], list):
update_dict["memory_type"] = json.dumps(update_dict["memory_type"])
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
update_dict["temperature"] = json.loads(update_dict["temperature"])
update_dict.update({
"update_time": current_timestamp(),
"update_date": get_format_time()
})
return cls.model.update(update_dict).where(cls.model.memory_id == memory_id).execute()
@classmethod
@DB.connection_context()
def delete_memory(cls, memory_id: str):
return cls.model.delete().where(cls.model.memory_id == memory_id).execute()