Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

[WIP] Support Provider Authentication Tokens (JWT) #88

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
*.p12
*.pem
*.p8
*.cer
c.out
*.pass
Expand Down
112 changes: 112 additions & 0 deletions example/jwt/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package main

import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"flag"
"io/ioutil"
"log"
"net/http"
"sync"
"time"

"golang.org/x/net/http2"

"github.com/RobotsAndPencils/buford/push"

"github.com/dgrijalva/jwt-go"
)

func NewClient() (*http.Client, error) {
config := &tls.Config{}
transport := &http.Transport{TLSClientConfig: config}

if err := http2.ConfigureTransport(transport); err != nil {
return nil, err
}

return &http.Client{Transport: transport}, nil
}

func main() {
var deviceToken, filename, keyID, teamID, bundleID string
var number int

flag.StringVar(&deviceToken, "d", "", "Device token")
flag.StringVar(&filename, "k", "", "Path to private signing key")
flag.StringVar(&keyID, "kid", "", "Key ID")
flag.StringVar(&teamID, "t", "", "TeamID")
flag.StringVar(&bundleID, "b", "", "Bundle ID for app")
flag.IntVar(&number, "n", 100, "Number of notifications to send")
flag.Parse()

privateBytes, err := ioutil.ReadFile(filename)
exitOnError(err)

block, _ := pem.Decode(privateBytes)
if block == nil {
log.Fatal("Key file must be PEM encoded.")
}

privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
exitOnError(err)

client, err := NewClient()
exitOnError(err)

service := push.NewService(client, push.Development)

token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
"iss": teamID,
"iat": time.Now().Unix(),
})
token.Header["kid"] = keyID
log.Printf("%#v\n", token)

// push the notification:
tokenString, err := token.SignedString(privateKey)
exitOnError(err)
log.Println(tokenString)

queue := push.NewQueue(service, 20)
var wg sync.WaitGroup

// process responses
// NOTE: Responses may be received in any order.
go func() {
count := 1
for resp := range queue.Responses {
if resp.Err != nil {
log.Printf("(%d) device: %s, error: %v", count, resp.DeviceToken, resp.Err)
} else {
log.Printf("(%d) device: %s, apns-id: %s", count, resp.DeviceToken, resp.ID)
}
count++
wg.Done()
}
}()

h := &push.Headers{Authorization: tokenString, Topic: bundleID}
b := []byte(`{"aps":{"alert":"Hello HTTP/2"}}`)

// synchronous send to prime stream
id, err := service.Push(deviceToken, h, []byte(`{"aps":{"alert":"Hello HTTP/2"}}`))
exitOnError(err)
log.Println("apns-id:", id)

// concurrent send
for i := 0; i < number; i++ {
wg.Add(1)
queue.Push(deviceToken, h, b)
}
// done sending notifications, wait for all responses and shutdown:
wg.Wait()
queue.Close()
}

func exitOnError(err error) {
if err != nil {
log.Fatal(err)
}
}
22 changes: 22 additions & 0 deletions push/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ var (
ErrBadPriority = errors.New("BadPriority")
ErrBadTopic = errors.New("BadTopic")

// Token authentication errors.
ErrMissingProviderToken = errors.New("MissingProviderToken")
ErrInvalidProviderToken = errors.New("InvalidProviderToken")
ErrExpiredProviderToken = errors.New("ExpiredProviderToken")
ErrTooManyProviderTokenUpdates = errors.New("TooManyProviderTokenUpdates")

// Certificate and topic errors.
ErrBadCertificate = errors.New("BadCertificate")
ErrBadCertificateEnvironment = errors.New("BadCertificateEnvironment")
Expand Down Expand Up @@ -80,6 +86,14 @@ func mapErrorReason(reason string) error {
e = ErrUnregistered
case "DuplicateHeaders":
e = ErrDuplicateHeaders
case "MissingProviderToken":
e = ErrMissingProviderToken
case "InvalidProviderToken":
e = ErrInvalidProviderToken
case "ExpiredProviderToken":
e = ErrExpiredProviderToken
case "TooManyProviderTokenUpdates":
e = ErrTooManyProviderTokenUpdates
case "BadCertificateEnvironment":
e = ErrBadCertificateEnvironment
case "BadCertificate":
Expand Down Expand Up @@ -128,6 +142,14 @@ func (e *Error) Error() string {
return "the apns-priority value is bad"
case ErrBadTopic:
return "the Topic header was invalid"
case ErrMissingProviderToken:
return "the Authorization header was missing"
case ErrInvalidProviderToken:
return "the JWT authentication token is invalid"
case ErrExpiredProviderToken:
return "the JWT authentication token expired"
case ErrTooManyProviderTokenUpdates:
return "the JWT authentication token is being updated too often"
case ErrBadCertificate:
return "the certificate was bad"
case ErrBadCertificateEnvironment:
Expand Down
7 changes: 7 additions & 0 deletions push/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ type Headers struct {

// Topic for certificates with multiple topics.
Topic string

// Authorization for Token Authentication (JWT)
Authorization string
}

// set headers for an HTTP request
Expand Down Expand Up @@ -54,4 +57,8 @@ func (h *Headers) set(reqHeader http.Header) {
reqHeader.Set("apns-topic", h.Topic)
}

if h.Authorization != "" {
reqHeader.Set("authorization", "bearer "+h.Authorization)
}

}
14 changes: 9 additions & 5 deletions push/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ import (

func TestHeaders(t *testing.T) {
headers := Headers{
ID: "uuid",
CollapseID: "game1.score.identifier",
Expiration: time.Unix(12622780800, 0),
LowPriority: true,
Topic: "bundle-id",
ID: "uuid",
CollapseID: "game1.score.identifier",
Expiration: time.Unix(12622780800, 0),
LowPriority: true,
Topic: "bundle-id",
Authorization: "eyJhbGciOiJFUzI1N",
}

reqHeader := http.Header{}
Expand All @@ -23,6 +24,7 @@ func TestHeaders(t *testing.T) {
testHeader(t, reqHeader, "apns-expiration", "12622780800")
testHeader(t, reqHeader, "apns-priority", "5")
testHeader(t, reqHeader, "apns-topic", "bundle-id")
testHeader(t, reqHeader, "authorization", "bearer eyJhbGciOiJFUzI1N")
}

func TestNilHeader(t *testing.T) {
Expand All @@ -35,6 +37,7 @@ func TestNilHeader(t *testing.T) {
testHeader(t, reqHeader, "apns-expiration", "")
testHeader(t, reqHeader, "apns-priority", "")
testHeader(t, reqHeader, "apns-topic", "")
testHeader(t, reqHeader, "authorization", "")
}

func TestEmptyHeaders(t *testing.T) {
Expand All @@ -47,6 +50,7 @@ func TestEmptyHeaders(t *testing.T) {
testHeader(t, reqHeader, "apns-expiration", "")
testHeader(t, reqHeader, "apns-priority", "")
testHeader(t, reqHeader, "apns-topic", "")
testHeader(t, reqHeader, "authorization", "")
}

func testHeader(t *testing.T, reqHeader http.Header, key, expected string) {
Expand Down
2 changes: 2 additions & 0 deletions push/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ func (s *Service) Push(deviceToken string, headers *Headers, payload []byte) (st

if err != nil {
if e, ok := err.(*url.Error); ok {
// log.Printf("%#v", e)
if e, ok := e.Err.(http2.GoAwayError); ok {
// ErrCode:0x7, DebugData:"Maximum active streams violated for this endpoint."
// parse DebugData as JSON. no status code known (0)
return "", parseErrorResponse(strings.NewReader(e.DebugData), 0)
}
Expand Down