From f8cd018a6ac3dae65961bb922f0afad6955a8eb9 Mon Sep 17 00:00:00 2001 From: Whitney Young Date: Tue, 30 Jun 2015 14:53:37 -0700 Subject: [PATCH] User functions via isolate. Fixes #140. --- src/database.cc | 143 ++++++++++++++++++++++++++++++++++++ src/database.h | 18 +++++ test/user_functions.test.js | 107 +++++++++++++++++++++++++++ 3 files changed, 268 insertions(+) create mode 100644 test/user_functions.test.js diff --git a/src/database.cc b/src/database.cc index d34b865a8..bcd093e13 100644 --- a/src/database.cc +++ b/src/database.cc @@ -5,6 +5,10 @@ #include "database.h" #include "statement.h" +#ifndef SQLITE_DETERMINISTIC +#define SQLITE_DETERMINISTIC 0x800 +#endif + using namespace node_sqlite3; Persistent Database::constructor_template; @@ -24,6 +28,7 @@ void Database::Init(Handle target) { NODE_SET_PROTOTYPE_METHOD(t, "serialize", Serialize); NODE_SET_PROTOTYPE_METHOD(t, "parallelize", Parallelize); NODE_SET_PROTOTYPE_METHOD(t, "configure", Configure); + NODE_SET_PROTOTYPE_METHOD(t, "registerFunction", RegisterFunction); NODE_SET_GETTER(t, "open", OpenGetter); @@ -356,6 +361,144 @@ NAN_METHOD(Database::Configure) { NanReturnValue(args.This()); } +NAN_METHOD(Database::RegisterFunction) { + NanScope(); + Database* db = ObjectWrap::Unwrap(args.This()); + + REQUIRE_ARGUMENTS(2); + REQUIRE_ARGUMENT_STRING(0, functionName); + REQUIRE_ARGUMENT_FUNCTION(1, callback); + + std::string str = "(" + std::string(*String::Utf8Value(callback->ToString())) + ")"; + + Isolate *isolate = v8::Isolate::New(); + isolate->Enter(); + { + Locker locker(isolate); + Isolate::Scope isolate_scope(isolate); + HandleScope handle_scope(isolate); + Local context = Context::New(isolate); + Context::Scope context_scope(context); + + Local global = NanGetCurrentContext()->Global(); + Local eval = Local::Cast(global->Get(NanNew("eval"))); + + // Local str = String::Concat(String::Concat(NanNew("("), callback->ToString()), NanNew(")")); + Local argv[] = { NanNew(str.c_str(), str.length()) }; + // Local function = Local::Cast(TRY_CATCH_CALL(global, eval, 1, argv)); + Local function = Local::Cast(eval->Call(global, 1, argv)); + + FunctionEnvironment *fn = new FunctionEnvironment(isolate, *functionName, function); + sqlite3_create_function( + db->_handle, + *functionName, + -1, // arbitrary number of args + SQLITE_UTF8 | SQLITE_DETERMINISTIC, + fn, + FunctionIsolate, + NULL, + NULL); + } + isolate->Exit(); + + NanReturnValue(args.This()); +} + +void Database::FunctionIsolate(sqlite3_context *context, int argc, sqlite3_value **argv) { + FunctionEnvironment *fn = (FunctionEnvironment *)sqlite3_user_data(context); + + Isolate *isolate = fn->isolate; + isolate->Enter(); + { + Locker locker(isolate); + HandleScope handle_scope(isolate); + Local v8context = Context::New(isolate); + Context::Scope context_scope(v8context); + + Database::FunctionExecute(fn, context, argc, argv); + } + isolate->Exit(); +} + +void Database::FunctionExecute(FunctionEnvironment *fn, sqlite3_context *context, int argc, sqlite3_value **argv) { + NanScope(); + + Local cb = NanNew(fn->callback); + sqlite3_value **values = argv; + + if (!cb.IsEmpty() && cb->IsFunction()) { + + // build the argument list for the function call + typedef Local LocalValue; + std::vector argv; + for (int i = 0; i < argc; i++) { + sqlite3_value *value = values[i]; + int type = sqlite3_value_type(value); + Local arg; + switch(type) { + case SQLITE_INTEGER: { + arg = NanNew(sqlite3_value_int64(value)); + } break; + case SQLITE_FLOAT: { + arg = NanNew(sqlite3_value_double(value)); + } break; + case SQLITE_TEXT: { + const char* text = (const char*)sqlite3_value_text(value); + int length = sqlite3_value_bytes(value); + arg = NanNew(text, length); + } break; + case SQLITE_BLOB: { + const void *blob = sqlite3_value_blob(value); + int length = sqlite3_value_bytes(value); + arg = NanNew(NanNewBufferHandle((char *)blob, length)); + } break; + case SQLITE_NULL: { + arg = NanNew(NanNull()); + } break; + } + + argv.push_back(arg); + } + + TryCatch trycatch; + Local result = cb->Call(NanGetCurrentContext()->Global(), argc, argv.data()); + + // process the result + if (trycatch.HasCaught()) { + String::Utf8Value message(trycatch.Message()->Get()); + sqlite3_result_error(context, *message, message.length()); + } + else if (result->IsString() || result->IsRegExp()) { + String::Utf8Value value(result->ToString()); + sqlite3_result_text(context, *value, value.length(), SQLITE_TRANSIENT); + } + else if (result->IsInt32()) { + sqlite3_result_int(context, result->Int32Value()); + } + else if (result->IsNumber() || result->IsDate()) { + sqlite3_result_double(context, result->NumberValue()); + } + else if (result->IsBoolean()) { + sqlite3_result_int(context, result->BooleanValue()); + } + else if (result->IsNull() || result->IsUndefined()) { + sqlite3_result_null(context); + } + else if (Buffer::HasInstance(result)) { + Local buffer = result->ToObject(); + sqlite3_result_blob(context, + Buffer::Data(buffer), + Buffer::Length(buffer), + SQLITE_TRANSIENT); + } + else { + std::string message("invalid return type in user function"); + message = message + " " + fn->name; + sqlite3_result_error(context, message.c_str(), message.length()); + } + } +} + void Database::SetBusyTimeout(Baton* baton) { assert(baton->db->open); assert(baton->db->_handle); diff --git a/src/database.h b/src/database.h index af83ee715..9111aab33 100644 --- a/src/database.h +++ b/src/database.h @@ -69,6 +69,20 @@ class Database : public ObjectWrap { Baton(db_, cb_), filename(filename_) {} }; + struct FunctionEnvironment { + Isolate *isolate; + std::string name; + Persistent callback; + + FunctionEnvironment(Isolate* isolate_, const char* name_, Handle cb_) : + isolate(isolate_), name(name_) { + NanAssignPersistent(callback, cb_); + } + virtual ~FunctionEnvironment() { + NanDisposePersistent(callback); + } + }; + typedef void (*Work_Callback)(Baton* baton); struct Call { @@ -152,6 +166,10 @@ class Database : public ObjectWrap { static NAN_METHOD(Configure); + static NAN_METHOD(RegisterFunction); + static void FunctionIsolate(sqlite3_context *context, int argc, sqlite3_value **argv); + static void FunctionExecute(FunctionEnvironment *baton, sqlite3_context *context, int argc, sqlite3_value **argv); + static void SetBusyTimeout(Baton* baton); static void RegisterTraceCallback(Baton* baton); diff --git a/test/user_functions.test.js b/test/user_functions.test.js new file mode 100644 index 000000000..a19eec66c --- /dev/null +++ b/test/user_functions.test.js @@ -0,0 +1,107 @@ +var sqlite3 = require('..'); +var assert = require('assert'); + +describe('user functions', function() { + var db; + before(function(done) { db = new sqlite3.Database(':memory:', done); }); + + it('should allow registration of user functions', function() { + db.registerFunction('MY_UPPERCASE', function(value) { + return value.toUpperCase(); + }); + db.registerFunction('MY_STRING_JOIN', function(value1, value2) { + return [value1, value2].join(' '); + }); + db.registerFunction('MY_Add', function(value1, value2) { + return value1 + value2; + }); + db.registerFunction('MY_REGEX', function(regex, value) { + return !!value.match(new RegExp(regex)); + }); + db.registerFunction('MY_REGEX_VALUE', function(regex, value) { + return /match things/i; + }); + db.registerFunction('MY_ERROR', function(value) { + throw new Error('This function always throws'); + }); + db.registerFunction('MY_UNHANDLED_TYPE', function(value) { + return {}; + }); + db.registerFunction('MY_NOTHING', function(value) { + + }); + }); + + it('should process user functions with one arg', function(done) { + db.all('SELECT MY_UPPERCASE("hello") AS txt', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(rows[0].txt, 'HELLO') + done(); + }); + }); + + it('should process user functions with two args', function(done) { + db.all('SELECT MY_STRING_JOIN("hello", "world") AS val', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(rows[0].val, 'hello world'); + done(); + }); + }); + + it('should process user functions with number args', function(done) { + db.all('SELECT MY_ADD(1, 2) AS val', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(rows[0].val, 3); + done(); + }); + }); + + it('allows writing of a regex function', function(done) { + db.all('SELECT MY_REGEX("colou?r", "color") AS val', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(Boolean(rows[0].val), true); + done(); + }); + }); + + it('converts returned regex instances to strings', function(done) { + db.all('SELECT MY_REGEX_VALUE() AS val', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(rows[0].val, '/match things/i'); + done(); + }); + }); + + it('reports errors thrown in functions', function(done) { + db.all('SELECT MY_ERROR() AS val', function(err, rows) { + assert.equal(err.message, 'SQLITE_ERROR: Uncaught Error: This function always throws'); + assert.equal(rows, undefined); + done(); + }); + }); + + it('reports errors for unhandled types', function(done) { + db.all('SELECT MY_UNHANDLED_TYPE() AS val', function(err, rows) { + assert.equal(err.message, 'SQLITE_ERROR: invalid return type in ' + + 'user function MY_UNHANDLED_TYPE'); + assert.equal(rows, undefined); + done(); + }); + }); + + it('allows no return value from functions', function(done) { + db.all('SELECT MY_NOTHING() AS val', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(rows[0].val, undefined); + done(); + }); + }); + + after(function(done) { db.close(done); }); +});