Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better API #8

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@ package otp_test

import (
"fmt"
"time"

"github.com/cristalhq/otp"
)

func Example() {
secretInBase32 := "JBSWY3DPEHPK3PXP"

algo := otp.AlgorithmSHA1
digits := otp.Digits(10)
issuer := "cristalhq"

hotp, err := otp.NewHOTP(algo, digits, issuer)
func ExampleHOTP() {
hotp, err := otp.NewHOTP(otp.HOTPConfig{
Algo: otp.AlgorithmSHA1,
Digits: otp.Digits(10),
Issuer: "cristalhq",
})
checkErr(err)

secretInBase32 := "JBSWY3DPEHPK3PXP"
code, err := hotp.GenerateCode(42, secretInBase32)
checkErr(err)

Expand All @@ -28,6 +28,31 @@ func Example() {
// 0979090604
}

func ExampleTOTP() {
totp, err := otp.NewTOTP(otp.TOTPConfig{
Algo: otp.AlgorithmSHA1,
Digits: otp.Digits(10),
Issuer: "cristalhq",
Period: 30,
Skew: 2,
})
checkErr(err)

secretInBase32 := "JBSWY3DPEHPK3PXP"
at := time.Date(2023, 11, 26, 12, 15, 18, 0, time.UTC)

code, err := totp.GenerateCode(secretInBase32, at)
checkErr(err)

fmt.Println(code)

err = totp.Validate(code, at, secretInBase32)
checkErr(err)

// Output:
// 0462778229
}

func checkErr(err error) {
if err != nil {
panic(err)
Expand Down
52 changes: 30 additions & 22 deletions hotp.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,46 @@ import (

// HOTP represents HOTP codes generator and validator.
type HOTP struct {
algo Algorithm
digits Digits
issuer string
cfg HOTPConfig
}

// NewHOTP creates new HOTP.
func NewHOTP(algo Algorithm, digits Digits, issuer string) (*HOTP, error) {
if algo < 0 || algo >= algorithmMax {
return nil, ErrUnsupportedAlgorithm
type HOTPConfig struct {
Algo Algorithm
Digits Digits
Issuer string
}
cristaloleg marked this conversation as resolved.
Show resolved Hide resolved

func (cfg HOTPConfig) Validate() error {
switch {
case cfg.Algo < 0 || cfg.Algo >= algorithmMax:
return ErrUnsupportedAlgorithm
case cfg.Issuer == "":
return ErrEmptyIssuer
default:
return nil
}
if issuer == "" {
return nil, ErrEmptyIssuer
}

// NewHOTP creates new HOTP.
func NewHOTP(cfg HOTPConfig) (*HOTP, error) {
if err := cfg.Validate(); err != nil {
return nil, err
}
return &HOTP{
algo: algo,
digits: digits,
issuer: issuer,
}, nil
return &HOTP{cfg: cfg}, nil
}

// GenerateURL for the account for a given secret.
func (h *HOTP) GenerateURL(account string, secret []byte) string {
v := url.Values{}
v.Set("algorithm", h.algo.String())
v.Set("digits", h.digits.String())
v.Set("issuer", h.issuer)
v.Set("algorithm", h.cfg.Algo.String())
v.Set("digits", h.cfg.Digits.String())
v.Set("issuer", h.cfg.Issuer)
v.Set("secret", b32EncNoPadding(secret))

u := url.URL{
Scheme: "otpauth",
Host: "hotp",
Path: "/" + h.issuer + ":" + account,
Path: "/" + h.cfg.Issuer + ":" + account,
RawQuery: v.Encode(),
}
return u.String()
Expand All @@ -57,7 +65,7 @@ func (h *HOTP) GenerateCode(counter uint64, secret string) (string, error) {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, counter)

mac := hmac.New(h.algo.Hash, secretBytes)
mac := hmac.New(h.cfg.Algo.Hash, secretBytes)
mac.Write(buf)
sum := mac.Sum(nil)

Expand All @@ -69,13 +77,13 @@ func (h *HOTP) GenerateCode(counter uint64, secret string) (string, error) {
value |= int64(sum[offset+2]&0xff) << 8
value |= int64(sum[offset+3] & 0xff)

length := int64(math.Pow10(h.digits.Length()))
return h.digits.Format(int(value % length)), nil
length := int64(math.Pow10(h.cfg.Digits.Length()))
return h.cfg.Digits.Format(int(value % length)), nil
}

// Validate the given passcode, counter and secret.
func (h *HOTP) Validate(passcode string, counter uint64, secret string) error {
if len(passcode) != h.digits.Length() {
if len(passcode) != h.cfg.Digits.Length() {
return ErrCodeLengthMismatch
}

Expand Down
42 changes: 35 additions & 7 deletions hotp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ func TestHOTP(t *testing.T) {
}

for _, tc := range hotpRFCTestCases {
hotp, err := NewHOTP(tc.algo, Digits(6), "cristalhq")
hotp, err := NewHOTP(HOTPConfig{
Algo: tc.algo,
Digits: Digits(6),
Issuer: "cristalhq",
})
mustOk(t, err)

code, err := hotp.GenerateCode(tc.counter, tc.secret)
Expand All @@ -38,15 +42,27 @@ func TestHOTP(t *testing.T) {
}

func TestNewHOTP(t *testing.T) {
_, err := NewHOTP(-1, Digits(8), "cristalhq")
_, err := NewHOTP(HOTPConfig{
Algo: -1,
Digits: Digits(8),
Issuer: "cristalhq",
})
mustEqual(t, err, ErrUnsupportedAlgorithm)

_, err = NewHOTP(1, Digits(8), "")
_, err = NewHOTP(HOTPConfig{
Algo: 1,
Digits: Digits(8),
Issuer: "",
})
mustEqual(t, err, ErrEmptyIssuer)
}

func TestHOTPGenerateURL(t *testing.T) {
hotp, err := NewHOTP(AlgorithmSHA1, Digits(8), "cristalhq")
hotp, err := NewHOTP(HOTPConfig{
Algo: AlgorithmSHA1,
Digits: Digits(8),
Issuer: "cristalhq",
})
mustOk(t, err)

url := hotp.GenerateURL("alice@bob.com", []byte("SECRET_STRING"))
Expand All @@ -57,7 +73,11 @@ func TestHOTPGenerateURL(t *testing.T) {
}

func BenchmarkHOTP_GenerateURL(b *testing.B) {
hotp, err := NewHOTP(AlgorithmSHA1, Digits(8), "cristalhq")
hotp, err := NewHOTP(HOTPConfig{
Algo: AlgorithmSHA1,
Digits: Digits(8),
Issuer: "cristalhq",
})
mustOk(b, err)

account := "otp@cristalhq.dev"
Expand All @@ -75,7 +95,11 @@ func BenchmarkHOTP_GenerateURL(b *testing.B) {
}

func BenchmarkHOTP_GenerateCode(b *testing.B) {
hotp, err := NewHOTP(AlgorithmSHA1, Digits(8), "cristalhq")
hotp, err := NewHOTP(HOTPConfig{
Algo: AlgorithmSHA1,
Digits: Digits(8),
Issuer: "cristalhq",
})
mustOk(b, err)

b.ResetTimer()
Expand All @@ -90,7 +114,11 @@ func BenchmarkHOTP_GenerateCode(b *testing.B) {
}

func BenchmarkHOTP_Validate(b *testing.B) {
hotp, err := NewHOTP(AlgorithmSHA1, Digits(8), "cristalhq")
hotp, err := NewHOTP(HOTPConfig{
Algo: AlgorithmSHA1,
Digits: Digits(8),
Issuer: "cristalhq",
})
mustOk(b, err)

secret := secretSha1
Expand Down
74 changes: 43 additions & 31 deletions totp.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,71 @@ import (

// TOTP represents TOTP codes generator and validator.
type TOTP struct {
*HOTP
period uint
skew uint
hotp HOTP
cfg TOTPConfig
}

// NewTOTP creates new TOTP.
func NewTOTP(algo Algorithm, digits Digits, issuer string, period, skew uint) (*TOTP, error) {
if algo < 0 || algo >= algorithmMax {
return nil, ErrUnsupportedAlgorithm
}
if issuer == "" {
return nil, ErrEmptyIssuer
}
if period < 1 {
return nil, ErrPeriodNotValid
type TOTPConfig struct {
Algo Algorithm
Digits Digits
Issuer string
Period uint
Skew uint
}

func (cfg TOTPConfig) Validate() error {
switch {
case cfg.Algo < 0 || cfg.Algo >= algorithmMax:
return ErrUnsupportedAlgorithm
case cfg.Issuer == "":
return ErrEmptyIssuer
default:
return nil
}
if skew < 1 {
return nil, ErrSkewNotValid
}

// NewTOTP creates new TOTP.
func NewTOTP(cfg TOTPConfig) (*TOTP, error) {
if err := cfg.Validate(); err != nil {
return nil, err
}

hotp, err := NewHOTP(algo, digits, issuer)
hotp, err := NewHOTP(HOTPConfig{
Algo: cfg.Algo,
Digits: cfg.Digits,
Issuer: cfg.Issuer,
})
if err != nil {
return nil, err
}
return &TOTP{
HOTP: hotp,
period: period,
skew: skew,
hotp: *hotp,
cfg: cfg,
}, nil
}

// GenerateURL for the account for a given secret.
func (t *TOTP) GenerateURL(account string, secret []byte) string {
v := url.Values{}
v.Set("algorithm", t.algo.String())
v.Set("digits", t.digits.String())
v.Set("issuer", t.issuer)
v.Set("algorithm", t.cfg.Algo.String())
v.Set("digits", t.cfg.Digits.String())
v.Set("issuer", t.cfg.Issuer)
v.Set("secret", b32EncNoPadding(secret))
v.Set("period", strconv.FormatUint(uint64(t.period), 10))
v.Set("period", strconv.FormatUint(uint64(t.cfg.Period), 10))

u := url.URL{
Scheme: "otpauth",
Host: "totp",
Path: "/" + t.issuer + ":" + account,
Path: "/" + t.cfg.Issuer + ":" + account,
RawQuery: v.Encode(),
}
return u.String()
}

// GenerateCode for the given counter and secret.
func (t *TOTP) GenerateCode(secret string, at time.Time) (string, error) {
counter := uint64(math.Floor(float64(at.Unix()) / float64(t.period)))
code, err := t.HOTP.GenerateCode(counter, secret)
counter := uint64(math.Floor(float64(at.Unix()) / float64(t.cfg.Period)))
code, err := t.hotp.GenerateCode(counter, secret)
if err != nil {
return "", err
}
Expand All @@ -70,20 +82,20 @@ func (t *TOTP) GenerateCode(secret string, at time.Time) (string, error) {

// Validate the given passcode, time and secret.
func (t *TOTP) Validate(passcode string, at time.Time, secret string) error {
if len(passcode) != t.digits.Length() {
if len(passcode) != t.cfg.Digits.Length() {
return ErrCodeLengthMismatch
}

counters := make([]uint64, 0, 2*t.skew+1)
counter := int64(math.Floor(float64(at.Unix()) / float64(t.period)))
counters := make([]uint64, 0, 2*t.cfg.Skew+1)
counter := int64(math.Floor(float64(at.Unix()) / float64(t.cfg.Period)))
counters = append(counters, uint64(counter))

for i := uint(1); i <= t.skew; i++ {
for i := uint(1); i <= t.cfg.Skew; i++ {
counters = append(counters, uint64(counter+int64(i)), uint64(counter-int64(i)))
}

for _, counter := range counters {
err := t.HOTP.Validate(passcode, counter, secret)
err := t.hotp.Validate(passcode, counter, secret)
if err == nil {
return nil
}
Expand Down
Loading
Loading