|
@@ -1,8 +1,14 @@
|
|
|
package org.jeecg.modules.adweb.gpt.controller;
|
|
|
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+
|
|
|
+import org.apache.commons.lang3.tuple.Pair;
|
|
|
+import org.apache.shiro.SecurityUtils;
|
|
|
import org.jeecg.common.api.vo.Result;
|
|
|
+import org.jeecg.common.system.vo.LoginUser;
|
|
|
import org.jeecg.modules.adweb.gpt.service.ChatService;
|
|
|
import org.jeecg.modules.adweb.gpt.vo.ChatHistoryVO;
|
|
|
+import org.jeecg.modules.adweb.quota.service.IResourceQuotaService;
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
import org.springframework.stereotype.Controller;
|
|
|
import org.springframework.web.bind.annotation.*;
|
|
@@ -15,10 +21,13 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
|
*/
|
|
|
@Controller
|
|
|
@RequestMapping("/ai/chat")
|
|
|
+@Slf4j
|
|
|
public class ChatController {
|
|
|
|
|
|
@Autowired private ChatService chatService;
|
|
|
|
|
|
+ @Autowired private IResourceQuotaService resourceQuotaService;
|
|
|
+
|
|
|
/** 创建sse连接 */
|
|
|
@GetMapping(value = "/send")
|
|
|
public SseEmitter createConnect(
|
|
@@ -40,7 +49,8 @@ public class ChatController {
|
|
|
@PostMapping(value = "/history/save")
|
|
|
@ResponseBody
|
|
|
public Result<?> saveHistory(@RequestBody ChatHistoryVO chatHistoryVO) {
|
|
|
- return chatService.saveHistory(chatHistoryVO);
|
|
|
+ chatService.saveHistory(chatHistoryVO);
|
|
|
+ return Result.OK("保存成功");
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -53,7 +63,7 @@ public class ChatController {
|
|
|
@GetMapping(value = "/history/get")
|
|
|
@ResponseBody
|
|
|
public Result<ChatHistoryVO> getHistoryByTopic() {
|
|
|
- return chatService.getHistoryByTopic();
|
|
|
+ return Result.OK(chatService.getHistoryByTopic());
|
|
|
}
|
|
|
|
|
|
/** 关闭连接 */
|
|
@@ -61,4 +71,22 @@ public class ChatController {
|
|
|
public void closeConnect() {
|
|
|
chatService.closeChat();
|
|
|
}
|
|
|
+
|
|
|
+ /** 检查AI算力额度 */
|
|
|
+ @GetMapping(value = "checkQuota")
|
|
|
+ public Result<String> checkResourceQuota() {
|
|
|
+ String uid = ((LoginUser) SecurityUtils.getSubject().getPrincipal()).getId();
|
|
|
+
|
|
|
+ Pair<Integer, Integer> aiPowerQuota = resourceQuotaService.getAIPowerQuota(uid);
|
|
|
+ if (aiPowerQuota.getLeft() <= aiPowerQuota.getRight()) {
|
|
|
+ log.warn("用户 {} AI算力额度不足", uid);
|
|
|
+ return Result.error(
|
|
|
+ String.format(
|
|
|
+ "AI算力不足。用户资源额度为%d,已使用%d",
|
|
|
+ aiPowerQuota.getLeft(), aiPowerQuota.getRight()),
|
|
|
+ null);
|
|
|
+ }
|
|
|
+
|
|
|
+ return Result.OK("OK");
|
|
|
+ }
|
|
|
}
|