# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import List import json from api.db.db_models import DB, Memory, User from api.db.services import duplicate_name from api.db.services.common_service import CommonService from api.utils.api_utils import get_error_argument_result from api.constants import MEMORY_NAME_LIMIT from common.misc_utils import get_uuid from common.time_utils import get_format_time, current_timestamp class MemoryService(CommonService): # Service class for manage memory operations model = Memory @classmethod @DB.connection_context() def get_by_memory_id(cls, memory_id: str) -> Memory: return cls.model.select().where(cls.model.memory_id == memory_id).first() @classmethod @DB.connection_context() def get_by_filter(cls, filter_dict: dict, keywords: str, page: int = 1, page_size: int = 50): fields = [ cls.model.memory_id, cls.model.memory_name, cls.model.avatar, cls.model.tenant_id, User.nickname.alias("owner_name"), cls.model.memory_type, cls.model.storage_type, cls.model.embedding, cls.model.llm, cls.model.permissions, cls.model.description, cls.model.memory_size, cls.model.forgetting_policy, cls.model.temperature, cls.model.system_prompt, cls.model.user_prompt ] memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)) if filter_dict.get("tenant_id"): memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"])) if filter_dict.get("memory_type"): match len(filter_dict["memory_type"]): case 1: memories = memories.where(cls.model.memory_type.contains(filter_dict["memory_type"][0])) case 2: memories = memories.where(cls.model.memory_type.contains(filter_dict["memory_type"][0]) | cls.model.memory_type.contains(filter_dict["memory_type"][1])) case 3: memories = memories.where(cls.model.memory_type.contains(filter_dict["memory_type"][0]) | cls.model.memory_type.contains(filter_dict["memory_type"][1]) | cls.model.memory_type.contains(filter_dict["memory_type"][2]) ) case _: return get_error_argument_result(message="Invalid memory type") if filter_dict.get("storage_type"): memories = memories.where(cls.model.storage_type == filter_dict["storage_type"]) if keywords: memories = memories.where(cls.model.memory_name.contains(keywords)) count = memories.count() memories = memories.order_by(cls.model.update_time.desc()) memories = memories.paginate(page, page_size) return list(memories.dicts()), count @classmethod @DB.connection_context() def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embedding: str, llm: str): # Deduplicate name within tenant memory_name = duplicate_name( cls.query, name=name, tenant_id=tenant_id ) if len(memory_name) > MEMORY_NAME_LIMIT: return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}." # build create dict memory_info = { "memory_id": get_uuid(), "memory_name": memory_name, "memory_type": json.dumps(memory_type), "tenant_id": tenant_id, "embedding": embedding, "llm": llm, "create_time": current_timestamp(), "create_date": get_format_time(), "update_time": current_timestamp(), "update_date": get_format_time(), } obj = cls.model(**memory_info).save() if not obj: return False, "Could not create new memory." db_row = cls.model.select().where(cls.model.memory_id == memory_info["memory_id"]).first() return obj, db_row @classmethod @DB.connection_context() def update_memory(cls, memory_id: str, update_dict: dict): if not update_dict: return 0 if update_dict.get("memory_type") and isinstance(update_dict["memory_type"], list): update_dict["memory_type"] = json.dumps(update_dict["memory_type"]) if "temperature" in update_dict and isinstance(update_dict["temperature"], str): update_dict["temperature"] = json.loads(update_dict["temperature"]) update_dict.update({ "update_time": current_timestamp(), "update_date": get_format_time() }) return cls.model.update(update_dict).where(cls.model.memory_id == memory_id).execute() @classmethod @DB.connection_context() def delete_memory(cls, memory_id: str): return cls.model.delete().where(cls.model.memory_id == memory_id).execute()