wfansh 2 months ago
parent
commit
aefa0f20a6

+ 30 - 2
jeecg-module-system/jeecg-system-biz/src/main/java/org/jeecg/modules/adweb/gpt/controller/ChatController.java

@@ -1,8 +1,14 @@
 package org.jeecg.modules.adweb.gpt.controller;
 
+import lombok.extern.slf4j.Slf4j;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.shiro.SecurityUtils;
 import org.jeecg.common.api.vo.Result;
+import org.jeecg.common.system.vo.LoginUser;
 import org.jeecg.modules.adweb.gpt.service.ChatService;
 import org.jeecg.modules.adweb.gpt.vo.ChatHistoryVO;
+import org.jeecg.modules.adweb.quota.service.IResourceQuotaService;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Controller;
 import org.springframework.web.bind.annotation.*;
@@ -15,10 +21,13 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
  */
 @Controller
 @RequestMapping("/ai/chat")
+@Slf4j
 public class ChatController {
 
     @Autowired private ChatService chatService;
 
+    @Autowired private IResourceQuotaService resourceQuotaService;
+
     /** 创建sse连接 */
     @GetMapping(value = "/send")
     public SseEmitter createConnect(
@@ -40,7 +49,8 @@ public class ChatController {
     @PostMapping(value = "/history/save")
     @ResponseBody
     public Result<?> saveHistory(@RequestBody ChatHistoryVO chatHistoryVO) {
-        return chatService.saveHistory(chatHistoryVO);
+        chatService.saveHistory(chatHistoryVO);
+        return Result.OK("保存成功");
     }
 
     /**
@@ -53,7 +63,7 @@ public class ChatController {
     @GetMapping(value = "/history/get")
     @ResponseBody
     public Result<ChatHistoryVO> getHistoryByTopic() {
-        return chatService.getHistoryByTopic();
+        return Result.OK(chatService.getHistoryByTopic());
     }
 
     /** 关闭连接 */
@@ -61,4 +71,22 @@ public class ChatController {
     public void closeConnect() {
         chatService.closeChat();
     }
+
+    /** 检查AI算力额度 */
+    @GetMapping(value = "checkQuota")
+    public Result<String> checkResourceQuota() {
+        String uid = ((LoginUser) SecurityUtils.getSubject().getPrincipal()).getId();
+
+        Pair<Integer, Integer> aiPowerQuota = resourceQuotaService.getAIPowerQuota(uid);
+        if (aiPowerQuota.getLeft() <= aiPowerQuota.getRight()) {
+            log.warn("用户 {} AI算力额度不足", uid);
+            return Result.error(
+                    String.format(
+                            "AI算力不足。用户资源额度为%d,已使用%d",
+                            aiPowerQuota.getLeft(), aiPowerQuota.getRight()),
+                    null);
+        }
+
+        return Result.OK("OK");
+    }
 }

+ 2 - 3
jeecg-module-system/jeecg-system-biz/src/main/java/org/jeecg/modules/adweb/gpt/service/ChatService.java

@@ -1,6 +1,5 @@
 package org.jeecg.modules.adweb.gpt.service;
 
-import org.jeecg.common.api.vo.Result;
 import org.jeecg.modules.adweb.gpt.vo.ChatHistoryVO;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
@@ -36,7 +35,7 @@ public interface ChatService {
      * @author chenrui
      * @date 2024/2/22 13:37
      */
-    Result<?> saveHistory(ChatHistoryVO chatHistoryVO);
+    void saveHistory(ChatHistoryVO chatHistoryVO);
 
     /**
      * 查询聊天记录
@@ -45,5 +44,5 @@ public interface ChatService {
      * @author chenrui
      * @date 2024/2/22 13:59
      */
-    Result<ChatHistoryVO> getHistoryByTopic();
+    ChatHistoryVO getHistoryByTopic();
 }

+ 3 - 5
jeecg-module-system/jeecg-system-biz/src/main/java/org/jeecg/modules/adweb/gpt/service/impl/ChatServiceImpl.java

@@ -165,7 +165,7 @@ public class ChatServiceImpl implements ChatService {
     }
 
     @Override
-    public Result<?> saveHistory(ChatHistoryVO chatHistoryVO) {
+    public void saveHistory(ChatHistoryVO chatHistoryVO) {
         // String uid = getUserId();
         // String cacheKey = CACHE_KEY_PREFIX + CACHE_KEY_MSG_HISTORY + ":" + uid;
         // redisTemplate.opsForValue().set(cacheKey, chatHistoryVO.getContent());
@@ -175,12 +175,10 @@ public class ChatServiceImpl implements ChatService {
         ChatHistory chatHistory = chatHistoryService.getChatHistoryOfWeek(this.getUserId());
         chatHistory.setContent(chatHistoryVO.getContent());
         chatHistoryService.updateById(chatHistory);
-
-        return Result.OK("保存成功");
     }
 
     @Override
-    public Result<ChatHistoryVO> getHistoryByTopic() {
+    public ChatHistoryVO getHistoryByTopic() {
         // String uid = getUserId();
         // String cacheKey = CACHE_KEY_PREFIX + CACHE_KEY_MSG_HISTORY + ":" + uid;
         // String historyContent = (String) redisTemplate.opsForValue().get(cacheKey);
@@ -194,7 +192,7 @@ public class ChatServiceImpl implements ChatService {
         ChatHistoryVO chatHistoryVO = new ChatHistoryVO();
         chatHistoryVO.setContent(Objects.nonNull(chatHistory) ? chatHistory.getContent() : null);
 
-        return Result.OK(chatHistoryVO);
+        return chatHistoryVO;
     }
 
     /** 获取当前登陆用户ID */