Skip to content

Commit

Permalink
Adds support for scripting engines as Valkey modules
Browse files Browse the repository at this point in the history
This commit extends the module API to support the addition of different
scripting engines to run user defined functions.

The scripting engine can be implemented as a Valkey module, and can be
dynamically loaded with the `loadmodule` config directive, or with
the `MODULE LOAD` command.

This commit also adds an example of a dummy scripting engine module,
to show how to use the new module API.

The current module API support, only allows to load scripting engines to
run functions using `FCALL` command.

In a follow up PR, we will move the Lua scripting engine implmentation
into its own module.

Signed-off-by: Ricardo Dias <ricardo.dias@percona.com>
  • Loading branch information
rjd15372 committed Nov 8, 2024
1 parent 45d596e commit d3293dd
Show file tree
Hide file tree
Showing 11 changed files with 491 additions and 126 deletions.
2 changes: 1 addition & 1 deletion src/aof.c
Original file line number Diff line number Diff line change
Expand Up @@ -2175,7 +2175,7 @@ static int rewriteFunctions(rio *aof) {
dictIterator *iter = dictGetIterator(functions);
dictEntry *entry = NULL;
while ((entry = dictNext(iter))) {
functionLibInfo *li = dictGetVal(entry);
ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry);
if (rioWrite(aof, "*3\r\n", 4) == 0) goto werr;
char function_load[] = "$8\r\nFUNCTION\r\n$4\r\nLOAD\r\n";
if (rioWrite(aof, function_load, sizeof(function_load) - 1) == 0) goto werr;
Expand Down
15 changes: 8 additions & 7 deletions src/function_lua.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ typedef struct luaFunctionCtx {
} luaFunctionCtx;

typedef struct loadCtx {
functionLibInfo *li;
ValkeyModuleScriptingEngineFunctionLibrary *li;
monotime start_time;
size_t timeout;
} loadCtx;
Expand Down Expand Up @@ -100,7 +100,7 @@ static void luaEngineLoadHook(lua_State *lua, lua_Debug *ar) {
*
* Return NULL on compilation error and set the error to the err variable
*/
static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size_t timeout, sds *err) {
static int luaEngineCreate(void *engine_ctx, ValkeyModuleScriptingEngineFunctionLibrary *li, const char *blob, size_t timeout, char **err) {
int ret = C_ERR;
luaEngineCtx *lua_engine_ctx = engine_ctx;
lua_State *lua = lua_engine_ctx->lua;
Expand All @@ -114,7 +114,7 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size
lua_pop(lua, 1); /* pop the metatable */

/* compile the code */
if (luaL_loadbuffer(lua, blob, sdslen(blob), "@user_function")) {
if (luaL_loadbuffer(lua, blob, strlen(blob), "@user_function")) {
*err = sdscatprintf(sdsempty(), "Error compiling function: %s", lua_tostring(lua, -1));
lua_pop(lua, 1); /* pops the error */
goto done;
Expand Down Expand Up @@ -158,7 +158,7 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size
/*
* Invole the give function with the given keys and args
*/
static void luaEngineCall(scriptRunCtx *run_ctx,
static void luaEngineCall(ValkeyModuleEngineFunctionCallCtx *func_ctx,
void *engine_ctx,
void *compiled_function,
robj **keys,
Expand All @@ -177,6 +177,7 @@ static void luaEngineCall(scriptRunCtx *run_ctx,

serverAssert(lua_isfunction(lua, -1));

scriptRunCtx *run_ctx = moduleGetScriptRunCtxFromFunctionCtx(func_ctx);
luaCallFunction(run_ctx, lua, keys, nkeys, args, nargs, 0);
lua_pop(lua, 1); /* Pop error handler */
}
Expand Down Expand Up @@ -495,8 +496,8 @@ int luaEngineInitEngine(void) {
lua_replace(lua_engine_ctx->lua, LUA_GLOBALSINDEX); /* set new global table as the new globals */


engine *lua_engine = zmalloc(sizeof(*lua_engine));
*lua_engine = (engine){
ValkeyModuleScriptingEngine *lua_engine = zmalloc(sizeof(*lua_engine));
*lua_engine = (ValkeyModuleScriptingEngine){
.engine_ctx = lua_engine_ctx,
.create = luaEngineCreate,
.call = luaEngineCall,
Expand All @@ -505,5 +506,5 @@ int luaEngineInitEngine(void) {
.get_engine_memory_overhead = luaEngineMemoryOverhead,
.free_function = luaEngineFreeFunction,
};
return functionsRegisterEngine(LUA_ENGINE_NAME, lua_engine);
return functionsRegisterEngine(LUA_ENGINE_NAME, NULL, lua_engine);
}
101 changes: 39 additions & 62 deletions src/functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ static size_t functionMallocSize(functionInfo *fi) {
fi->li->ei->engine->get_function_memory_overhead(fi->function);
}

static size_t libraryMallocSize(functionLibInfo *li) {
static size_t libraryMallocSize(ValkeyModuleScriptingEngineFunctionLibrary *li) {
return zmalloc_size(li) + sdsAllocSize(li->name) + sdsAllocSize(li->code);
}

Expand All @@ -143,12 +143,12 @@ static void engineFunctionDispose(dict *d, void *obj) {
if (fi->desc) {
sdsfree(fi->desc);
}
engine *engine = fi->li->ei->engine;
ValkeyModuleScriptingEngine *engine = fi->li->ei->engine;
engine->free_function(engine->engine_ctx, fi->function);
zfree(fi);
}

static void engineLibraryFree(functionLibInfo *li) {
static void engineLibraryFree(ValkeyModuleScriptingEngineFunctionLibrary *li) {
if (!li) {
return;
}
Expand Down Expand Up @@ -227,6 +227,15 @@ functionsLibCtx *functionsLibCtxCreate(void) {
return ret;
}

void functionsAddEngineStats(engineInfo *ei) {
serverAssert(curr_functions_lib_ctx != NULL);
dictEntry *entry = dictFind(curr_functions_lib_ctx->engines_stats, ei->name);
if (entry == NULL) {
functionsLibEngineStats *stats = zcalloc(sizeof(*stats));
dictAdd(curr_functions_lib_ctx->engines_stats, ei->name, stats);
}
}

/*
* Creating a function inside the given library.
* On success, return C_OK.
Expand All @@ -236,7 +245,7 @@ functionsLibCtx *functionsLibCtxCreate(void) {
* the function will verify that the given name is following the naming format
* and return an error if its not.
*/
int functionLibCreateFunction(sds name, void *function, functionLibInfo *li, sds desc, uint64_t f_flags, sds *err) {
int functionLibCreateFunction(sds name, void *function, ValkeyModuleScriptingEngineFunctionLibrary *li, sds desc, uint64_t f_flags, sds *err) {
if (functionsVerifyName(name) != C_OK) {
*err = sdsnew("Library names can only contain letters, numbers, or underscores(_) and must be at least one "
"character long");
Expand All @@ -263,9 +272,9 @@ int functionLibCreateFunction(sds name, void *function, functionLibInfo *li, sds
return C_OK;
}

static functionLibInfo *engineLibraryCreate(sds name, engineInfo *ei, sds code) {
functionLibInfo *li = zmalloc(sizeof(*li));
*li = (functionLibInfo){
static ValkeyModuleScriptingEngineFunctionLibrary *engineLibraryCreate(sds name, engineInfo *ei, sds code) {
ValkeyModuleScriptingEngineFunctionLibrary *li = zmalloc(sizeof(*li));
*li = (ValkeyModuleScriptingEngineFunctionLibrary){
.name = sdsdup(name),
.functions = dictCreate(&libraryFunctionDictType),
.ei = ei,
Expand All @@ -274,7 +283,7 @@ static functionLibInfo *engineLibraryCreate(sds name, engineInfo *ei, sds code)
return li;
}

static void libraryUnlink(functionsLibCtx *lib_ctx, functionLibInfo *li) {
static void libraryUnlink(functionsLibCtx *lib_ctx, ValkeyModuleScriptingEngineFunctionLibrary *li) {
dictIterator *iter = dictGetIterator(li->functions);
dictEntry *entry = NULL;
while ((entry = dictNext(iter))) {
Expand All @@ -296,7 +305,7 @@ static void libraryUnlink(functionsLibCtx *lib_ctx, functionLibInfo *li) {
stats->n_functions -= dictSize(li->functions);
}

static void libraryLink(functionsLibCtx *lib_ctx, functionLibInfo *li) {
static void libraryLink(functionsLibCtx *lib_ctx, ValkeyModuleScriptingEngineFunctionLibrary *li) {
dictIterator *iter = dictGetIterator(li->functions);
dictEntry *entry = NULL;
while ((entry = dictNext(iter))) {
Expand Down Expand Up @@ -332,8 +341,8 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l
dictEntry *entry = NULL;
iter = dictGetIterator(functions_lib_ctx_src->libraries);
while ((entry = dictNext(iter))) {
functionLibInfo *li = dictGetVal(entry);
functionLibInfo *old_li = dictFetchValue(functions_lib_ctx_dst->libraries, li->name);
ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry);
ValkeyModuleScriptingEngineFunctionLibrary *old_li = dictFetchValue(functions_lib_ctx_dst->libraries, li->name);
if (old_li) {
if (!replace) {
/* library already exists, failed the restore. */
Expand Down Expand Up @@ -367,7 +376,7 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l
/* No collision, it is safe to link all the new libraries. */
iter = dictGetIterator(functions_lib_ctx_src->libraries);
while ((entry = dictNext(iter))) {
functionLibInfo *li = dictGetVal(entry);
ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry);
libraryLink(functions_lib_ctx_dst, li);
dictSetVal(functions_lib_ctx_src->libraries, entry, NULL);
}
Expand All @@ -387,7 +396,7 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l
/* Link back all libraries on tmp_l_ctx */
while (listLength(old_libraries_list) > 0) {
listNode *head = listFirst(old_libraries_list);
functionLibInfo *li = listNodeValue(head);
ValkeyModuleScriptingEngineFunctionLibrary *li = listNodeValue(head);
listNodeValue(head) = NULL;
libraryLink(functions_lib_ctx_dst, li);
listDelNode(old_libraries_list, head);
Expand All @@ -401,7 +410,9 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l
*
* - engine_name - name of the engine to register
* - engine_ctx - the engine ctx that should be used by the server to interact with the engine */
int functionsRegisterEngine(const char *engine_name, engine *engine) {
int functionsRegisterEngine(const char *engine_name,
ValkeyModule *engine_module,
ValkeyModuleScriptingEngine *engine) {
sds engine_name_sds = sdsnew(engine_name);
if (dictFetchValue(engines, engine_name_sds)) {
serverLog(LL_WARNING, "Same engine was registered twice");
Expand All @@ -416,12 +427,15 @@ int functionsRegisterEngine(const char *engine_name, engine *engine) {
engineInfo *ei = zmalloc(sizeof(*ei));
*ei = (engineInfo){
.name = engine_name_sds,
.engineModule = engine_module,
.engine = engine,
.c = c,
};

dictAdd(engines, engine_name_sds, ei);

functionsAddEngineStats(ei);

engine_cache_memory += zmalloc_size(ei) + sdsAllocSize(ei->name) + zmalloc_size(engine) +
engine->get_engine_memory_overhead(engine->engine_ctx);

Expand Down Expand Up @@ -535,7 +549,7 @@ void functionListCommand(client *c) {
dictIterator *iter = dictGetIterator(curr_functions_lib_ctx->libraries);
dictEntry *entry = NULL;
while ((entry = dictNext(iter))) {
functionLibInfo *li = dictGetVal(entry);
ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry);
if (library_name) {
if (!stringmatchlen(library_name, sdslen(library_name), li->name, sdslen(li->name), 1)) {
continue;
Expand Down Expand Up @@ -584,7 +598,7 @@ void functionListCommand(client *c) {
*/
void functionDeleteCommand(client *c) {
robj *function_name = c->argv[2];
functionLibInfo *li = dictFetchValue(curr_functions_lib_ctx->libraries, function_name->ptr);
ValkeyModuleScriptingEngineFunctionLibrary *li = dictFetchValue(curr_functions_lib_ctx->libraries, function_name->ptr);
if (!li) {
addReplyError(c, "Library not found");
return;
Expand Down Expand Up @@ -614,55 +628,18 @@ uint64_t fcallGetCommandFlags(client *c, uint64_t cmd_flags) {
return scriptFlagsToCmdFlags(cmd_flags, script_flags);
}

static void fcallCommandGeneric(client *c, int ro) {
/* Functions need to be fed to monitors before the commands they execute. */
replicationFeedMonitors(c, server.monitors, c->db->id, c->argv, c->argc);

robj *function_name = c->argv[1];
dictEntry *de = c->cur_script;
if (!de) de = dictFind(curr_functions_lib_ctx->functions, function_name->ptr);
if (!de) {
addReplyError(c, "Function not found");
return;
}
functionInfo *fi = dictGetVal(de);
engine *engine = fi->li->ei->engine;

long long numkeys;
/* Get the number of arguments that are keys */
if (getLongLongFromObject(c->argv[2], &numkeys) != C_OK) {
addReplyError(c, "Bad number of keys provided");
return;
}
if (numkeys > (c->argc - 3)) {
addReplyError(c, "Number of keys can't be greater than number of args");
return;
} else if (numkeys < 0) {
addReplyError(c, "Number of keys can't be negative");
return;
}

scriptRunCtx run_ctx;

if (scriptPrepareForRun(&run_ctx, fi->li->ei->c, c, fi->name, fi->f_flags, ro) != C_OK) return;

engine->call(&run_ctx, engine->engine_ctx, fi->function, c->argv + 3, numkeys, c->argv + 3 + numkeys,
c->argc - 3 - numkeys);
scriptResetRun(&run_ctx);
}

/*
* FCALL <FUNCTION NAME> nkeys <key1 .. keyn> <arg1 .. argn>
*/
void fcallCommand(client *c) {
fcallCommandGeneric(c, 0);
fcallCommandGeneric(curr_functions_lib_ctx->functions, c, 0);
}

/*
* FCALL_RO <FUNCTION NAME> nkeys <key1 .. keyn> <arg1 .. argn>
*/
void fcallroCommand(client *c) {
fcallCommandGeneric(c, 1);
fcallCommandGeneric(curr_functions_lib_ctx->functions, c, 1);
}

/*
Expand Down Expand Up @@ -952,9 +929,10 @@ void functionFreeLibMetaData(functionsLibMetaData *md) {
sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibCtx *lib_ctx, size_t timeout) {
dictIterator *iter = NULL;
dictEntry *entry = NULL;
functionLibInfo *new_li = NULL;
functionLibInfo *old_li = NULL;
ValkeyModuleScriptingEngineFunctionLibrary *old_li = NULL;
functionsLibMetaData md = {0};
ValkeyModuleScriptingEngineFunctionLibrary *new_li = NULL;

if (functionExtractLibMetaData(code, &md, err) != C_OK) {
return NULL;
}
Expand All @@ -970,7 +948,7 @@ sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibC
*err = sdscatfmt(sdsempty(), "Engine '%S' not found", md.engine);
goto error;
}
engine *engine = ei->engine;
ValkeyModuleScriptingEngine *engine = ei->engine;

old_li = dictFetchValue(lib_ctx->libraries, md.name);
if (old_li && !replace) {
Expand Down Expand Up @@ -1073,7 +1051,7 @@ unsigned long functionsMemory(void) {
size_t engines_memory = 0;
while ((entry = dictNext(iter))) {
engineInfo *ei = dictGetVal(entry);
engine *engine = ei->engine;
ValkeyModuleScriptingEngine *engine = ei->engine;
engines_memory += engine->get_used_memory(engine->engine_ctx);
}
dictReleaseIterator(iter);
Expand Down Expand Up @@ -1114,12 +1092,11 @@ size_t functionsLibCtxFunctionsLen(functionsLibCtx *functions_ctx) {
int functionsInit(void) {
engines = dictCreate(&engineDictType);

curr_functions_lib_ctx = functionsLibCtxCreate();

if (luaEngineInitEngine() != C_OK) {
return C_ERR;
}

/* Must be initialized after engines initialization */
curr_functions_lib_ctx = functionsLibCtxCreate();

return C_OK;
}
Loading

0 comments on commit d3293dd

Please sign in to comment.