| | package types |
| |
|
| | import ( |
| | "errors" |
| | "fmt" |
| | "net/http" |
| | "strings" |
| |
|
| | "github.com/QuantumNous/new-api/common" |
| | ) |
| |
|
| | type OpenAIError struct { |
| | Message string `json:"message"` |
| | Type string `json:"type"` |
| | Param string `json:"param"` |
| | Code any `json:"code"` |
| | } |
| |
|
| | type ClaudeError struct { |
| | Type string `json:"type,omitempty"` |
| | Message string `json:"message,omitempty"` |
| | } |
| |
|
| | type ErrorType string |
| |
|
| | const ( |
| | ErrorTypeNewAPIError ErrorType = "new_api_error" |
| | ErrorTypeOpenAIError ErrorType = "openai_error" |
| | ErrorTypeClaudeError ErrorType = "claude_error" |
| | ErrorTypeMidjourneyError ErrorType = "midjourney_error" |
| | ErrorTypeGeminiError ErrorType = "gemini_error" |
| | ErrorTypeRerankError ErrorType = "rerank_error" |
| | ErrorTypeUpstreamError ErrorType = "upstream_error" |
| | ) |
| |
|
| | type ErrorCode string |
| |
|
| | const ( |
| | ErrorCodeInvalidRequest ErrorCode = "invalid_request" |
| | ErrorCodeSensitiveWordsDetected ErrorCode = "sensitive_words_detected" |
| |
|
| | |
| | ErrorCodeCountTokenFailed ErrorCode = "count_token_failed" |
| | ErrorCodeModelPriceError ErrorCode = "model_price_error" |
| | ErrorCodeInvalidApiType ErrorCode = "invalid_api_type" |
| | ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed" |
| | ErrorCodeDoRequestFailed ErrorCode = "do_request_failed" |
| | ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed" |
| | ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed" |
| |
|
| | |
| | ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" |
| | ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid" |
| | ErrorCodeChannelHeaderOverrideInvalid ErrorCode = "channel:header_override_invalid" |
| | ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error" |
| | ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error" |
| | ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key" |
| | ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded" |
| |
|
| | |
| | ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed" |
| | ErrorCodeConvertRequestFailed ErrorCode = "convert_request_failed" |
| | ErrorCodeAccessDenied ErrorCode = "access_denied" |
| |
|
| | |
| | ErrorCodeBadRequestBody ErrorCode = "bad_request_body" |
| |
|
| | |
| | ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed" |
| | ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code" |
| | ErrorCodeBadResponse ErrorCode = "bad_response" |
| | ErrorCodeBadResponseBody ErrorCode = "bad_response_body" |
| | ErrorCodeEmptyResponse ErrorCode = "empty_response" |
| | ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error" |
| | ErrorCodeModelNotFound ErrorCode = "model_not_found" |
| | ErrorCodePromptBlocked ErrorCode = "prompt_blocked" |
| |
|
| | |
| | ErrorCodeQueryDataError ErrorCode = "query_data_error" |
| | ErrorCodeUpdateDataError ErrorCode = "update_data_error" |
| |
|
| | |
| | ErrorCodeInsufficientUserQuota ErrorCode = "insufficient_user_quota" |
| | ErrorCodePreConsumeTokenQuotaFailed ErrorCode = "pre_consume_token_quota_failed" |
| | ) |
| |
|
| | type NewAPIError struct { |
| | Err error |
| | RelayError any |
| | skipRetry bool |
| | recordErrorLog *bool |
| | errorType ErrorType |
| | errorCode ErrorCode |
| | StatusCode int |
| | } |
| |
|
| | func (e *NewAPIError) GetErrorCode() ErrorCode { |
| | if e == nil { |
| | return "" |
| | } |
| | return e.errorCode |
| | } |
| |
|
| | func (e *NewAPIError) GetErrorType() ErrorType { |
| | if e == nil { |
| | return "" |
| | } |
| | return e.errorType |
| | } |
| |
|
| | func (e *NewAPIError) Error() string { |
| | if e == nil { |
| | return "" |
| | } |
| | if e.Err == nil { |
| | |
| | return string(e.errorCode) |
| | } |
| | return e.Err.Error() |
| | } |
| |
|
| | func (e *NewAPIError) MaskSensitiveError() string { |
| | if e == nil { |
| | return "" |
| | } |
| | if e.Err == nil { |
| | return string(e.errorCode) |
| | } |
| | errStr := e.Err.Error() |
| | if e.errorCode == ErrorCodeCountTokenFailed { |
| | return errStr |
| | } |
| | return common.MaskSensitiveInfo(errStr) |
| | } |
| |
|
| | func (e *NewAPIError) SetMessage(message string) { |
| | e.Err = errors.New(message) |
| | } |
| |
|
| | func (e *NewAPIError) ToOpenAIError() OpenAIError { |
| | var result OpenAIError |
| | switch e.errorType { |
| | case ErrorTypeOpenAIError: |
| | if openAIError, ok := e.RelayError.(OpenAIError); ok { |
| | result = openAIError |
| | } |
| | case ErrorTypeClaudeError: |
| | if claudeError, ok := e.RelayError.(ClaudeError); ok { |
| | result = OpenAIError{ |
| | Message: e.Error(), |
| | Type: claudeError.Type, |
| | Param: "", |
| | Code: e.errorCode, |
| | } |
| | } |
| | default: |
| | result = OpenAIError{ |
| | Message: e.Error(), |
| | Type: string(e.errorType), |
| | Param: "", |
| | Code: e.errorCode, |
| | } |
| | } |
| | if e.errorCode != ErrorCodeCountTokenFailed { |
| | result.Message = common.MaskSensitiveInfo(result.Message) |
| | } |
| | if result.Message == "" { |
| | result.Message = string(e.errorType) |
| | } |
| | return result |
| | } |
| |
|
| | func (e *NewAPIError) ToClaudeError() ClaudeError { |
| | var result ClaudeError |
| | switch e.errorType { |
| | case ErrorTypeOpenAIError: |
| | if openAIError, ok := e.RelayError.(OpenAIError); ok { |
| | result = ClaudeError{ |
| | Message: e.Error(), |
| | Type: fmt.Sprintf("%v", openAIError.Code), |
| | } |
| | } |
| | case ErrorTypeClaudeError: |
| | if claudeError, ok := e.RelayError.(ClaudeError); ok { |
| | result = claudeError |
| | } |
| | default: |
| | result = ClaudeError{ |
| | Message: e.Error(), |
| | Type: string(e.errorType), |
| | } |
| | } |
| | if e.errorCode != ErrorCodeCountTokenFailed { |
| | result.Message = common.MaskSensitiveInfo(result.Message) |
| | } |
| | if result.Message == "" { |
| | result.Message = string(e.errorType) |
| | } |
| | return result |
| | } |
| |
|
| | type NewAPIErrorOptions func(*NewAPIError) |
| |
|
| | func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError { |
| | var newErr *NewAPIError |
| | |
| | if errors.As(err, &newErr) { |
| | for _, op := range ops { |
| | op(newErr) |
| | } |
| | return newErr |
| | } |
| | e := &NewAPIError{ |
| | Err: err, |
| | RelayError: nil, |
| | errorType: ErrorTypeNewAPIError, |
| | StatusCode: http.StatusInternalServerError, |
| | errorCode: errorCode, |
| | } |
| | for _, op := range ops { |
| | op(e) |
| | } |
| | return e |
| | } |
| |
|
| | func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { |
| | var newErr *NewAPIError |
| | |
| | if errors.As(err, &newErr) { |
| | if newErr.RelayError == nil { |
| | openaiError := OpenAIError{ |
| | Message: newErr.Error(), |
| | Type: string(errorCode), |
| | Code: errorCode, |
| | } |
| | newErr.RelayError = openaiError |
| | } |
| | for _, op := range ops { |
| | op(newErr) |
| | } |
| | return newErr |
| | } |
| | openaiError := OpenAIError{ |
| | Message: err.Error(), |
| | Type: string(errorCode), |
| | Code: errorCode, |
| | } |
| | return WithOpenAIError(openaiError, statusCode, ops...) |
| | } |
| |
|
| | func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { |
| | openaiError := OpenAIError{ |
| | Type: string(errorCode), |
| | Code: errorCode, |
| | } |
| | return WithOpenAIError(openaiError, statusCode, ops...) |
| | } |
| |
|
| | func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { |
| | e := &NewAPIError{ |
| | Err: err, |
| | RelayError: OpenAIError{ |
| | Message: err.Error(), |
| | Type: string(errorCode), |
| | }, |
| | errorType: ErrorTypeNewAPIError, |
| | StatusCode: statusCode, |
| | errorCode: errorCode, |
| | } |
| | for _, op := range ops { |
| | op(e) |
| | } |
| |
|
| | return e |
| | } |
| |
|
| | func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { |
| | code, ok := openAIError.Code.(string) |
| | if !ok { |
| | if openAIError.Code != nil { |
| | code = fmt.Sprintf("%v", openAIError.Code) |
| | } else { |
| | code = "unknown_error" |
| | } |
| | } |
| | if openAIError.Type == "" { |
| | openAIError.Type = "upstream_error" |
| | } |
| | e := &NewAPIError{ |
| | RelayError: openAIError, |
| | errorType: ErrorTypeOpenAIError, |
| | StatusCode: statusCode, |
| | Err: errors.New(openAIError.Message), |
| | errorCode: ErrorCode(code), |
| | } |
| | for _, op := range ops { |
| | op(e) |
| | } |
| | return e |
| | } |
| |
|
| | func WithClaudeError(claudeError ClaudeError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { |
| | if claudeError.Type == "" { |
| | claudeError.Type = "upstream_error" |
| | } |
| | e := &NewAPIError{ |
| | RelayError: claudeError, |
| | errorType: ErrorTypeClaudeError, |
| | StatusCode: statusCode, |
| | Err: errors.New(claudeError.Message), |
| | errorCode: ErrorCode(claudeError.Type), |
| | } |
| | for _, op := range ops { |
| | op(e) |
| | } |
| | return e |
| | } |
| |
|
| | func IsChannelError(err *NewAPIError) bool { |
| | if err == nil { |
| | return false |
| | } |
| | return strings.HasPrefix(string(err.errorCode), "channel:") |
| | } |
| |
|
| | func IsSkipRetryError(err *NewAPIError) bool { |
| | if err == nil { |
| | return false |
| | } |
| |
|
| | return err.skipRetry |
| | } |
| |
|
| | func ErrOptionWithSkipRetry() NewAPIErrorOptions { |
| | return func(e *NewAPIError) { |
| | e.skipRetry = true |
| | } |
| | } |
| |
|
| | func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions { |
| | return func(e *NewAPIError) { |
| | e.recordErrorLog = common.GetPointer(false) |
| | } |
| | } |
| |
|
| | func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions { |
| | return func(e *NewAPIError) { |
| | if common.DebugEnabled { |
| | fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err) |
| | } |
| | e.Err = errors.New(replaceStr) |
| | } |
| | } |
| |
|
| | func IsRecordErrorLog(e *NewAPIError) bool { |
| | if e == nil { |
| | return false |
| | } |
| | if e.recordErrorLog == nil { |
| | |
| | return true |
| | } |
| | return *e.recordErrorLog |
| | } |
| |
|