From 18bb579dadbf66f29dba9f6546045b9735715459 Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Sun, 26 Nov 2023 18:58:19 +0100 Subject: [PATCH] Better API --- example_test.go | 41 +++++++++++++++++++++------ hotp.go | 52 +++++++++++++++++++--------------- hotp_test.go | 42 +++++++++++++++++++++++----- totp.go | 74 ++++++++++++++++++++++++++++--------------------- totp_test.go | 56 ++++++++++++++++++++++++++++++++----- 5 files changed, 190 insertions(+), 75 deletions(-) diff --git a/example_test.go b/example_test.go index 57bef19..4653055 100644 --- a/example_test.go +++ b/example_test.go @@ -2,20 +2,20 @@ package otp_test import ( "fmt" + "time" "github.com/cristalhq/otp" ) -func Example() { - secretInBase32 := "JBSWY3DPEHPK3PXP" - - algo := otp.AlgorithmSHA1 - digits := otp.Digits(10) - issuer := "cristalhq" - - hotp, err := otp.NewHOTP(algo, digits, issuer) +func ExampleHOTP() { + hotp, err := otp.NewHOTP(otp.HOTPConfig{ + Algo: otp.AlgorithmSHA1, + Digits: otp.Digits(10), + Issuer: "cristalhq", + }) checkErr(err) + secretInBase32 := "JBSWY3DPEHPK3PXP" code, err := hotp.GenerateCode(42, secretInBase32) checkErr(err) @@ -28,6 +28,31 @@ func Example() { // 0979090604 } +func ExampleTOTP() { + totp, err := otp.NewTOTP(otp.TOTPConfig{ + Algo: otp.AlgorithmSHA1, + Digits: otp.Digits(10), + Issuer: "cristalhq", + Period: 30, + Skew: 2, + }) + checkErr(err) + + secretInBase32 := "JBSWY3DPEHPK3PXP" + at := time.Date(2023, 11, 26, 12, 15, 18, 0, time.UTC) + + code, err := totp.GenerateCode(secretInBase32, at) + checkErr(err) + + fmt.Println(code) + + err = totp.Validate(code, at, secretInBase32) + checkErr(err) + + // Output: + // 0462778229 +} + func checkErr(err error) { if err != nil { panic(err) diff --git a/hotp.go b/hotp.go index 9578d84..93280bf 100644 --- a/hotp.go +++ b/hotp.go @@ -10,38 +10,46 @@ import ( // HOTP represents HOTP codes generator and validator. type HOTP struct { - algo Algorithm - digits Digits - issuer string + cfg HOTPConfig } -// NewHOTP creates new HOTP. -func NewHOTP(algo Algorithm, digits Digits, issuer string) (*HOTP, error) { - if algo < 0 || algo >= algorithmMax { - return nil, ErrUnsupportedAlgorithm +type HOTPConfig struct { + Algo Algorithm + Digits Digits + Issuer string +} + +func (cfg HOTPConfig) Validate() error { + switch { + case cfg.Algo < 0 || cfg.Algo >= algorithmMax: + return ErrUnsupportedAlgorithm + case cfg.Issuer == "": + return ErrEmptyIssuer + default: + return nil } - if issuer == "" { - return nil, ErrEmptyIssuer +} + +// NewHOTP creates new HOTP. +func NewHOTP(cfg HOTPConfig) (*HOTP, error) { + if err := cfg.Validate(); err != nil { + return nil, err } - return &HOTP{ - algo: algo, - digits: digits, - issuer: issuer, - }, nil + return &HOTP{cfg: cfg}, nil } // GenerateURL for the account for a given secret. func (h *HOTP) GenerateURL(account string, secret []byte) string { v := url.Values{} - v.Set("algorithm", h.algo.String()) - v.Set("digits", h.digits.String()) - v.Set("issuer", h.issuer) + v.Set("algorithm", h.cfg.Algo.String()) + v.Set("digits", h.cfg.Digits.String()) + v.Set("issuer", h.cfg.Issuer) v.Set("secret", b32EncNoPadding(secret)) u := url.URL{ Scheme: "otpauth", Host: "hotp", - Path: "/" + h.issuer + ":" + account, + Path: "/" + h.cfg.Issuer + ":" + account, RawQuery: v.Encode(), } return u.String() @@ -57,7 +65,7 @@ func (h *HOTP) GenerateCode(counter uint64, secret string) (string, error) { buf := make([]byte, 8) binary.BigEndian.PutUint64(buf, counter) - mac := hmac.New(h.algo.Hash, secretBytes) + mac := hmac.New(h.cfg.Algo.Hash, secretBytes) mac.Write(buf) sum := mac.Sum(nil) @@ -69,13 +77,13 @@ func (h *HOTP) GenerateCode(counter uint64, secret string) (string, error) { value |= int64(sum[offset+2]&0xff) << 8 value |= int64(sum[offset+3] & 0xff) - length := int64(math.Pow10(h.digits.Length())) - return h.digits.Format(int(value % length)), nil + length := int64(math.Pow10(h.cfg.Digits.Length())) + return h.cfg.Digits.Format(int(value % length)), nil } // Validate the given passcode, counter and secret. func (h *HOTP) Validate(passcode string, counter uint64, secret string) error { - if len(passcode) != h.digits.Length() { + if len(passcode) != h.cfg.Digits.Length() { return ErrCodeLengthMismatch } diff --git a/hotp_test.go b/hotp_test.go index f1fd9f7..d415b20 100644 --- a/hotp_test.go +++ b/hotp_test.go @@ -25,7 +25,11 @@ func TestHOTP(t *testing.T) { } for _, tc := range hotpRFCTestCases { - hotp, err := NewHOTP(tc.algo, Digits(6), "cristalhq") + hotp, err := NewHOTP(HOTPConfig{ + Algo: tc.algo, + Digits: Digits(6), + Issuer: "cristalhq", + }) mustOk(t, err) code, err := hotp.GenerateCode(tc.counter, tc.secret) @@ -38,15 +42,27 @@ func TestHOTP(t *testing.T) { } func TestNewHOTP(t *testing.T) { - _, err := NewHOTP(-1, Digits(8), "cristalhq") + _, err := NewHOTP(HOTPConfig{ + Algo: -1, + Digits: Digits(8), + Issuer: "cristalhq", + }) mustEqual(t, err, ErrUnsupportedAlgorithm) - _, err = NewHOTP(1, Digits(8), "") + _, err = NewHOTP(HOTPConfig{ + Algo: 1, + Digits: Digits(8), + Issuer: "", + }) mustEqual(t, err, ErrEmptyIssuer) } func TestHOTPGenerateURL(t *testing.T) { - hotp, err := NewHOTP(AlgorithmSHA1, Digits(8), "cristalhq") + hotp, err := NewHOTP(HOTPConfig{ + Algo: AlgorithmSHA1, + Digits: Digits(8), + Issuer: "cristalhq", + }) mustOk(t, err) url := hotp.GenerateURL("alice@bob.com", []byte("SECRET_STRING")) @@ -57,7 +73,11 @@ func TestHOTPGenerateURL(t *testing.T) { } func BenchmarkHOTP_GenerateURL(b *testing.B) { - hotp, err := NewHOTP(AlgorithmSHA1, Digits(8), "cristalhq") + hotp, err := NewHOTP(HOTPConfig{ + Algo: AlgorithmSHA1, + Digits: Digits(8), + Issuer: "cristalhq", + }) mustOk(b, err) account := "otp@cristalhq.dev" @@ -75,7 +95,11 @@ func BenchmarkHOTP_GenerateURL(b *testing.B) { } func BenchmarkHOTP_GenerateCode(b *testing.B) { - hotp, err := NewHOTP(AlgorithmSHA1, Digits(8), "cristalhq") + hotp, err := NewHOTP(HOTPConfig{ + Algo: AlgorithmSHA1, + Digits: Digits(8), + Issuer: "cristalhq", + }) mustOk(b, err) b.ResetTimer() @@ -90,7 +114,11 @@ func BenchmarkHOTP_GenerateCode(b *testing.B) { } func BenchmarkHOTP_Validate(b *testing.B) { - hotp, err := NewHOTP(AlgorithmSHA1, Digits(8), "cristalhq") + hotp, err := NewHOTP(HOTPConfig{ + Algo: AlgorithmSHA1, + Digits: Digits(8), + Issuer: "cristalhq", + }) mustOk(b, err) secret := secretSha1 diff --git a/totp.go b/totp.go index 045db1e..aa9f3ba 100644 --- a/totp.go +++ b/totp.go @@ -9,50 +9,62 @@ import ( // TOTP represents TOTP codes generator and validator. type TOTP struct { - *HOTP - period uint - skew uint + hotp HOTP + cfg TOTPConfig } -// NewTOTP creates new TOTP. -func NewTOTP(algo Algorithm, digits Digits, issuer string, period, skew uint) (*TOTP, error) { - if algo < 0 || algo >= algorithmMax { - return nil, ErrUnsupportedAlgorithm - } - if issuer == "" { - return nil, ErrEmptyIssuer - } - if period < 1 { - return nil, ErrPeriodNotValid +type TOTPConfig struct { + Algo Algorithm + Digits Digits + Issuer string + Period uint + Skew uint +} + +func (cfg TOTPConfig) Validate() error { + switch { + case cfg.Algo < 0 || cfg.Algo >= algorithmMax: + return ErrUnsupportedAlgorithm + case cfg.Issuer == "": + return ErrEmptyIssuer + default: + return nil } - if skew < 1 { - return nil, ErrSkewNotValid +} + +// NewTOTP creates new TOTP. +func NewTOTP(cfg TOTPConfig) (*TOTP, error) { + if err := cfg.Validate(); err != nil { + return nil, err } - hotp, err := NewHOTP(algo, digits, issuer) + hotp, err := NewHOTP(HOTPConfig{ + Algo: cfg.Algo, + Digits: cfg.Digits, + Issuer: cfg.Issuer, + }) if err != nil { return nil, err } return &TOTP{ - HOTP: hotp, - period: period, - skew: skew, + hotp: *hotp, + cfg: cfg, }, nil } // GenerateURL for the account for a given secret. func (t *TOTP) GenerateURL(account string, secret []byte) string { v := url.Values{} - v.Set("algorithm", t.algo.String()) - v.Set("digits", t.digits.String()) - v.Set("issuer", t.issuer) + v.Set("algorithm", t.cfg.Algo.String()) + v.Set("digits", t.cfg.Digits.String()) + v.Set("issuer", t.cfg.Issuer) v.Set("secret", b32EncNoPadding(secret)) - v.Set("period", strconv.FormatUint(uint64(t.period), 10)) + v.Set("period", strconv.FormatUint(uint64(t.cfg.Period), 10)) u := url.URL{ Scheme: "otpauth", Host: "totp", - Path: "/" + t.issuer + ":" + account, + Path: "/" + t.cfg.Issuer + ":" + account, RawQuery: v.Encode(), } return u.String() @@ -60,8 +72,8 @@ func (t *TOTP) GenerateURL(account string, secret []byte) string { // GenerateCode for the given counter and secret. func (t *TOTP) GenerateCode(secret string, at time.Time) (string, error) { - counter := uint64(math.Floor(float64(at.Unix()) / float64(t.period))) - code, err := t.HOTP.GenerateCode(counter, secret) + counter := uint64(math.Floor(float64(at.Unix()) / float64(t.cfg.Period))) + code, err := t.hotp.GenerateCode(counter, secret) if err != nil { return "", err } @@ -70,20 +82,20 @@ func (t *TOTP) GenerateCode(secret string, at time.Time) (string, error) { // Validate the given passcode, time and secret. func (t *TOTP) Validate(passcode string, at time.Time, secret string) error { - if len(passcode) != t.digits.Length() { + if len(passcode) != t.cfg.Digits.Length() { return ErrCodeLengthMismatch } - counters := make([]uint64, 0, 2*t.skew+1) - counter := int64(math.Floor(float64(at.Unix()) / float64(t.period))) + counters := make([]uint64, 0, 2*t.cfg.Skew+1) + counter := int64(math.Floor(float64(at.Unix()) / float64(t.cfg.Period))) counters = append(counters, uint64(counter)) - for i := uint(1); i <= t.skew; i++ { + for i := uint(1); i <= t.cfg.Skew; i++ { counters = append(counters, uint64(counter+int64(i)), uint64(counter-int64(i))) } for _, counter := range counters { - err := t.HOTP.Validate(passcode, counter, secret) + err := t.hotp.Validate(passcode, counter, secret) if err == nil { return nil } diff --git a/totp_test.go b/totp_test.go index efc052a..b300916 100644 --- a/totp_test.go +++ b/totp_test.go @@ -34,7 +34,13 @@ func TestTOTP(t *testing.T) { } for _, tc := range totpRFCTestCases { - totp, err := NewTOTP(tc.algo, Digits(8), "cristalhq", 30, 1) + totp, err := NewTOTP(TOTPConfig{ + Algo: tc.algo, + Digits: Digits(8), + Issuer: "cristalhq", + Period: 30, + Skew: 1, + }) mustOk(t, err) at := time.Unix(tc.ts, 0).UTC() @@ -48,15 +54,33 @@ func TestTOTP(t *testing.T) { } func TestNewTOTP(t *testing.T) { - _, err := NewTOTP(-1, Digits(8), "cristalhq", 30, 1) + _, err := NewTOTP(TOTPConfig{ + Algo: -1, + Digits: Digits(8), + Issuer: "cristalhq", + Period: 30, + Skew: 1, + }) mustEqual(t, err, ErrUnsupportedAlgorithm) - _, err = NewTOTP(1, Digits(8), "", 30, 1) + _, err = NewTOTP(TOTPConfig{ + Algo: 1, + Digits: Digits(8), + Issuer: "", + Period: 30, + Skew: 1, + }) mustEqual(t, err, ErrEmptyIssuer) } func TestTOTPGenerateURL(t *testing.T) { - totp, err := NewTOTP(AlgorithmSHA1, Digits(8), "cristalhq", 30, 1) + totp, err := NewTOTP(TOTPConfig{ + Algo: AlgorithmSHA1, + Digits: Digits(8), + Issuer: "cristalhq", + Period: 30, + Skew: 1, + }) mustOk(t, err) url := totp.GenerateURL("alice@bob.com", []byte("SECRET_STRING")) @@ -67,7 +91,13 @@ func TestTOTPGenerateURL(t *testing.T) { } func BenchmarkTOTP_GenerateURL(b *testing.B) { - totp, err := NewTOTP(AlgorithmSHA1, Digits(8), "cristalhq", 30, 1) + totp, err := NewTOTP(TOTPConfig{ + Algo: AlgorithmSHA1, + Digits: Digits(8), + Issuer: "cristalhq", + Period: 30, + Skew: 1, + }) mustOk(b, err) account := "otp@cristalhq.dev" @@ -85,7 +115,13 @@ func BenchmarkTOTP_GenerateURL(b *testing.B) { } func BenchmarkTOTP_GenerateCode(b *testing.B) { - totp, err := NewTOTP(AlgorithmSHA1, Digits(8), "cristalhq", 30, 1) + totp, err := NewTOTP(TOTPConfig{ + Algo: AlgorithmSHA1, + Digits: Digits(8), + Issuer: "cristalhq", + Period: 30, + Skew: 1, + }) mustOk(b, err) secret := secretSha1 @@ -102,7 +138,13 @@ func BenchmarkTOTP_GenerateCode(b *testing.B) { } func BenchmarkTOTP_Validate(b *testing.B) { - totp, err := NewTOTP(AlgorithmSHA1, Digits(8), "cristalhq", 30, 1) + totp, err := NewTOTP(TOTPConfig{ + Algo: AlgorithmSHA1, + Digits: Digits(8), + Issuer: "cristalhq", + Period: 30, + Skew: 1, + }) mustOk(b, err) secret := secretSha1