Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fixes #63

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example.env
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
STAGE=development

# Log level
LOG_LEVEL=debug
LOG_LEVEL=DEBUG

# Gorm
DATABASE_URL="root:root@tcp(localhost:3306)/chat-app?charset=utf8mb4&parseTime=True&loc=Local"
Expand Down
2 changes: 1 addition & 1 deletion scripts/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ then
then
brew install pre-commit
else
pip install pre-commit
pipx install pre-commit
fi
fi

Expand Down
17 changes: 13 additions & 4 deletions src/api/auth/auth.controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ func LoginController(c *gin.Context) {
return
}

user, err := LoginService(db, cfg, payload, logger)
tokens, err := LoginService(db, cfg, payload, logger)
if err != nil {
c.JSON(err.Code, err)
return
}

c.JSON(200, user)
c.SetCookie("access_token", tokens.AccessToken, cfg.JwtExpirationTime, "/", "", cfg.Stage == "production", true)
c.SetCookie("refresh_token", tokens.RefreshToken, cfg.RefreshExpirationTime, "/", "", cfg.Stage == "production", true)

c.JSON(200, gin.H{})
}

func CheckLoginController(c *gin.Context) {
Expand Down Expand Up @@ -74,14 +77,17 @@ func RefreshController(c *gin.Context) {
return
}

tokens, err := RefreshService(db, cfg, payload.(*JWTPayload), logger)
tokens, err := RefreshService(db, cfg, payload.(*JWTAccessTokenPayload), logger)

if err != nil {
c.JSON(err.Code, err)
return
}

c.JSON(200, tokens)
c.SetCookie("access_token", tokens.AccessToken, cfg.JwtExpirationTime, "/", "", cfg.Stage == "production", true)
c.SetCookie("refresh_token", tokens.RefreshToken, cfg.RefreshExpirationTime, "/", "", cfg.Stage == "production", true)

c.JSON(200, gin.H{})
}

func LogoutController(c *gin.Context) {
Expand Down Expand Up @@ -120,5 +126,8 @@ func LogoutController(c *gin.Context) {
return
}

c.SetCookie("access_token", "", 0, "/", "", cfg.Stage == "production", true)
c.SetCookie("refresh_token", "", 0, "/", "", cfg.Stage == "production", true)

c.JSON(200, gin.H{})
}
18 changes: 0 additions & 18 deletions src/api/auth/auth.dto.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package auth

import "time"

type LoginRequest struct {
Email string `json:"email" validate:"required,email"`
Password string `json:"password" validate:"required"`
Expand All @@ -11,19 +9,3 @@ type RefreshTokenResponse struct {
JWTPair
AccessTokenExpires int64 `json:"accessTokenExpires"`
}

type UserWithTokens struct {
Id string `json:"id"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
Email string `json:"email"`
Name string `json:"name"`
Bio *string `json:"bio"`
Iv string `json:"iv"`
PublicKey string `json:"publicKey"`
PrivateKey string `json:"privateKey"`
ProfilePicture *string `json:"profilePicture"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
AccessTokenExpires int64 `json:"accessTokenExpires"`
}
45 changes: 21 additions & 24 deletions src/api/auth/auth.guard.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"easyflow-backend/src/enum"
"errors"
"net/http"
"strings"

"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
Expand All @@ -26,8 +25,17 @@ func AuthGuard() gin.HandlerFunc {
return
}

// Get access_token from header
accessToken := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
// Get access_token from cookies
accessToken, err := c.Cookie("access_token")
if err != nil {
logger.PrintfWarning("Error while getting access token cookie: %s", err.Error())
c.JSON(http.StatusUnauthorized, api.ApiError{
Code: http.StatusUnauthorized,
Error: enum.Unauthorized,
})
c.Abort()
return
}

if accessToken == "" {
logger.PrintfWarning("No access token provided")
Expand Down Expand Up @@ -70,16 +78,6 @@ func AuthGuard() gin.HandlerFunc {
return
}

if !payload.IsAccess {
logger.PrintfWarning("Invalid token type")
c.JSON(http.StatusUnauthorized, api.ApiError{
Code: http.StatusUnauthorized,
Error: enum.InvalidCookie,
})
c.Abort()
return
}

// Set user payload in context
c.Set("user", payload)
c.Next()
Expand All @@ -99,7 +97,16 @@ func RefreshAuthGuard() gin.HandlerFunc {
return
}

refreshToken := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
refreshToken, err := c.Cookie("refresh_token")
if err != nil {
logger.PrintfWarning("Error while getting refresh token cookie: %s", err.Error())
c.JSON(http.StatusUnauthorized, api.ApiError{
Code: http.StatusUnauthorized,
Error: enum.Unauthorized,
})
c.Abort()
return
}

if refreshToken == "" {
logger.PrintfWarning("No refresh token provided")
Expand Down Expand Up @@ -140,16 +147,6 @@ func RefreshAuthGuard() gin.HandlerFunc {
return
}

if token.IsAccess {
logger.PrintfWarning("Invalid token type")
c.JSON(http.StatusUnauthorized, api.ApiError{
Code: http.StatusUnauthorized,
Error: enum.InvalidCookie,
})
c.Abort()
return
}

if err := db.First(&database.UserKeys{}, "user_id = ? AND random = ?", token.UserId, token.RefreshRand).Error; err != nil {
logger.PrintfWarning("Invalid refresh token")
c.JSON(http.StatusUnauthorized, api.ApiError{
Expand Down
87 changes: 37 additions & 50 deletions src/api/auth/auth.service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"easyflow-backend/src/common"
"easyflow-backend/src/database"
"easyflow-backend/src/enum"
"fmt"
"net/http"
"net/url"
"strconv"
Expand All @@ -17,34 +18,42 @@ import (
"gorm.io/gorm"
)

func generateJwt(cfg *common.Config, payload *JWTPayload) (string, error) {
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
token, err := claims.SignedString([]byte(cfg.JwtSecret))
func generateJwt[T interface{ jwt.Claims }](cfg *common.Config, payload T) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
signedToken, err := token.SignedString([]byte(cfg.JwtSecret))
if err != nil {
return "", err
return "", fmt.Errorf("failed to sign token: %w", err)
}

return token, nil
return signedToken, nil
}

func ValidateToken(cfg *common.Config, token string) (*JWTPayload, error) {
claims := &JWTPayload{}
_, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte(cfg.JwtSecret), nil
})
func ValidateToken(cfg *common.Config, token string) (*JWTAccessTokenPayload, error) {
var claims JWTAccessTokenPayload
_, err := jwt.ParseWithClaims(
token,
&claims,
func(token *jwt.Token) (interface{}, error) {
// Verify that the signing method is what we expect
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(cfg.JwtSecret), nil
},
)

if err != nil {
return nil, err
}

return claims, nil
return &claims, nil
}

func LoginService(db *gorm.DB, cfg *common.Config, payload *LoginRequest, logger *common.Logger) (UserWithTokens, *api.ApiError) {
func LoginService(db *gorm.DB, cfg *common.Config, payload *LoginRequest, logger *common.Logger) (JWTPair, *api.ApiError) {
var user database.User
if err := db.Where("email = ?", payload.Email).First(&user).Error; err != nil {
logger.PrintfWarning("User with email: %s not found", payload.Email)
return UserWithTokens{}, &api.ApiError{
return JWTPair{}, &api.ApiError{
Code: http.StatusUnauthorized,
Error: enum.WrongCredentials,
Details: err,
Expand All @@ -54,7 +63,7 @@ func LoginService(db *gorm.DB, cfg *common.Config, payload *LoginRequest, logger
//check password
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(payload.Password)); err != nil {
logger.PrintfWarning("Wrong password for user with email: %s", payload.Email)
return UserWithTokens{}, &api.ApiError{
return JWTPair{}, &api.ApiError{
Code: http.StatusUnauthorized,
Error: enum.WrongCredentials,
Details: err,
Expand All @@ -65,46 +74,42 @@ func LoginService(db *gorm.DB, cfg *common.Config, payload *LoginRequest, logger
expires := time.Now().Add(time.Duration(cfg.JwtExpirationTime) * time.Second)
refreshExpires := time.Now().Add(time.Duration(cfg.RefreshExpirationTime) * time.Second)

accessTokenPayload := JWTPayload{
accessTokenPayload := JWTAccessTokenPayload{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expires),
Issuer: "easyflow",
IssuedAt: jwt.NewNumericDate(time.Now()),
},
UserId: user.Id,
Email: user.Email,
RefreshRand: &random,
IsAccess: true,
}

refreshTokenPayload := JWTPayload{
refreshTokenPayload := JWTAccessTokenPayload{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(refreshExpires),
Issuer: "easyflow",
IssuedAt: jwt.NewNumericDate(time.Now()),
},
UserId: user.Id,
Email: user.Email,
RefreshRand: &random,
IsAccess: false,
}

accessToken, err := generateJwt(cfg, &accessTokenPayload)
accessToken, err := generateJwt[JWTAccessTokenPayload](cfg, accessTokenPayload)

if err != nil {
logger.PrintfError("Error generating jwt: %s", err)
return UserWithTokens{}, &api.ApiError{
return JWTPair{}, &api.ApiError{
Code: http.StatusInternalServerError,
Error: enum.ApiError,
Details: err,
}
}

refreshToken, err := generateJwt(cfg, &refreshTokenPayload)
refreshToken, err := generateJwt[JWTAccessTokenPayload](cfg, refreshTokenPayload)

if err != nil {
logger.PrintfError("Error generating jwt: %s", err)
return UserWithTokens{}, &api.ApiError{
return JWTPair{}, &api.ApiError{
Code: http.StatusInternalServerError,
Error: enum.ApiError,
Details: err,
Expand All @@ -120,7 +125,7 @@ func LoginService(db *gorm.DB, cfg *common.Config, payload *LoginRequest, logger

if err := db.Save(&entry).Error; err != nil {
logger.PrintfError("Error updating user key: %s", err)
return UserWithTokens{}, &api.ApiError{
return JWTPair{}, &api.ApiError{
Code: http.StatusInternalServerError,
Error: enum.ApiError,
Details: err,
Expand Down Expand Up @@ -159,27 +164,13 @@ func LoginService(db *gorm.DB, cfg *common.Config, payload *LoginRequest, logger

logger.Printf("Logged in user: %s", user.Id)

logger.PrintfDebug("Access: %s", accessToken)
logger.PrintfDebug("Refresh: %s", refreshToken)

return UserWithTokens{
Id: user.Id,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
Email: user.Email,
Name: user.Name,
Bio: user.Bio,
Iv: user.Iv,
PublicKey: user.PublicKey,
PrivateKey: user.PrivateKey,
ProfilePicture: user.ProfilePicture,
AccessToken: accessToken,
RefreshToken: refreshToken,
AccessTokenExpires: expires.Unix(),
return JWTPair{
RefreshToken: refreshToken,
AccessToken: accessToken,
}, nil
}

func RefreshService(db *gorm.DB, cfg *common.Config, payload *JWTPayload, logger *common.Logger) (RefreshTokenResponse, *api.ApiError) {
func RefreshService(db *gorm.DB, cfg *common.Config, payload *JWTAccessTokenPayload, logger *common.Logger) (RefreshTokenResponse, *api.ApiError) {
//get user from db
var user database.User
if err := db.First(&user, "id = ?", payload.UserId).Error; err != nil {
Expand All @@ -195,28 +186,24 @@ func RefreshService(db *gorm.DB, cfg *common.Config, payload *JWTPayload, logger
expires := time.Now().Add(time.Duration(cfg.JwtExpirationTime) * time.Second)
refreshExpires := time.Now().Add(time.Duration(cfg.RefreshExpirationTime) * time.Second)

accessTokenPayload := JWTPayload{
accessTokenPayload := JWTAccessTokenPayload{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expires),
Issuer: "easyflow",
IssuedAt: jwt.NewNumericDate(time.Now()),
},
UserId: user.Id,
Email: user.Email,
RefreshRand: &random,
IsAccess: true,
}

refreshTokenPayload := JWTPayload{
refreshTokenPayload := JWTAccessTokenPayload{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(refreshExpires),
Issuer: "easyflow",
IssuedAt: jwt.NewNumericDate(time.Now()),
},
UserId: user.Id,
Email: user.Email,
RefreshRand: &random,
IsAccess: false,
}

accessToken, err := generateJwt(cfg, &accessTokenPayload)
Expand Down Expand Up @@ -267,7 +254,7 @@ func RefreshService(db *gorm.DB, cfg *common.Config, payload *JWTPayload, logger
}, nil
}

func LogoutService(db *gorm.DB, payload *JWTPayload, logger *common.Logger) *api.ApiError {
func LogoutService(db *gorm.DB, payload *JWTAccessTokenPayload, logger *common.Logger) *api.ApiError {
if err := db.Delete(&database.UserKeys{}, payload.RefreshRand).Error; err != nil {
logger.PrintfError("Could not delete Refresh Token with random: %s and user id: %s", payload.RefreshRand, payload.UserId)
return &api.ApiError{
Expand Down
4 changes: 1 addition & 3 deletions src/api/auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ import (
"github.com/google/uuid"
)

type JWTPayload struct {
type JWTAccessTokenPayload struct {
jwt.RegisteredClaims
UserId string `json:"userId"`
Email string `json:"email"`
RefreshRand *uuid.UUID `json:"refreshRand"`
IsAccess bool `json:"isAccess"`
}

type JWTPair struct {
Expand Down
Loading
Loading