Skip to content

Commit

Permalink
Extract out db.QuestionMarks function (#6568)
Browse files Browse the repository at this point in the history
We use this pattern in several places: there is a query that needs to
have a variable number of placeholders (question marks) in it, depending
on how many items we are inserting or querying for. For instance, when
issuing a precertificate we add that precertificate's names to the
"issuedNames" table. To make things more efficient, we do that in a
single query, whether there is one name on the certificate or a hundred.
That means interpolating into the query string with series of question
marks that matches the number of names.

We have a helper type MultiInserter that solves this problem for simple
inserts, but it does not solve the problem for selects or more complex
inserts, and we still have a number of places that generate their
sequence of question marks manually.

This change updates addIssuedNames to use MultiInserter. To enable that,
it also narrows the interface required by MultiInserter.Insert, so it's
easier to mock in tests.

This change adds the new function db.QuestionMarks, which generates e.g.
`?,?,?` depending on the input N.

In a few places I had to rename a function parameter named `db` to avoid
shadowing the `db` package.
  • Loading branch information
jsha authored Jan 10, 2023
1 parent 1e7c64e commit 4be76af
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 56 deletions.
4 changes: 1 addition & 3 deletions cmd/expiration-mailer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,14 @@ func (m *mailer) updateLastNagTimestamps(ctx context.Context, certs []*x509.Cert

// updateLastNagTimestampsChunk processes a single chunk (up to 65k) of certificates.
func (m *mailer) updateLastNagTimestampsChunk(ctx context.Context, certs []*x509.Certificate) {
qmarks := make([]string, len(certs))
params := make([]interface{}, len(certs)+1)
for i, cert := range certs {
qmarks[i] = "?"
params[i+1] = core.SerialToString(cert.SerialNumber)
}

query := fmt.Sprintf(
"UPDATE certificateStatus SET lastExpirationNagSent = ? WHERE serial IN (%s)",
strings.Join(qmarks, ","),
db.QuestionMarks(len(certs)),
)
params[0] = m.clk.Now()

Expand Down
8 changes: 8 additions & 0 deletions db/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ type Executor interface {
Query(string, ...interface{}) (*sql.Rows, error)
}

// Queryer offers the Query method. Note that this is not read-only (i.e. not
// Selector), since a Query can be `INSERT`, `UPDATE`, etc. The difference
// between Query and Exec is that Query can return rows. So for instance it is
// suitable for inserting rows and getting back ids.
type Queryer interface {
Query(string, ...interface{}) (*sql.Rows, error)
}

// Transaction extends an Executor and adds Rollback, Commit, and WithContext.
type Transaction interface {
Executor
Expand Down
39 changes: 23 additions & 16 deletions db/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,23 @@ type MultiInserter struct {
}

// NewMultiInserter creates a new MultiInserter, checking for reasonable table
// name and list of fields.
func NewMultiInserter(table string, fields string, retCol string) (*MultiInserter, error) {
// name and list of fields. returningColumn is the name of a column to be used
// in a `RETURNING xyz` clause at the end. If it is empty, no `RETURNING xyz`
// clause is used. If returningColumn is present, it must refer to a column
// that can be parsed into an int64.
func NewMultiInserter(table string, fields string, returningColumn string) (*MultiInserter, error) {
numFields := len(strings.Split(fields, ","))
if len(table) == 0 || len(fields) == 0 || numFields == 0 {
return nil, fmt.Errorf("empty table name or fields list")
}
if strings.Contains(retCol, ",") {
return nil, fmt.Errorf("return column must be singular, but got %q", retCol)
if strings.Contains(returningColumn, ",") {
return nil, fmt.Errorf("return column must be singular, but got %q", returningColumn)
}

return &MultiInserter{
table: table,
fields: fields,
retCol: retCol,
retCol: returningColumn,
numFields: numFields,
values: make([][]interface{}, 0),
}, nil
Expand All @@ -50,12 +53,10 @@ func (mi *MultiInserter) Add(row []interface{}) error {
// for gorp to use in place of the query's question marks. Currently only
// used by .Insert(), below.
func (mi *MultiInserter) query() (string, []interface{}) {
questionsRow := strings.TrimRight(strings.Repeat("?,", mi.numFields), ",")

var questionsBuf strings.Builder
var queryArgs []interface{}
for _, row := range mi.values {
fmt.Fprintf(&questionsBuf, "(%s),", questionsRow)
fmt.Fprintf(&questionsBuf, "(%s),", QuestionMarks(mi.numFields))
queryArgs = append(queryArgs, row...)
}

Expand All @@ -71,12 +72,12 @@ func (mi *MultiInserter) query() (string, []interface{}) {
}

// Insert performs the action represented by .query() on the provided database,
// which is assumed to already have a context attached. If a non-empty retCol
// was provided, then it returns the list of values from that column returned
// by the query.
func (mi *MultiInserter) Insert(exec Executor) ([]int64, error) {
// which is assumed to already have a context attached. If a non-empty
// returningColumn was provided, then it returns the list of values from that
// column returned by the query.
func (mi *MultiInserter) Insert(queryer Queryer) ([]int64, error) {
query, queryArgs := mi.query()
rows, err := exec.Query(query, queryArgs...)
rows, err := queryer.Query(query, queryArgs...)
if err != nil {
return nil, err
}
Expand All @@ -94,9 +95,15 @@ func (mi *MultiInserter) Insert(exec Executor) ([]int64, error) {
}
}

err = rows.Close()
if err != nil {
return nil, err
// Hack: sometimes in unittests we make a mock Queryer that returns a nil
// `*sql.Rows`. A nil `*sql.Rows` is not actually valid— calling `Close()`
// on it will panic— but here we choose to treat it like an empty list,
// and skip calling `Close()` to avoid the panic.
if rows != nil {
err = rows.Close()
if err != nil {
return nil, err
}
}

return ids, nil
Expand Down
21 changes: 21 additions & 0 deletions db/qmarks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package db

import "strings"

// QuestionMarks returns a string consisting of N question marks, joined by
// commas. If n is <= 0, panics.
func QuestionMarks(n int) string {
if n <= 0 {
panic("db.QuestionMarks called with n <=0")
}
var qmarks strings.Builder
qmarks.Grow(2 * n)
for i := 0; i < n; i++ {
if i == 0 {
qmarks.WriteString("?")
} else {
qmarks.WriteString(",?")
}
}
return qmarks.String()
}
19 changes: 19 additions & 0 deletions db/qmarks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package db

import (
"testing"

"github.com/letsencrypt/boulder/test"
)

func TestQuestionMarks(t *testing.T) {
test.AssertEquals(t, QuestionMarks(1), "?")
test.AssertEquals(t, QuestionMarks(2), "?,?")
test.AssertEquals(t, QuestionMarks(3), "?,?,?")
}

func TestQuestionMarksPanic(t *testing.T) {
defer func() { recover() }()
QuestionMarks(0)
t.Errorf("calling QuestionMarks(0) did not panic as expected")
}
26 changes: 15 additions & 11 deletions sa/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -765,22 +765,27 @@ func deleteOrderFQDNSet(
return nil
}

func addIssuedNames(db db.Execer, cert *x509.Certificate, isRenewal bool) error {
func addIssuedNames(queryer db.Queryer, cert *x509.Certificate, isRenewal bool) error {
if len(cert.DNSNames) == 0 {
return berrors.InternalServerError("certificate has no DNSNames")
}
var qmarks []string
var values []interface{}

multiInserter, err := db.NewMultiInserter("issuedNames", "reversedName, serial, notBefore, renewal", "")
if err != nil {
return err
}
for _, name := range cert.DNSNames {
values = append(values,
err = multiInserter.Add([]interface{}{
ReverseName(name),
core.SerialToString(cert.SerialNumber),
cert.NotBefore,
isRenewal)
qmarks = append(qmarks, "(?, ?, ?, ?)")
isRenewal,
})
if err != nil {
return err
}
}
query := `INSERT INTO issuedNames (reversedName, serial, notBefore, renewal) VALUES ` + strings.Join(qmarks, ", ") + `;`
_, err := db.Exec(query, values...)
_, err = multiInserter.Insert(queryer)
return err
}

Expand Down Expand Up @@ -932,10 +937,8 @@ type authzValidity struct {
// status and expiration date of each of them. It assumes that the provided
// database selector already has a context associated with it.
func getAuthorizationStatuses(s db.Selector, ids []int64) ([]authzValidity, error) {
var qmarks []string
var params []interface{}
for _, id := range ids {
qmarks = append(qmarks, "?")
params = append(params, id)
}
var validityInfo []struct {
Expand All @@ -944,7 +947,8 @@ func getAuthorizationStatuses(s db.Selector, ids []int64) ([]authzValidity, erro
}
_, err := s.Select(
&validityInfo,
fmt.Sprintf("SELECT status, expires FROM authz2 WHERE id IN (%s)", strings.Join(qmarks, ",")),
fmt.Sprintf("SELECT status, expires FROM authz2 WHERE id IN (%s)",
db.QuestionMarks(len(ids))),
params...,
)
if err != nil {
Expand Down
11 changes: 5 additions & 6 deletions sa/sa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ func createPendingAuthorization(t *testing.T, sa *SQLStorageAuthority, domain st
err = sa.dbMap.Insert(&am)
test.AssertNotError(t, err, "creating test authorization")

t.Log(am.ID)
return am.ID
}

Expand Down Expand Up @@ -717,12 +716,12 @@ func TestFQDNSetsExists(t *testing.T) {
test.Assert(t, exists.Exists, "FQDN set does exist")
}

type execRecorder struct {
type queryRecorder struct {
query string
args []interface{}
}

func (e *execRecorder) Exec(query string, args ...interface{}) (sql.Result, error) {
func (e *queryRecorder) Query(query string, args ...interface{}) (*sql.Rows, error) {
e.query = query
e.args = args
return nil, nil
Expand All @@ -732,7 +731,7 @@ func TestAddIssuedNames(t *testing.T) {
serial := big.NewInt(1)
expectedSerial := "000000000000000000000000000000000001"
notBefore := time.Date(2018, 2, 14, 12, 0, 0, 0, time.UTC)
placeholdersPerName := "(?, ?, ?, ?)"
placeholdersPerName := "(?,?,?,?)"
baseQuery := "INSERT INTO issuedNames (reversedName, serial, notBefore, renewal) VALUES"

testCases := []struct {
Expand Down Expand Up @@ -807,7 +806,7 @@ func TestAddIssuedNames(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
var e execRecorder
var e queryRecorder
err := addIssuedNames(
&e,
&x509.Certificate{
Expand All @@ -819,7 +818,7 @@ func TestAddIssuedNames(t *testing.T) {
test.AssertNotError(t, err, "addIssuedNames failed")
expectedPlaceholders := placeholdersPerName
for i := 0; i < len(tc.IssuedNames)-1; i++ {
expectedPlaceholders = fmt.Sprintf("%s, %s", expectedPlaceholders, placeholdersPerName)
expectedPlaceholders = fmt.Sprintf("%s,%s", expectedPlaceholders, placeholdersPerName)
}
expectedQuery := fmt.Sprintf("%s %s;", baseQuery, expectedPlaceholders)
test.AssertEquals(t, e.query, expectedQuery)
Expand Down
39 changes: 19 additions & 20 deletions sa/saro.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,10 +760,8 @@ func (ssa *SQLStorageAuthorityRO) GetAuthorizations2(ctx context.Context, req *s
identifierTypeToUint[string(identifier.DNS)],
}

qmarks := make([]string, len(req.Domains))
for i, n := range req.Domains {
qmarks[i] = "?"
params = append(params, n)
for _, name := range req.Domains {
params = append(params, name)
}

query := fmt.Sprintf(
Expand All @@ -775,7 +773,7 @@ func (ssa *SQLStorageAuthorityRO) GetAuthorizations2(ctx context.Context, req *s
identifierType = ? AND
identifierValue IN (%s)`,
authzFields,
strings.Join(qmarks, ","),
db.QuestionMarks(len(req.Domains)),
)

_, err := ssa.dbReadOnlyMap.Select(
Expand Down Expand Up @@ -965,30 +963,31 @@ func (ssa *SQLStorageAuthorityRO) GetValidAuthorizations2(ctx context.Context, r
return nil, errIncompleteRequest
}

var authzModels []authzModel
query := fmt.Sprintf(
`SELECT %s FROM authz2 WHERE
registrationID = ? AND
status = ? AND
expires > ? AND
identifierType = ? AND
identifierValue IN (%s)`,
authzFields,
db.QuestionMarks(len(req.Domains)),
)

params := []interface{}{
req.RegistrationID,
statusUint(core.StatusValid),
time.Unix(0, req.Now),
identifierTypeToUint[string(identifier.DNS)],
}
qmarks := make([]string, len(req.Domains))
for i, n := range req.Domains {
qmarks[i] = "?"
params = append(params, n)
for _, domain := range req.Domains {
params = append(params, domain)
}

var authzModels []authzModel
_, err := ssa.dbReadOnlyMap.Select(
&authzModels,
fmt.Sprintf(
`SELECT %s FROM authz2 WHERE
registrationID = ? AND
status = ? AND
expires > ? AND
identifierType = ? AND
identifierValue IN (%s)`,
authzFields,
strings.Join(qmarks, ","),
),
query,
params...,
)
if err != nil {
Expand Down

0 comments on commit 4be76af

Please sign in to comment.