Skip to content

Commit

Permalink
generic callback and login handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
jesperkha committed Apr 23, 2024
1 parent 3454429 commit 7ec7e55
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 85 deletions.
109 changes: 109 additions & 0 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package auth

import (
"context"
"encoding/json"
"io"
"log"
"net/http"

"golang.org/x/oauth2"
)

type Provider struct {
config oauth2.Config
}

func New(clientId, clientSecret, authUrl, tokenUrl string) Provider {
config := oauth2.Config{
// Default auth callback for testing. Remove
RedirectURL: "http://localhost:8080/auth/{provider}/callback",
ClientID: clientId,
ClientSecret: clientSecret,
Scopes: []string{},
Endpoint: oauth2.Endpoint{
AuthURL: authUrl,
TokenURL: tokenUrl,
},
}

return Provider{config: config}
}

type User struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
Error string `json:"error"`
ErrorDesc string `json:"error_description"`
ErrorUri string `json:"error_uri"`
}

type Session struct {
User User
Writer http.ResponseWriter
Request *http.Request
}

// LoginHandler return a http.HandlerFunc used to handle the endpoint /auth/{provider}
func LoginHandler(providers map[string]Provider) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
providerName := r.PathValue("provider")
if p, ok := providers[providerName]; ok {
url := p.config.AuthCodeURL("randomstate")
http.Redirect(w, r, url, http.StatusSeeOther)
} else {
log.Println("login: provider not found: ", providerName)
w.WriteHeader(http.StatusNotFound)
}
}
}

// CallbackHandler returns a http.HandlerFunc used to handle the endpoint /auth/{provider}/callback
func CallbackHandler(providers map[string]Provider, userCallback func(Session)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
providerName := r.PathValue("provider")

if p, ok := providers[providerName]; ok {
userData, err := p.fetchUserData(r)
if err != nil {
log.Println(err)
return
}

var user User
if err = json.Unmarshal(userData, &user); err != nil {
log.Println(err)
return
}

userCallback(Session{
User: user,
Request: r,
Writer: w,
})
} else {
log.Println("callback: provider not found: ", providerName)
w.WriteHeader(http.StatusNotFound)
}
}
}

func (p *Provider) fetchUserData(r *http.Request) ([]byte, error) {
code := r.FormValue("code")

token, err := p.config.Exchange(context.Background(), code)
if err != nil {
return nil, err
}

token_url := p.config.Endpoint.TokenURL + "?access_token=" + token.AccessToken
resp, err := http.Get(token_url)
if err != nil {
return nil, err
}

return io.ReadAll(resp.Body)
}
91 changes: 11 additions & 80 deletions internal/auth/sample/sample.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package sample

import (
"context"
"encoding/json"
"io"
"log"
"net/http"

"golang.org/x/oauth2"
"github.com/echo-webkom/goat/internal/auth"
)

func resJson(w http.ResponseWriter, j any) {
Expand All @@ -23,7 +21,7 @@ func resJson(w http.ResponseWriter, j any) {
}

// Mount endpoints handled by provider for testing
func mountExampleHandlers(s *http.ServeMux) {
func MountExampleHandlers(s *http.ServeMux) {
// Example login page, will be replaced with provider URL
s.HandleFunc("GET /sample/auth", func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "internal/auth/sample/sample_auth.html")
Expand All @@ -32,94 +30,27 @@ func mountExampleHandlers(s *http.ServeMux) {
// Used for token exchange
s.HandleFunc("POST /sample/tokenUrl", func(w http.ResponseWriter, r *http.Request) {
resJson(w, map[string]any{
"access_token": "abcdef",
"access_token": "VeryCoolAccessToken",
"token_type": "bearer",
"expires_in": 3600,
"refresh_token": "ghijklmno",
"scope": "",
"refresh_token": "CoolerRefreshToken",
"scope": "CoolSCope",
})
})

// Used to fetch user data with generated token
s.HandleFunc("GET /sample/tokenUrl", func(w http.ResponseWriter, r *http.Request) {
resJson(w, map[string]any{
"username": "bob",
"access_token": r.URL.Query().Get("access_token"),
})
})
}

// Todo: create generic newProvider function

func New(s *http.ServeMux) {
mountExampleHandlers(s)

const (
// Load from .env
CLIENT_ID = "john"
CLIENT_SECRET = "1234"

AUTH_URL = "http://localhost:8080/sample/auth"
TOKEN_URL = "http://localhost:8080/sample/tokenUrl"
func New() auth.Provider {
return auth.New(
"cooluserid",
"coolusersecret",
"http://localhost:8080/sample/auth",
"http://localhost:8080/sample/tokenUrl",
)

config := oauth2.Config{
RedirectURL: "http://localhost:8080/sample_callback",
ClientID: CLIENT_ID,
ClientSecret: CLIENT_SECRET,
Scopes: []string{},
Endpoint: oauth2.Endpoint{
AuthURL: AUTH_URL,
TokenURL: TOKEN_URL,
},
}

s.Handle("GET /sample_login", login(config))
s.Handle("POST /sample_callback", callback(config))
}

// Creates new login handler. Should redirect to providers auth URL with
// generated state. URL is given client id/secret, redirect uri, callback
// uri and state.
func login(config oauth2.Config) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
url := config.AuthCodeURL("randomstate")
http.Redirect(w, r, url, http.StatusSeeOther)
}
}

// Creates a new callback handler for the auth provider. Verifies state
// and creates access token from code given by provider. This handler simply
// responds with the user data json.
func callback(config oauth2.Config) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
state := r.FormValue("state")
if state != "randomstate" {
w.WriteHeader(http.StatusInternalServerError)
return
}

code := r.FormValue("code")

token, err := config.Exchange(context.Background(), code)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}

token_url := config.Endpoint.TokenURL + "?access_token=" + token.AccessToken
resp, err := http.Get(token_url)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}

userData, err := io.ReadAll(resp.Body)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}

w.Write(userData)
}
}
6 changes: 3 additions & 3 deletions internal/auth/sample/sample_auth.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
<title>Document</title>
</head>
<body>
<p>hello auth</p>
<p>Very cool auth provider login screen thing.com</p>

<form action="/sample_callback" method="post">
<form action="http://localhost:8080/auth/sample/callback" method="post">
<label for="code">Code:</label><br>
<input type="text" id="code" name="code" value="1234"><br>
<input type="text" id="code" name="code" value="some_code"><br>
<label for="state">State:</label><br>
<input type="text" id="state" name="state" value="randomstate"><br><br>
<input type="submit" value="Submit">
Expand Down
16 changes: 14 additions & 2 deletions internal/server/server.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package server

import (
"encoding/json"
"net/http"

"github.com/echo-webkom/goat/internal/auth"
"github.com/echo-webkom/goat/internal/auth/sample"
)

Expand Down Expand Up @@ -49,6 +51,16 @@ func (s *Server) MountHandlers() {

s.Router.Handle("GET /test", ToHttpHandlerFunc(middleware(handler)))

// Sample oauth2 flow, go to /sample_login
sample.New(s.Router)
// Sample oauth2 flow, go to /auth/sample
ps := map[string]auth.Provider{
"sample": sample.New(),
}

s.Router.HandleFunc("/auth/{provider}", auth.LoginHandler(ps))
s.Router.HandleFunc("/auth/{provider}/callback", auth.CallbackHandler(ps, func(s auth.Session) {
d, _ := json.Marshal(s.User)
s.Writer.Write(d)
}))

sample.MountExampleHandlers(s.Router)
}

0 comments on commit 7ec7e55

Please sign in to comment.