diff --git a/lightningd/plugin_hook.c b/lightningd/plugin_hook.c index e1fd750a307f..271389c49dee 100644 --- a/lightningd/plugin_hook.c +++ b/lightningd/plugin_hook.c @@ -156,13 +156,14 @@ static void db_hook_response(const char *buffer, const jsmntok_t *toks, io_break(ph_req); } -void plugin_hook_db_sync(struct db *db, const char **changes, const char *final) +void plugin_hook_db_sync(struct db *db) { const struct plugin_hook *hook = &db_write_hook; struct jsonrpc_request *req; struct plugin_hook_request *ph_req; void *ret; + const char **changes = db_changes(db); if (!hook->plugin) return; @@ -174,11 +175,11 @@ void plugin_hook_db_sync(struct db *db, const char **changes, const char *final) ph_req->hook = hook; ph_req->db = db; + json_add_num(req->stream, "data_version", db_data_version_get(db)); + json_array_start(req->stream, "writes"); for (size_t i = 0; i < tal_count(changes); i++) json_add_string(req->stream, NULL, changes[i]); - if (final) - json_add_string(req->stream, NULL, final); json_array_end(req->stream); jsonrpc_request_end(req); diff --git a/lightningd/plugin_hook.h b/lightningd/plugin_hook.h index 4869d69e9af3..d5ed4d659711 100644 --- a/lightningd/plugin_hook.h +++ b/lightningd/plugin_hook.h @@ -108,8 +108,7 @@ bool plugin_hook_unregister(struct plugin *plugin, const char *method); /* Unregister all hooks a plugin has registered for */ void plugin_hook_unregister_all(struct plugin *plugin); -/* Special sync plugin hook for db: changes[] are SQL statements, with optional - * final command appended. */ -void plugin_hook_db_sync(struct db *db, const char **changes, const char *final); +/* Special sync plugin hook for db. */ +void plugin_hook_db_sync(struct db *db); #endif /* LIGHTNING_LIGHTNINGD_PLUGIN_HOOK_H */ diff --git a/tests/test_db.py b/tests/test_db.py index ee942eed44cc..c1e0c4edc4ad 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,8 +1,10 @@ from fixtures import * # noqa: F401,F403 -from utils import wait_for, sync_blockheight, COMPAT from fixtures import TEST_NETWORK - +from pyln.client import RpcError +from utils import wait_for, sync_blockheight, COMPAT import os +import pytest +import time import unittest @@ -136,3 +138,23 @@ def test_scid_upgrade(node_factory, bitcoind): assert l1.db_query('SELECT short_channel_id from channels;') == [{'short_channel_id': '103x1x1'}] assert l1.db_query('SELECT failchannel from payments;') == [{'failchannel': '103x1x1'}] + + +def test_optimistic_locking(node_factory, bitcoind): + """Have a node run against a DB, then change it under its feet, crashing it. + + We start a node, wait for it to settle its write so we have a window where + we can interfere, and watch the world burn (safely). + """ + l1 = node_factory.get_node(may_fail=True, allow_broken_log=True) + + sync_blockheight(bitcoind, [l1]) + l1.rpc.getinfo() + time.sleep(1) + l1.db.execute("UPDATE vars SET intval = intval + 1 WHERE name = 'data_version';") + + # Now trigger any DB write and we should be crashing. + with pytest.raises(RpcError, match=r'Connection to RPC server lost.'): + l1.rpc.newaddr() + + assert(l1.daemon.is_in_log(r'Optimistic lock on the database failed')) diff --git a/wallet/db.c b/wallet/db.c index c6ed48b3fbaf..edfe3f4475e5 100644 --- a/wallet/db.c +++ b/wallet/db.c @@ -589,6 +589,7 @@ static struct migration dbmigrations[] = { " SELECT id, 11, local_feerate_per_kw FROM channels WHERE funder = 1 and local_feerate_per_kw != remote_feerate_per_kw;"), NULL}, /* FIXME: Remove now-unused local_feerate_per_kw and remote_feerate_per_kw from channels */ + {SQL("INSERT INTO vars (name, intval) VALUES ('data_version', 0);"), NULL}, }; /* Leak tracking. */ @@ -763,8 +764,13 @@ static void db_report_changes(struct db *db, const char *final, size_t min) assert(db->changes); assert(tal_count(db->changes) >= min); + /* Having changes implies that we have a dirty TX. The opposite is + * currently not true, e.g., the postgres driver doesn't record + * changes yet. */ + assert(!tal_count(db->changes) || db->dirty); + if (tal_count(db->changes) > min) - plugin_hook_db_sync(db, db->changes, final); + plugin_hook_db_sync(db); db->changes = tal_free(db->changes); } @@ -785,6 +791,9 @@ void db_begin_transaction_(struct db *db, const char *location) if (db->in_transaction) db_fatal("Already in transaction from %s", db->in_transaction); + /* No writes yet. */ + db->dirty = false; + db_prepare_for_changes(db); ok = db->config->begin_tx_fn(db); if (!ok) @@ -793,11 +802,44 @@ void db_begin_transaction_(struct db *db, const char *location) db->in_transaction = location; } +/* By making the update conditional on the current value we expect we + * are implementing an optimistic lock: if the update results in + * changes on the DB we know that the data_version did not change + * under our feet and no other transaction ran in the meantime. + * + * Notice that this update effectively locks the row, so that other + * operations attempting to change this outside the transaction will + * wait for this transaction to complete. The external change will + * ultimately fail the changes test below, it'll just delay its abort + * until our transaction is committed. + */ +static void db_data_version_incr(struct db *db) +{ + struct db_stmt *stmt = db_prepare_v2( + db, SQL("UPDATE vars " + "SET intval = intval + 1 " + "WHERE name = 'data_version'" + " AND intval = ?")); + db_bind_int(stmt, 0, db->data_version); + db_exec_prepared_v2(stmt); + if (db_count_changes(stmt) != 1) + fatal("Optimistic lock on the database failed. There may be a " + "concurrent access to the database. Aborting since " + "concurrent access is unsafe."); + tal_free(stmt); + db->data_version++; +} + void db_commit_transaction(struct db *db) { bool ok; assert(db->in_transaction); db_assert_no_outstanding_statements(db); + + /* Increment before reporting changes to an eventual plugin. */ + if (db->dirty) + db_data_version_incr(db); + db_report_changes(db, NULL, 0); ok = db->config->commit_tx_fn(db); @@ -805,6 +847,7 @@ void db_commit_transaction(struct db *db) db_fatal("Failed to commit DB transaction: %s", db->error); db->in_transaction = NULL; + db->dirty = false; } static struct db_config *db_config_find(const char *dsn) @@ -905,8 +948,6 @@ static void db_migrate(struct lightningd *ld, struct db *db) int current, orig, available; struct db_stmt *stmt; - db_begin_transaction(db); - orig = current = db_get_version(db); available = ARRAY_SIZE(dbmigrations) - 1; @@ -946,15 +987,31 @@ static void db_migrate(struct lightningd *ld, struct db *db) db_exec_prepared_v2(stmt); tal_free(stmt); } +} - db_commit_transaction(db); +u32 db_data_version_get(struct db *db) +{ + struct db_stmt *stmt; + u32 version; + stmt = db_prepare_v2(db, SQL("SELECT intval FROM vars WHERE name = 'data_version'")); + db_query_prepared(stmt); + db_step(stmt); + version = db_column_int(stmt, 0); + tal_free(stmt); + return version; } struct db *db_setup(const tal_t *ctx, struct lightningd *ld) { struct db *db = db_open(ctx, ld->wallet_dsn); db->log = new_log(db, ld->log_book, NULL, "database"); + + db_begin_transaction(db); + db_migrate(ld, db); + + db->data_version = db_data_version_get(db); + db_commit_transaction(db); return db; } @@ -1355,6 +1412,10 @@ void db_column_txid(struct db_stmt *stmt, int pos, struct bitcoin_txid *t) bool db_exec_prepared_v2(struct db_stmt *stmt TAKES) { bool ret = stmt->db->config->exec_fn(stmt); + + /* If this was a write we need to bump the data_version upon commit. */ + stmt->db->dirty = stmt->db->dirty || !stmt->query->readonly; + stmt->executed = true; list_del_from(&stmt->db->pending_statements, &stmt->list); @@ -1399,3 +1460,8 @@ void db_changes_add(struct db_stmt *stmt, const char * expanded) tal_arr_expand(&db->changes, tal_strdup(db->changes, expanded)); } + +const char **db_changes(struct db *db) +{ + return db->changes; +} diff --git a/wallet/db.h b/wallet/db.h index 2a0c2b6565cb..5f352be2b5c6 100644 --- a/wallet/db.h +++ b/wallet/db.h @@ -224,4 +224,12 @@ struct db_stmt *db_prepare_v2_(const char *location, struct db *db, #define db_prepare_v2(db,query) \ db_prepare_v2_(__FILE__ ":" stringify(__LINE__), db, query) +/** + * Access pending changes that have been added to the current transaction. + */ +const char **db_changes(struct db *db); + +/* Get the current data version. */ +u32 db_data_version_get(struct db *db); + #endif /* LIGHTNING_WALLET_DB_H */ diff --git a/wallet/db_common.h b/wallet/db_common.h index 0fa1a09ed9b9..69fdd9da028b 100644 --- a/wallet/db_common.h +++ b/wallet/db_common.h @@ -30,6 +30,14 @@ struct db { char *error; struct log *log; + + /* Were there any modifying statements in the current transaction? + * Used to bump the data_version in the DB.*/ + bool dirty; + + /* The current DB version we expect to update if changes are + * committed. */ + u32 data_version; }; struct db_query { diff --git a/wallet/test/run-db.c b/wallet/test/run-db.c index a6859bddcd7c..f45be567557d 100644 --- a/wallet/test/run-db.c +++ b/wallet/test/run-db.c @@ -47,7 +47,7 @@ static void db_test_fatal(const char *fmt, ...) va_end(ap); } -void plugin_hook_db_sync(struct db *db UNNEEDED, const char **changes UNNEEDED, const char *final UNNEEDED) +void plugin_hook_db_sync(struct db *db UNNEEDED) { } @@ -63,6 +63,7 @@ static struct db *create_test_db(void) dsn = tal_fmt(NULL, "sqlite3://%s", filename); db = db_open(NULL, dsn); + db->data_version = 0; tal_free(dsn); return db; } @@ -73,8 +74,9 @@ static bool test_empty_db_migrate(struct lightningd *ld) CHECK(db); db_begin_transaction(db); CHECK(db_get_version(db) == -1); - db_commit_transaction(db); db_migrate(ld, db); + db_commit_transaction(db); + db_begin_transaction(db); CHECK(db_get_version(db) == ARRAY_SIZE(dbmigrations) - 1); db_commit_transaction(db); @@ -106,6 +108,10 @@ static bool test_primitives(void) CHECK_MSG(db_err, "Failing SQL command"); tal_free(stmt); db_err = tal_free(db_err); + + /* We didn't migrate the DB, so don't have the vars table. Pretend we + * didn't change anything so we don't bump the data_version. */ + db->dirty = false; db_commit_transaction(db); CHECK(!db->in_transaction); tal_free(db); @@ -118,9 +124,9 @@ static bool test_vars(struct lightningd *ld) struct db *db = create_test_db(); char *varname = "testvar"; CHECK(db); - db_migrate(ld, db); db_begin_transaction(db); + db_migrate(ld, db); /* Check default behavior */ CHECK(db_get_intvar(db, varname, 42) == 42); diff --git a/wallet/test/run-wallet.c b/wallet/test/run-wallet.c index 02f8fe95c577..7928c3e5592f 100644 --- a/wallet/test/run-wallet.c +++ b/wallet/test/run-wallet.c @@ -650,7 +650,7 @@ u8 *wire_sync_read(const tal_t *ctx UNNEEDED, int fd UNNEEDED) { return NULL; } -void plugin_hook_db_sync(struct db *db UNNEEDED, const char **changes UNNEEDED, const char *final UNNEEDED) +void plugin_hook_db_sync(struct db *db UNNEEDED) { } bool fromwire_hsm_get_channel_basepoints_reply(const void *p UNNEEDED, @@ -747,7 +747,10 @@ static struct wallet *create_test_wallet(struct lightningd *ld, const tal_t *ctx w->bip32_base) == WALLY_OK); CHECK_MSG(w->db, "Failed opening the db"); + db_begin_transaction(w->db); db_migrate(ld, w->db); + w->db->data_version = 0; + db_commit_transaction(w->db); CHECK_MSG(!wallet_err, "DB migration failed"); w->max_channel_dbid = 0;