| | package controller |
| |
|
| | import ( |
| | "context" |
| | "encoding/json" |
| | "fmt" |
| | "io" |
| | "time" |
| |
|
| | "github.com/QuantumNous/new-api/common" |
| | "github.com/QuantumNous/new-api/constant" |
| | "github.com/QuantumNous/new-api/dto" |
| | "github.com/QuantumNous/new-api/logger" |
| | "github.com/QuantumNous/new-api/model" |
| | "github.com/QuantumNous/new-api/relay" |
| | "github.com/QuantumNous/new-api/relay/channel" |
| | relaycommon "github.com/QuantumNous/new-api/relay/common" |
| | "github.com/QuantumNous/new-api/setting/ratio_setting" |
| | ) |
| |
|
| | func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { |
| | for channelId, taskIds := range taskChannelM { |
| | if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { |
| | logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) |
| | } |
| | } |
| | return nil |
| | } |
| |
|
| | func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { |
| | logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) |
| | if len(taskIds) == 0 { |
| | return nil |
| | } |
| | cacheGetChannel, err := model.CacheGetChannel(channelId) |
| | if err != nil { |
| | errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{ |
| | "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), |
| | "status": "FAILURE", |
| | "progress": "100%", |
| | }) |
| | if errUpdate != nil { |
| | common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) |
| | } |
| | return fmt.Errorf("CacheGetChannel failed: %w", err) |
| | } |
| | adaptor := relay.GetTaskAdaptor(platform) |
| | if adaptor == nil { |
| | return fmt.Errorf("video adaptor not found") |
| | } |
| | info := &relaycommon.RelayInfo{} |
| | info.ChannelMeta = &relaycommon.ChannelMeta{ |
| | ChannelBaseUrl: cacheGetChannel.GetBaseURL(), |
| | } |
| | info.ApiKey = cacheGetChannel.Key |
| | adaptor.Init(info) |
| | for _, taskId := range taskIds { |
| | if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { |
| | logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) |
| | } |
| | } |
| | return nil |
| | } |
| |
|
| | func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error { |
| | baseURL := constant.ChannelBaseURLs[channel.Type] |
| | if channel.GetBaseURL() != "" { |
| | baseURL = channel.GetBaseURL() |
| | } |
| |
|
| | task := taskM[taskId] |
| | if task == nil { |
| | logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) |
| | return fmt.Errorf("task %s not found", taskId) |
| | } |
| | resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ |
| | "task_id": taskId, |
| | "action": task.Action, |
| | }) |
| | if err != nil { |
| | return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) |
| | } |
| | |
| | |
| | |
| | defer resp.Body.Close() |
| | responseBody, err := io.ReadAll(resp.Body) |
| | if err != nil { |
| | return fmt.Errorf("readAll failed for task %s: %w", taskId, err) |
| | } |
| |
|
| | logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody))) |
| |
|
| | taskResult := &relaycommon.TaskInfo{} |
| | |
| | var responseItems dto.TaskResponse[model.Task] |
| | if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { |
| | logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems)) |
| | t := responseItems.Data |
| | taskResult.TaskID = t.TaskID |
| | taskResult.Status = string(t.Status) |
| | taskResult.Url = t.FailReason |
| | taskResult.Progress = t.Progress |
| | taskResult.Reason = t.FailReason |
| | task.Data = t.Data |
| | } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { |
| | return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) |
| | } else { |
| | task.Data = redactVideoResponseBody(responseBody) |
| | } |
| |
|
| | logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult)) |
| |
|
| | now := time.Now().Unix() |
| | if taskResult.Status == "" { |
| | |
| | taskResult = relaycommon.FailTaskInfo("upstream returned empty status") |
| | } |
| |
|
| | |
| | shouldRefund := false |
| | quota := task.Quota |
| | preStatus := task.Status |
| |
|
| | task.Status = model.TaskStatus(taskResult.Status) |
| | switch taskResult.Status { |
| | case model.TaskStatusSubmitted: |
| | task.Progress = "10%" |
| | case model.TaskStatusQueued: |
| | task.Progress = "20%" |
| | case model.TaskStatusInProgress: |
| | task.Progress = "30%" |
| | if task.StartTime == 0 { |
| | task.StartTime = now |
| | } |
| | case model.TaskStatusSuccess: |
| | task.Progress = "100%" |
| | if task.FinishTime == 0 { |
| | task.FinishTime = now |
| | } |
| | if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") { |
| | task.FailReason = taskResult.Url |
| | } |
| |
|
| | |
| | if taskResult.TotalTokens > 0 { |
| | |
| | var taskData map[string]interface{} |
| | if err := json.Unmarshal(task.Data, &taskData); err == nil { |
| | if modelName, ok := taskData["model"].(string); ok && modelName != "" { |
| | |
| | modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) |
| | |
| | if hasRatioSetting && modelRatio > 0 { |
| | |
| | group := task.Group |
| | if group == "" { |
| | user, err := model.GetUserById(task.UserId, false) |
| | if err == nil { |
| | group = user.Group |
| | } |
| | } |
| | if group != "" { |
| | groupRatio := ratio_setting.GetGroupRatio(group) |
| | userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group) |
| |
|
| | var finalGroupRatio float64 |
| | if hasUserGroupRatio { |
| | finalGroupRatio = userGroupRatio |
| | } else { |
| | finalGroupRatio = groupRatio |
| | } |
| |
|
| | |
| | actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio) |
| |
|
| | |
| | preConsumedQuota := task.Quota |
| | quotaDelta := actualQuota - preConsumedQuota |
| |
|
| | if quotaDelta > 0 { |
| | |
| | logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)", |
| | task.TaskID, |
| | logger.LogQuota(quotaDelta), |
| | logger.LogQuota(actualQuota), |
| | logger.LogQuota(preConsumedQuota), |
| | taskResult.TotalTokens, |
| | )) |
| | if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil { |
| | logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error())) |
| | } else { |
| | model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) |
| | model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) |
| | task.Quota = actualQuota |
| |
|
| | |
| | logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s", |
| | modelRatio, finalGroupRatio, taskResult.TotalTokens, |
| | logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta)) |
| | model.RecordLog(task.UserId, model.LogTypeSystem, logContent) |
| | } |
| | } else if quotaDelta < 0 { |
| | |
| | refundQuota := -quotaDelta |
| | logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)", |
| | task.TaskID, |
| | logger.LogQuota(refundQuota), |
| | logger.LogQuota(actualQuota), |
| | logger.LogQuota(preConsumedQuota), |
| | taskResult.TotalTokens, |
| | )) |
| | if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil { |
| | logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error())) |
| | } else { |
| | task.Quota = actualQuota |
| |
|
| | |
| | logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s", |
| | modelRatio, finalGroupRatio, taskResult.TotalTokens, |
| | logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota)) |
| | model.RecordLog(task.UserId, model.LogTypeSystem, logContent) |
| | } |
| | } else { |
| | |
| | logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)", |
| | task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens)) |
| | } |
| | } |
| | } |
| | } |
| | } |
| | } |
| | case model.TaskStatusFailure: |
| | logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) |
| | task.Status = model.TaskStatusFailure |
| | task.Progress = "100%" |
| | if task.FinishTime == 0 { |
| | task.FinishTime = now |
| | } |
| | task.FailReason = taskResult.Reason |
| | logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) |
| | taskResult.Progress = "100%" |
| | if quota != 0 { |
| | if preStatus != model.TaskStatusFailure { |
| | shouldRefund = true |
| | } else { |
| | logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID)) |
| | } |
| | } |
| | default: |
| | return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) |
| | } |
| | if taskResult.Progress != "" { |
| | task.Progress = taskResult.Progress |
| | } |
| | if err := task.Update(); err != nil { |
| | common.SysLog("UpdateVideoTask task error: " + err.Error()) |
| | shouldRefund = false |
| | } |
| |
|
| | if shouldRefund { |
| | |
| | if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { |
| | logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error()) |
| | } |
| | logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota)) |
| | model.RecordLog(task.UserId, model.LogTypeSystem, logContent) |
| | } |
| |
|
| | return nil |
| | } |
| |
|
| | func redactVideoResponseBody(body []byte) []byte { |
| | var m map[string]any |
| | if err := json.Unmarshal(body, &m); err != nil { |
| | return body |
| | } |
| | resp, _ := m["response"].(map[string]any) |
| | if resp != nil { |
| | delete(resp, "bytesBase64Encoded") |
| | if v, ok := resp["video"].(string); ok { |
| | resp["video"] = truncateBase64(v) |
| | } |
| | if vs, ok := resp["videos"].([]any); ok { |
| | for i := range vs { |
| | if vm, ok := vs[i].(map[string]any); ok { |
| | delete(vm, "bytesBase64Encoded") |
| | } |
| | } |
| | } |
| | } |
| | b, err := json.Marshal(m) |
| | if err != nil { |
| | return body |
| | } |
| | return b |
| | } |
| |
|
| | func truncateBase64(s string) string { |
| | const maxKeep = 256 |
| | if len(s) <= maxKeep { |
| | return s |
| | } |
| | return s[:maxKeep] + "..." |
| | } |
| |
|