diff --git a/app/route/middleware/sign/aes/aes.go b/app/route/middleware/sign/aes/aes.go new file mode 100644 index 00000000..d392d964 --- /dev/null +++ b/app/route/middleware/sign/aes/aes.go @@ -0,0 +1,120 @@ +package sign_aes + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "go-gin-api/app/config" + "go-gin-api/app/util" + "net/url" + "sort" + "strconv" + "strings" + "time" +) + +var AppSecret string + +// AES 对称加密 +func SetUp() gin.HandlerFunc { + + return func(c *gin.Context) { + utilGin := util.Gin{Ctx: c} + + sign, err := verifySign(c) + + if sign != nil { + utilGin.Response(-1, "Debug Sign", sign) + c.Abort() + return + } + + if err != nil { + utilGin.Response(-1, err.Error(), sign) + c.Abort() + return + } + + c.Next() + } +} + +// 验证签名 +func verifySign(c *gin.Context) (map[string]string, error) { + _ = c.Request.ParseForm() + req := c.Request.Form + debug := strings.Join(c.Request.Form["debug"], "") + ak := strings.Join(c.Request.Form["ak"], "") + sn := strings.Join(c.Request.Form["sn"], "") + ts := strings.Join(c.Request.Form["ts"], "") + + // 验证来源 + value, ok := config.ApiAuthConfig[ak] + if ok { + AppSecret = value["aes"] + } else { + return nil, errors.New("ak Error") + } + + if debug == "1" { + currentUnix := util.GetCurrentUnix() + req.Set("ts", strconv.FormatInt(currentUnix, 10)) + + sn, err := createSign(req) + if err != nil { + return nil, errors.New("sn Exception") + } + + res := map[string]string{ + "ts": strconv.FormatInt(currentUnix, 10), + "sn": sn, + } + return res, nil + } + + // 验证过期时间 + timestamp := time.Now().Unix() + exp, _ := strconv.ParseInt(config.AppSignExpiry, 10, 64) + tsInt, _ := strconv.ParseInt(ts, 10, 64) + if tsInt > timestamp || timestamp - tsInt >= exp { + return nil, errors.New("ts Error") + } + + // 验证签名 + if sn == "" { + return nil, errors.New("sn Error") + } + + decryptStr, decryptErr := util.AesDecrypt(sn, []byte(AppSecret), AppSecret) + if decryptErr != nil { + return nil, errors.New(decryptErr.Error()) + } + if decryptStr != createEncryptStr(req) { + return nil, errors.New("sn Error") + } + return nil, nil +} + +// 创建签名 +func createSign(params url.Values) (string, error) { + return util.AesEncrypt(createEncryptStr(params), []byte(AppSecret), AppSecret) +} + +func createEncryptStr(params url.Values) string { + var key []string + var str = "" + for k := range params { + if k != "sn" && k != "debug" { + key = append(key, k) + } + } + sort.Strings(key) + for i := 0; i < len(key); i++ { + if i == 0 { + str = fmt.Sprintf("%v=%v", key[i], params.Get(key[i])) + } else { + str = str + fmt.Sprintf("&%v=%v", key[i], params.Get(key[i])) + } + } + return str +} diff --git a/app/util/aes.go b/app/util/aes.go new file mode 100644 index 00000000..b8bfce8f --- /dev/null +++ b/app/util/aes.go @@ -0,0 +1,58 @@ +package util + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/base64" +) + +// 加密 aes_128_cbc +func AesEncrypt (encryptStr string, key []byte, iv string) (string, error) { + encryptBytes := []byte(encryptStr) + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + blockSize := block.BlockSize() + encryptBytes = pkcs5Padding(encryptBytes, blockSize) + + blockMode := cipher.NewCBCEncrypter(block, []byte(iv)) + encrypted := make([]byte, len(encryptBytes)) + blockMode.CryptBlocks(encrypted, encryptBytes) + return base64.URLEncoding.EncodeToString(encrypted), nil +} + +// 解密 +func AesDecrypt (decryptStr string, key []byte, iv string) (string, error) { + decryptBytes, err := base64.URLEncoding.DecodeString(decryptStr) + if err != nil { + return "", err + } + + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + blockMode := cipher.NewCBCDecrypter(block, []byte(iv)) + decrypted := make([]byte, len(decryptBytes)) + + blockMode.CryptBlocks(decrypted, decryptBytes) + decrypted = pkcs5UnPadding(decrypted) + return string(decrypted), nil +} + +func pkcs5Padding (cipherText []byte, blockSize int) []byte { + padding := blockSize - len(cipherText)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(cipherText, padText...) +} + +func pkcs5UnPadding (decrypted []byte) []byte { + length := len(decrypted) + unPadding := int(decrypted[length-1]) + return decrypted[:(length - unPadding)] +} +