Skip to content

Commit

Permalink
User functions via isolate. Fixes TryGhost#140.
Browse files Browse the repository at this point in the history
  • Loading branch information
wbyoung committed Jun 30, 2015
1 parent 1127c27 commit f8cd018
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 0 deletions.
143 changes: 143 additions & 0 deletions src/database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include "database.h"
#include "statement.h"

#ifndef SQLITE_DETERMINISTIC
#define SQLITE_DETERMINISTIC 0x800
#endif

using namespace node_sqlite3;

Persistent<FunctionTemplate> Database::constructor_template;
Expand All @@ -24,6 +28,7 @@ void Database::Init(Handle<Object> target) {
NODE_SET_PROTOTYPE_METHOD(t, "serialize", Serialize);
NODE_SET_PROTOTYPE_METHOD(t, "parallelize", Parallelize);
NODE_SET_PROTOTYPE_METHOD(t, "configure", Configure);
NODE_SET_PROTOTYPE_METHOD(t, "registerFunction", RegisterFunction);

NODE_SET_GETTER(t, "open", OpenGetter);

Expand Down Expand Up @@ -356,6 +361,144 @@ NAN_METHOD(Database::Configure) {
NanReturnValue(args.This());
}

NAN_METHOD(Database::RegisterFunction) {
NanScope();
Database* db = ObjectWrap::Unwrap<Database>(args.This());

REQUIRE_ARGUMENTS(2);
REQUIRE_ARGUMENT_STRING(0, functionName);
REQUIRE_ARGUMENT_FUNCTION(1, callback);

std::string str = "(" + std::string(*String::Utf8Value(callback->ToString())) + ")";

Isolate *isolate = v8::Isolate::New();
isolate->Enter();
{
Locker locker(isolate);
Isolate::Scope isolate_scope(isolate);
HandleScope handle_scope(isolate);
Local<Context> context = Context::New(isolate);
Context::Scope context_scope(context);

Local<Object> global = NanGetCurrentContext()->Global();
Local<Function> eval = Local<Function>::Cast(global->Get(NanNew<String>("eval")));

// Local<String> str = String::Concat(String::Concat(NanNew<String>("("), callback->ToString()), NanNew<String>(")"));
Local<Value> argv[] = { NanNew<String>(str.c_str(), str.length()) };
// Local<Function> function = Local<Function>::Cast(TRY_CATCH_CALL(global, eval, 1, argv));
Local<Function> function = Local<Function>::Cast(eval->Call(global, 1, argv));

FunctionEnvironment *fn = new FunctionEnvironment(isolate, *functionName, function);
sqlite3_create_function(
db->_handle,
*functionName,
-1, // arbitrary number of args
SQLITE_UTF8 | SQLITE_DETERMINISTIC,
fn,
FunctionIsolate,
NULL,
NULL);
}
isolate->Exit();

NanReturnValue(args.This());
}

void Database::FunctionIsolate(sqlite3_context *context, int argc, sqlite3_value **argv) {
FunctionEnvironment *fn = (FunctionEnvironment *)sqlite3_user_data(context);

Isolate *isolate = fn->isolate;
isolate->Enter();
{
Locker locker(isolate);
HandleScope handle_scope(isolate);
Local<Context> v8context = Context::New(isolate);
Context::Scope context_scope(v8context);

Database::FunctionExecute(fn, context, argc, argv);
}
isolate->Exit();
}

void Database::FunctionExecute(FunctionEnvironment *fn, sqlite3_context *context, int argc, sqlite3_value **argv) {
NanScope();

Local<Function> cb = NanNew(fn->callback);
sqlite3_value **values = argv;

if (!cb.IsEmpty() && cb->IsFunction()) {

// build the argument list for the function call
typedef Local<Value> LocalValue;
std::vector<LocalValue> argv;
for (int i = 0; i < argc; i++) {
sqlite3_value *value = values[i];
int type = sqlite3_value_type(value);
Local<Value> arg;
switch(type) {
case SQLITE_INTEGER: {
arg = NanNew<Number>(sqlite3_value_int64(value));
} break;
case SQLITE_FLOAT: {
arg = NanNew<Number>(sqlite3_value_double(value));
} break;
case SQLITE_TEXT: {
const char* text = (const char*)sqlite3_value_text(value);
int length = sqlite3_value_bytes(value);
arg = NanNew<String>(text, length);
} break;
case SQLITE_BLOB: {
const void *blob = sqlite3_value_blob(value);
int length = sqlite3_value_bytes(value);
arg = NanNew(NanNewBufferHandle((char *)blob, length));
} break;
case SQLITE_NULL: {
arg = NanNew(NanNull());
} break;
}

argv.push_back(arg);
}

TryCatch trycatch;
Local<Value> result = cb->Call(NanGetCurrentContext()->Global(), argc, argv.data());

// process the result
if (trycatch.HasCaught()) {
String::Utf8Value message(trycatch.Message()->Get());
sqlite3_result_error(context, *message, message.length());
}
else if (result->IsString() || result->IsRegExp()) {
String::Utf8Value value(result->ToString());
sqlite3_result_text(context, *value, value.length(), SQLITE_TRANSIENT);
}
else if (result->IsInt32()) {
sqlite3_result_int(context, result->Int32Value());
}
else if (result->IsNumber() || result->IsDate()) {
sqlite3_result_double(context, result->NumberValue());
}
else if (result->IsBoolean()) {
sqlite3_result_int(context, result->BooleanValue());
}
else if (result->IsNull() || result->IsUndefined()) {
sqlite3_result_null(context);
}
else if (Buffer::HasInstance(result)) {
Local<Object> buffer = result->ToObject();
sqlite3_result_blob(context,
Buffer::Data(buffer),
Buffer::Length(buffer),
SQLITE_TRANSIENT);
}
else {
std::string message("invalid return type in user function");
message = message + " " + fn->name;
sqlite3_result_error(context, message.c_str(), message.length());
}
}
}

void Database::SetBusyTimeout(Baton* baton) {
assert(baton->db->open);
assert(baton->db->_handle);
Expand Down
18 changes: 18 additions & 0 deletions src/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ class Database : public ObjectWrap {
Baton(db_, cb_), filename(filename_) {}
};

struct FunctionEnvironment {
Isolate *isolate;
std::string name;
Persistent<Function> callback;

FunctionEnvironment(Isolate* isolate_, const char* name_, Handle<Function> cb_) :
isolate(isolate_), name(name_) {
NanAssignPersistent(callback, cb_);
}
virtual ~FunctionEnvironment() {
NanDisposePersistent(callback);
}
};

typedef void (*Work_Callback)(Baton* baton);

struct Call {
Expand Down Expand Up @@ -152,6 +166,10 @@ class Database : public ObjectWrap {

static NAN_METHOD(Configure);

static NAN_METHOD(RegisterFunction);
static void FunctionIsolate(sqlite3_context *context, int argc, sqlite3_value **argv);
static void FunctionExecute(FunctionEnvironment *baton, sqlite3_context *context, int argc, sqlite3_value **argv);

static void SetBusyTimeout(Baton* baton);

static void RegisterTraceCallback(Baton* baton);
Expand Down
107 changes: 107 additions & 0 deletions test/user_functions.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
var sqlite3 = require('..');
var assert = require('assert');

describe('user functions', function() {
var db;
before(function(done) { db = new sqlite3.Database(':memory:', done); });

it('should allow registration of user functions', function() {
db.registerFunction('MY_UPPERCASE', function(value) {
return value.toUpperCase();
});
db.registerFunction('MY_STRING_JOIN', function(value1, value2) {
return [value1, value2].join(' ');
});
db.registerFunction('MY_Add', function(value1, value2) {
return value1 + value2;
});
db.registerFunction('MY_REGEX', function(regex, value) {
return !!value.match(new RegExp(regex));
});
db.registerFunction('MY_REGEX_VALUE', function(regex, value) {
return /match things/i;
});
db.registerFunction('MY_ERROR', function(value) {
throw new Error('This function always throws');
});
db.registerFunction('MY_UNHANDLED_TYPE', function(value) {
return {};
});
db.registerFunction('MY_NOTHING', function(value) {

});
});

it('should process user functions with one arg', function(done) {
db.all('SELECT MY_UPPERCASE("hello") AS txt', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].txt, 'HELLO')
done();
});
});

it('should process user functions with two args', function(done) {
db.all('SELECT MY_STRING_JOIN("hello", "world") AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].val, 'hello world');
done();
});
});

it('should process user functions with number args', function(done) {
db.all('SELECT MY_ADD(1, 2) AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].val, 3);
done();
});
});

it('allows writing of a regex function', function(done) {
db.all('SELECT MY_REGEX("colou?r", "color") AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(Boolean(rows[0].val), true);
done();
});
});

it('converts returned regex instances to strings', function(done) {
db.all('SELECT MY_REGEX_VALUE() AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].val, '/match things/i');
done();
});
});

it('reports errors thrown in functions', function(done) {
db.all('SELECT MY_ERROR() AS val', function(err, rows) {
assert.equal(err.message, 'SQLITE_ERROR: Uncaught Error: This function always throws');
assert.equal(rows, undefined);
done();
});
});

it('reports errors for unhandled types', function(done) {
db.all('SELECT MY_UNHANDLED_TYPE() AS val', function(err, rows) {
assert.equal(err.message, 'SQLITE_ERROR: invalid return type in ' +
'user function MY_UNHANDLED_TYPE');
assert.equal(rows, undefined);
done();
});
});

it('allows no return value from functions', function(done) {
db.all('SELECT MY_NOTHING() AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].val, undefined);
done();
});
});

after(function(done) { db.close(done); });
});

0 comments on commit f8cd018

Please sign in to comment.