diff --git a/cmake/Modules/SourceFiles.cmake b/cmake/Modules/SourceFiles.cmake index 1a754ff846..80c3f0c876 100644 --- a/cmake/Modules/SourceFiles.cmake +++ b/cmake/Modules/SourceFiles.cmake @@ -100,6 +100,7 @@ set(VALKEY_SERVER_SRCS ${CMAKE_SOURCE_DIR}/src/script_lua.c ${CMAKE_SOURCE_DIR}/src/script.c ${CMAKE_SOURCE_DIR}/src/functions.c + ${CMAKE_SOURCE_DIR}/src/scripting_engine.c ${CMAKE_SOURCE_DIR}/src/function_lua.c ${CMAKE_SOURCE_DIR}/src/commands.c ${CMAKE_SOURCE_DIR}/src/strl.c diff --git a/src/Makefile b/src/Makefile index e52f4f08d3..9e4075660d 100644 --- a/src/Makefile +++ b/src/Makefile @@ -374,7 +374,7 @@ else endef endif -# Determine install/uninstall Redis symlinks for compatibility when +# Determine install/uninstall Redis symlinks for compatibility when # installing/uninstalling Valkey binaries (defaulting to `yes`) USE_REDIS_SYMLINKS?=yes ifeq ($(USE_REDIS_SYMLINKS),yes) @@ -416,7 +416,7 @@ endif ENGINE_NAME=valkey SERVER_NAME=$(ENGINE_NAME)-server$(PROG_SUFFIX) ENGINE_SENTINEL_NAME=$(ENGINE_NAME)-sentinel$(PROG_SUFFIX) -ENGINE_SERVER_OBJ=threads_mngr.o adlist.o quicklist.o ae.o anet.o dict.o hashtable.o kvstore.o server.o sds.o zmalloc.o lzf_c.o lzf_d.o pqsort.o zipmap.o sha1.o ziplist.o release.o memory_prefetch.o io_threads.o networking.o util.o object.o db.o replication.o rdb.o t_string.o t_list.o t_set.o t_zset.o t_hash.o config.o aof.o pubsub.o multi.o debug.o sort.o intset.o syncio.o cluster.o cluster_legacy.o cluster_slot_stats.o crc16.o endianconv.o slowlog.o eval.o bio.o rio.o rand.o memtest.o syscheck.o crcspeed.o crccombine.o crc64.o bitops.o sentinel.o notify.o setproctitle.o blocked.o hyperloglog.o latency.o sparkline.o valkey-check-rdb.o valkey-check-aof.o geo.o lazyfree.o module.o evict.o expire.o geohash.o geohash_helper.o childinfo.o allocator_defrag.o defrag.o siphash.o rax.o t_stream.o listpack.o localtime.o lolwut.o lolwut5.o lolwut6.o acl.o tracking.o socket.o tls.o sha256.o timeout.o setcpuaffinity.o monotonic.o mt19937-64.o resp_parser.o call_reply.o script_lua.o script.o functions.o function_lua.o commands.o strl.o connection.o unix.o logreqres.o rdma.o +ENGINE_SERVER_OBJ=threads_mngr.o adlist.o quicklist.o ae.o anet.o dict.o hashtable.o kvstore.o server.o sds.o zmalloc.o lzf_c.o lzf_d.o pqsort.o zipmap.o sha1.o ziplist.o release.o memory_prefetch.o io_threads.o networking.o util.o object.o db.o replication.o rdb.o t_string.o t_list.o t_set.o t_zset.o t_hash.o config.o aof.o pubsub.o multi.o debug.o sort.o intset.o syncio.o cluster.o cluster_legacy.o cluster_slot_stats.o crc16.o endianconv.o slowlog.o eval.o bio.o rio.o rand.o memtest.o syscheck.o crcspeed.o crccombine.o crc64.o bitops.o sentinel.o notify.o setproctitle.o blocked.o hyperloglog.o latency.o sparkline.o valkey-check-rdb.o valkey-check-aof.o geo.o lazyfree.o module.o evict.o expire.o geohash.o geohash_helper.o childinfo.o allocator_defrag.o defrag.o siphash.o rax.o t_stream.o listpack.o localtime.o lolwut.o lolwut5.o lolwut6.o acl.o tracking.o socket.o tls.o sha256.o timeout.o setcpuaffinity.o monotonic.o mt19937-64.o resp_parser.o call_reply.o script_lua.o script.o functions.o function_lua.o commands.o strl.o connection.o unix.o logreqres.o rdma.o scripting_engine.o ENGINE_CLI_NAME=$(ENGINE_NAME)-cli$(PROG_SUFFIX) ENGINE_CLI_OBJ=anet.o adlist.o dict.o valkey-cli.o zmalloc.o release.o ae.o serverassert.o crcspeed.o crccombine.o crc64.o siphash.o crc16.o monotonic.o cli_common.o mt19937-64.o strl.o cli_commands.o ENGINE_BENCHMARK_NAME=$(ENGINE_NAME)-benchmark$(PROG_SUFFIX) diff --git a/src/function_lua.c b/src/function_lua.c index b535528906..59c16eae54 100644 --- a/src/function_lua.c +++ b/src/function_lua.c @@ -39,6 +39,7 @@ * Uses script_lua.c to run the Lua code. */ +#include "scripting_engine.h" #include "functions.h" #include "script_lua.h" #include @@ -121,7 +122,7 @@ static compiledFunction **luaEngineCreate(ValkeyModuleCtx *module_ctx, const char *code, size_t timeout, size_t *out_num_compiled_functions, - char **err) { + robj **err) { /* The lua engine is implemented in the core, and not in a Valkey Module */ serverAssert(module_ctx == NULL); @@ -139,7 +140,8 @@ static compiledFunction **luaEngineCreate(ValkeyModuleCtx *module_ctx, /* compile the code */ if (luaL_loadbuffer(lua, code, strlen(code), "@user_function")) { - *err = valkey_asprintf("Error compiling function: %s", lua_tostring(lua, -1)); + sds error = sdscatfmt(sdsempty(), "Error compiling function: %s", lua_tostring(lua, -1)); + *err = createObject(OBJ_STRING, error); lua_pop(lua, 1); /* pops the error */ goto done; } @@ -157,7 +159,8 @@ static compiledFunction **luaEngineCreate(ValkeyModuleCtx *module_ctx, if (lua_pcall(lua, 0, 0, 0)) { errorInfo err_info = {0}; luaExtractErrorInformation(lua, &err_info); - *err = valkey_asprintf("Error registering functions: %s", err_info.msg); + sds error = sdscatfmt(sdsempty(), "Error registering functions: %s", err_info.msg); + *err = createObject(OBJ_STRING, error); lua_pop(lua, 1); /* pops the error */ luaErrorInformationDiscard(&err_info); listIter *iter = listGetIterator(load_ctx.functions, AL_START_HEAD); @@ -557,8 +560,8 @@ int luaEngineInitEngine(void) { .get_memory_info = luaEngineGetMemoryInfo, }; - return functionsRegisterEngine(LUA_ENGINE_NAME, - NULL, - lua_engine_ctx, - &lua_engine_methods); + return scriptingEngineManagerRegister(LUA_ENGINE_NAME, + NULL, + lua_engine_ctx, + &lua_engine_methods); } diff --git a/src/functions.c b/src/functions.c index 0d003f7fac..14d8c5296e 100644 --- a/src/functions.c +++ b/src/functions.c @@ -31,7 +31,6 @@ #include "sds.h" #include "dict.h" #include "adlist.h" -#include "module.h" #define LOAD_TIMEOUT_MS 500 @@ -41,8 +40,6 @@ typedef enum { restorePolicy_Replace } restorePolicy; -static size_t engine_cache_memory = 0; - /* Forward declaration */ static void engineFunctionDispose(void *obj); static void engineStatsDispose(void *obj); @@ -67,15 +64,6 @@ typedef struct functionsLibMetaData { sds code; } functionsLibMetaData; -dictType engineDictType = { - dictSdsCaseHash, /* hash function */ - dictSdsDup, /* key dup */ - dictSdsKeyCaseCompare, /* key compare */ - dictSdsDestructor, /* key destructor */ - NULL, /* val destructor */ - NULL /* allow to expand */ -}; - dictType functionDictType = { dictSdsCaseHash, /* hash function */ dictSdsDup, /* key dup */ @@ -112,34 +100,14 @@ dictType librariesDictType = { NULL /* allow to expand */ }; -/* Dictionary of engines */ -static dict *engines = NULL; - /* Libraries Ctx. */ static functionsLibCtx *curr_functions_lib_ctx = NULL; -static void setupEngineModuleCtx(engineInfo *ei, client *c) { - if (ei->engineModule != NULL) { - serverAssert(ei->module_ctx != NULL); - moduleScriptingEngineInitContext(ei->module_ctx, ei->engineModule, c); - } -} - -static void teardownEngineModuleCtx(engineInfo *ei) { - if (ei->engineModule != NULL) { - serverAssert(ei->module_ctx != NULL); - moduleFreeContext(ei->module_ctx); - } -} - static size_t functionMallocSize(functionInfo *fi) { - setupEngineModuleCtx(fi->li->ei, NULL); - size_t size = zmalloc_size(fi) + - sdsAllocSize(fi->name) + - (fi->desc ? sdsAllocSize(fi->desc) : 0) + - fi->li->ei->engine->get_function_memory_overhead(fi->li->ei->module_ctx, fi->function); - teardownEngineModuleCtx(fi->li->ei); - return size; + return zmalloc_size(fi) + + sdsAllocSize(fi->name) + + (fi->desc ? sdsAllocSize(fi->desc) : 0) + + scriptingEngineCallGetFunctionMemoryOverhead(fi->li->engine, fi->function); } static size_t libraryMallocSize(functionLibInfo *li) { @@ -161,12 +129,8 @@ static void engineFunctionDispose(void *obj) { if (fi->desc) { sdsfree(fi->desc); } - setupEngineModuleCtx(fi->li->ei, NULL); - engine *engine = fi->li->ei->engine; - engine->free_function(fi->li->ei->module_ctx, - engine->engine_ctx, - fi->function); - teardownEngineModuleCtx(fi->li->ei); + + scriptingEngineCallFreeFunction(fi->li->engine, fi->function); zfree(fi); } @@ -239,30 +203,30 @@ functionsLibCtx *functionsLibCtxGetCurrent(void) { return curr_functions_lib_ctx; } +static void initializeFunctionsLibEngineStats(scriptingEngine *engine, + void *context) { + functionsLibCtx *lib_ctx = (functionsLibCtx *)context; + functionsLibEngineStats *stats = zcalloc(sizeof(*stats)); + dictAdd(lib_ctx->engines_stats, scriptingEngineGetName(engine), stats); +} + /* Create a new functions ctx */ functionsLibCtx *functionsLibCtxCreate(void) { functionsLibCtx *ret = zmalloc(sizeof(functionsLibCtx)); ret->libraries = dictCreate(&librariesDictType); ret->functions = dictCreate(&functionDictType); ret->engines_stats = dictCreate(&engineStatsDictType); - dictIterator *iter = dictGetIterator(engines); - dictEntry *entry = NULL; - while ((entry = dictNext(iter))) { - engineInfo *ei = dictGetVal(entry); - functionsLibEngineStats *stats = zcalloc(sizeof(*stats)); - dictAdd(ret->engines_stats, ei->name, stats); - } - dictReleaseIterator(iter); + scriptingEngineManagerForEachEngine(initializeFunctionsLibEngineStats, ret); ret->cache_memory = 0; return ret; } -void functionsAddEngineStats(engineInfo *ei) { +void functionsAddEngineStats(sds engine_name) { serverAssert(curr_functions_lib_ctx != NULL); - dictEntry *entry = dictFind(curr_functions_lib_ctx->engines_stats, ei->name); + dictEntry *entry = dictFind(curr_functions_lib_ctx->engines_stats, engine_name); if (entry == NULL) { functionsLibEngineStats *stats = zcalloc(sizeof(*stats)); - dictAdd(curr_functions_lib_ctx->engines_stats, ei->name, stats); + dictAdd(curr_functions_lib_ctx->engines_stats, engine_name, stats); } } @@ -312,12 +276,12 @@ static int functionLibCreateFunction(robj *name, return C_OK; } -static functionLibInfo *engineLibraryCreate(sds name, engineInfo *ei, sds code) { +static functionLibInfo *engineLibraryCreate(sds name, scriptingEngine *e, sds code) { functionLibInfo *li = zmalloc(sizeof(*li)); *li = (functionLibInfo){ .name = sdsdup(name), .functions = dictCreate(&libraryFunctionDictType), - .ei = ei, + .engine = e, .code = sdsdup(code), }; return li; @@ -339,7 +303,7 @@ static void libraryUnlink(functionsLibCtx *lib_ctx, functionLibInfo *li) { lib_ctx->cache_memory -= libraryMallocSize(li); /* update stats */ - functionsLibEngineStats *stats = dictFetchValue(lib_ctx->engines_stats, li->ei->name); + functionsLibEngineStats *stats = dictFetchValue(lib_ctx->engines_stats, scriptingEngineGetName(li->engine)); serverAssert(stats); stats->n_lib--; stats->n_functions -= dictSize(li->functions); @@ -359,7 +323,7 @@ static void libraryLink(functionsLibCtx *lib_ctx, functionLibInfo *li) { lib_ctx->cache_memory += libraryMallocSize(li); /* update stats */ - functionsLibEngineStats *stats = dictFetchValue(lib_ctx->engines_stats, li->ei->name); + functionsLibEngineStats *stats = dictFetchValue(lib_ctx->engines_stats, scriptingEngineGetName(li->engine)); serverAssert(stats); stats->n_lib++; stats->n_functions += dictSize(li->functions); @@ -446,107 +410,29 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l return ret; } -/* Register an engine, should be called once by the engine on startup and give - * the following: - * - * - engine_name - name of the engine to register - * - * - engine_module - the valkey module that implements this engine - * - * - engine_ctx - the engine ctx that should be used by the server to interact - * with the engine. - * - * - engine_methods - the struct with the scripting engine callback functions - * pointers. - * - */ -int functionsRegisterEngine(const char *engine_name, - ValkeyModule *engine_module, - engineCtx *engine_ctx, - engineMethods *engine_methods) { - sds engine_name_sds = sdsnew(engine_name); - if (dictFetchValue(engines, engine_name_sds)) { - serverLog(LL_WARNING, "Same engine was registered twice"); - sdsfree(engine_name_sds); - return C_ERR; - } - - engine *eng = zmalloc(sizeof(engine)); - *eng = (engine){ - .engine_ctx = engine_ctx, - .create = engine_methods->create_functions_library, - .call = engine_methods->call_function, - .get_function_memory_overhead = engine_methods->get_function_memory_overhead, - .free_function = engine_methods->free_function, - .get_memory_info = engine_methods->get_memory_info, - }; - - client *c = createClient(NULL); - c->flag.deny_blocking = 1; - c->flag.script = 1; - c->flag.fake = 1; - engineInfo *ei = zmalloc(sizeof(*ei)); - *ei = (engineInfo){ - .name = engine_name_sds, - .engineModule = engine_module, - .module_ctx = engine_module ? moduleAllocateContext() : NULL, - .engine = eng, - .c = c, - }; - - dictAdd(engines, engine_name_sds, ei); - - functionsAddEngineStats(ei); - - setupEngineModuleCtx(ei, NULL); - engineMemoryInfo mem_info = eng->get_memory_info(ei->module_ctx, - eng->engine_ctx); - engine_cache_memory += zmalloc_size(ei) + - sdsAllocSize(ei->name) + - zmalloc_size(eng) + - mem_info.engine_memory_overhead; - - teardownEngineModuleCtx(ei); - - return C_OK; +static void replyEngineStats(scriptingEngine *engine, void *context) { + client *c = (client *)context; + addReplyBulkCString(c, scriptingEngineGetName(engine)); + addReplyMapLen(c, 2); + functionsLibEngineStats *e_stats = + dictFetchValue(curr_functions_lib_ctx->engines_stats, scriptingEngineGetName(engine)); + addReplyBulkCString(c, "libraries_count"); + addReplyLongLong(c, e_stats ? e_stats->n_lib : 0); + addReplyBulkCString(c, "functions_count"); + addReplyLongLong(c, e_stats ? e_stats->n_functions : 0); } -/* Removes a scripting engine from the server. - * - * - engine_name - name of the engine to remove - */ -int functionsUnregisterEngine(const char *engine_name) { - sds engine_name_sds = sdsnew(engine_name); - dictEntry *entry = dictFind(engines, engine_name_sds); - if (entry == NULL) { - serverLog(LL_WARNING, "There's no engine registered with name %s", engine_name); - sdsfree(engine_name_sds); - return C_ERR; - } - - engineInfo *ei = dictGetVal(entry); - +void functionsRemoveLibFromEngine(scriptingEngine *engine) { dictIterator *iter = dictGetSafeIterator(curr_functions_lib_ctx->libraries); + dictEntry *entry = NULL; while ((entry = dictNext(iter))) { functionLibInfo *li = dictGetVal(entry); - if (li->ei == ei) { + if (li->engine == engine) { libraryUnlink(curr_functions_lib_ctx, li); engineLibraryFree(li); } } dictReleaseIterator(iter); - - zfree(ei->engine); - sdsfree(ei->name); - freeClient(ei->c); - if (ei->engineModule != NULL) { - serverAssert(ei->module_ctx != NULL); - zfree(ei->module_ctx); - } - zfree(ei); - - sdsfree(engine_name_sds); - return C_OK; } /* @@ -578,20 +464,8 @@ void functionStatsCommand(client *c) { } addReplyBulkCString(c, "engines"); - addReplyMapLen(c, dictSize(engines)); - dictIterator *iter = dictGetIterator(engines); - dictEntry *entry = NULL; - while ((entry = dictNext(iter))) { - engineInfo *ei = dictGetVal(entry); - addReplyBulkCString(c, ei->name); - addReplyMapLen(c, 2); - functionsLibEngineStats *e_stats = dictFetchValue(curr_functions_lib_ctx->engines_stats, ei->name); - addReplyBulkCString(c, "libraries_count"); - addReplyLongLong(c, e_stats->n_lib); - addReplyBulkCString(c, "functions_count"); - addReplyLongLong(c, e_stats->n_functions); - } - dictReleaseIterator(iter); + addReplyMapLen(c, scriptingEngineManagerGetNumEngines()); + scriptingEngineManagerForEachEngine(replyEngineStats, c); } static void functionListReplyFlags(client *c, functionInfo *fi) { @@ -667,7 +541,8 @@ void functionListCommand(client *c) { addReplyBulkCString(c, "library_name"); addReplyBulkCBuffer(c, li->name, sdslen(li->name)); addReplyBulkCString(c, "engine"); - addReplyBulkCBuffer(c, li->ei->name, sdslen(li->ei->name)); + sds engine_name = scriptingEngineGetName(li->engine); + addReplyBulkCBuffer(c, engine_name, sdslen(engine_name)); addReplyBulkCString(c, "functions"); addReplyArrayLen(c, dictSize(li->functions)); @@ -747,7 +622,7 @@ static void fcallCommandGeneric(client *c, int ro) { return; } functionInfo *fi = dictGetVal(de); - engine *engine = fi->li->ei->engine; + scriptingEngine *engine = fi->li->engine; long long numkeys; /* Get the number of arguments that are keys */ @@ -764,19 +639,16 @@ static void fcallCommandGeneric(client *c, int ro) { } scriptRunCtx run_ctx; - if (scriptPrepareForRun(&run_ctx, fi->li->ei->c, c, fi->name, fi->f_flags, ro) != C_OK) return; - setupEngineModuleCtx(fi->li->ei, run_ctx.original_client); - - engine->call(fi->li->ei->module_ctx, - engine->engine_ctx, - &run_ctx, - fi->function, - c->argv + 3, - numkeys, - c->argv + 3 + numkeys, - c->argc - 3 - numkeys); - - teardownEngineModuleCtx(fi->li->ei); + if (scriptPrepareForRun(&run_ctx, scriptingEngineGetClient(engine), c, fi->name, fi->f_flags, ro) != C_OK) return; + + scriptingEngineCallFunction(engine, + &run_ctx, + run_ctx.original_client, + fi->function, + c->argv + 3, + numkeys, + c->argv + 3 + numkeys, + c->argc - 3 - numkeys); scriptResetRun(&run_ctx); } @@ -1076,12 +948,10 @@ void functionFreeLibMetaData(functionsLibMetaData *md) { if (md->engine) sdsfree(md->engine); } -static void freeCompiledFunctions(engineInfo *ei, +static void freeCompiledFunctions(scriptingEngine *engine, compiledFunction **compiled_functions, size_t num_compiled_functions, size_t free_function_from_idx) { - setupEngineModuleCtx(ei, NULL); - for (size_t i = 0; i < num_compiled_functions; i++) { compiledFunction *func = compiled_functions[i]; decrRefCount(func->name); @@ -1089,16 +959,12 @@ static void freeCompiledFunctions(engineInfo *ei, decrRefCount(func->desc); } if (i >= free_function_from_idx) { - ei->engine->free_function(ei->module_ctx, - ei->engine->engine_ctx, - func->function); + scriptingEngineCallFreeFunction(engine, func->function); } zfree(func); } zfree(compiled_functions); - - teardownEngineModuleCtx(ei); } /* Compile and save the given library, return the loaded library name on success @@ -1120,12 +986,13 @@ sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibC goto error; } - engineInfo *ei = dictFetchValue(engines, md.engine); - if (!ei) { + scriptingEngine *engine = scriptingEngineManagerFind(md.engine); + if (!engine) { *err = sdscatfmt(sdsempty(), "Engine '%S' not found", md.engine); goto error; } - engine *engine = ei->engine; + + functionsAddEngineStats(md.engine); old_li = dictFetchValue(lib_ctx->libraries, md.name); if (old_li && !replace) { @@ -1138,26 +1005,25 @@ sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibC libraryUnlink(lib_ctx, old_li); } - new_li = engineLibraryCreate(md.name, ei, code); + new_li = engineLibraryCreate(md.name, engine, code); size_t num_compiled_functions = 0; - char *compile_error = NULL; - setupEngineModuleCtx(ei, NULL); + robj *compile_error = NULL; compiledFunction **compiled_functions = - engine->create(ei->module_ctx, - engine->engine_ctx, - md.code, - timeout, - &num_compiled_functions, - &compile_error); - teardownEngineModuleCtx(ei); + scriptingEngineCallCreateFunctionsLibrary(engine, + md.code, + timeout, + &num_compiled_functions, + &compile_error); if (compiled_functions == NULL) { serverAssert(num_compiled_functions == 0); serverAssert(compile_error != NULL); - *err = sdsnew(compile_error); - zfree(compile_error); + *err = sdsdup(compile_error->ptr); + decrRefCount(compile_error); goto error; } + serverAssert(compile_error == NULL); + for (size_t i = 0; i < num_compiled_functions; i++) { compiledFunction *func = compiled_functions[i]; int ret = functionLibCreateFunction(func->name, @@ -1167,7 +1033,7 @@ sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibC func->f_flags, err); if (ret == C_ERR) { - freeCompiledFunctions(ei, + freeCompiledFunctions(engine, compiled_functions, num_compiled_functions, i); @@ -1175,7 +1041,7 @@ sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibC } } - freeCompiledFunctions(ei, + freeCompiledFunctions(engine, compiled_functions, num_compiled_functions, num_compiled_functions); @@ -1259,32 +1125,26 @@ void functionLoadCommand(client *c) { addReplyBulkSds(c, library_name); } +static void getEngineUsedMemory(scriptingEngine *engine, void *context) { + size_t *engines_memory = (size_t *)context; + engineMemoryInfo mem_info = scriptingEngineCallGetMemoryInfo(engine); + *engines_memory += mem_info.used_memory; +} + /* Return memory usage of all the engines combine */ unsigned long functionsMemory(void) { - dictIterator *iter = dictGetIterator(engines); - dictEntry *entry = NULL; size_t engines_memory = 0; - while ((entry = dictNext(iter))) { - engineInfo *ei = dictGetVal(entry); - engine *engine = ei->engine; - setupEngineModuleCtx(ei, NULL); - engineMemoryInfo mem_info = engine->get_memory_info(ei->module_ctx, - engine->engine_ctx); - engines_memory += mem_info.used_memory; - teardownEngineModuleCtx(ei); - } - dictReleaseIterator(iter); - + scriptingEngineManagerForEachEngine(getEngineUsedMemory, &engines_memory); return engines_memory; } /* Return memory overhead of all the engines combine */ unsigned long functionsMemoryOverhead(void) { - size_t memory_overhead = dictMemUsage(engines); + size_t memory_overhead = scriptingEngineManagerGetMemoryUsage(); memory_overhead += dictMemUsage(curr_functions_lib_ctx->functions); memory_overhead += sizeof(functionsLibCtx); memory_overhead += curr_functions_lib_ctx->cache_memory; - memory_overhead += engine_cache_memory; + memory_overhead += scriptingEngineManagerGetTotalMemoryOverhead(); return memory_overhead; } @@ -1309,8 +1169,6 @@ size_t functionsLibCtxFunctionsLen(functionsLibCtx *functions_ctx) { /* Initialize engine data structures. * Should be called once on server initialization */ int functionsInit(void) { - engines = dictCreate(&engineDictType); - curr_functions_lib_ctx = functionsLibCtxCreate(); if (luaEngineInitEngine() != C_OK) { diff --git a/src/functions.h b/src/functions.h index a48ff1b8db..7f6d144365 100644 --- a/src/functions.h +++ b/src/functions.h @@ -49,73 +49,19 @@ */ #include "server.h" +#include "scripting_engine.h" #include "script.h" #include "valkeymodule.h" typedef struct functionLibInfo functionLibInfo; -/* ValkeyModule type aliases for scripting engine structs and types. */ -typedef struct ValkeyModule ValkeyModule; -typedef ValkeyModuleScriptingEngineCtx engineCtx; -typedef ValkeyModuleScriptingEngineFunctionCtx functionCtx; -typedef ValkeyModuleScriptingEngineCompiledFunction compiledFunction; -typedef ValkeyModuleScriptingEngineMemoryInfo engineMemoryInfo; -typedef ValkeyModuleScriptingEngineMethods engineMethods; - -typedef struct engine { - /* engine specific context */ - engineCtx *engine_ctx; - - /* Compiles the script code and returns an array of compiled functions - * registered in the script./ - * - * Returns NULL on error and set err to be the error message */ - compiledFunction **(*create)( - ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx, - const char *code, - size_t timeout, - size_t *out_num_compiled_functions, - char **err); - - /* Invoking a function, func_ctx is an opaque object (from engine POV). - * The func_ctx should be used by the engine to interaction with the server, - * such interaction could be running commands, set resp, or set - * replication mode - */ - void (*call)(ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx, - functionCtx *func_ctx, - void *compiled_function, - robj **keys, - size_t nkeys, - robj **args, - size_t nargs); - - /* free the given function */ - void (*free_function)(ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx, - void *compiled_function); - - /* Return memory overhead for a given function, - * such memory is not counted as engine memory but as general - * structs memory that hold different information */ - size_t (*get_function_memory_overhead)(ValkeyModuleCtx *module_ctx, - void *compiled_function); - - /* Get the current used memory by the engine */ - engineMemoryInfo (*get_memory_info)(ValkeyModuleCtx *module_ctx, - engineCtx *engine_ctx); - -} engine; - /* Hold information about an engine. * Used on rdb.c so it must be declared here. */ typedef struct engineInfo { sds name; /* Name of the engine */ ValkeyModule *engineModule; /* the module that implements the scripting engine */ ValkeyModuleCtx *module_ctx; /* Scripting engine module context */ - engine *engine; /* engine callbacks that allows to interact with the engine */ + scriptingEngine *engine; /* engine callbacks that allows to interact with the engine */ client *c; /* Client that is used to run commands */ } engineInfo; @@ -133,18 +79,12 @@ typedef struct functionInfo { /* Hold information about the specific library. * Used on rdb.c so it must be declared here. */ struct functionLibInfo { - sds name; /* Library name */ - dict *functions; /* Functions dictionary */ - engineInfo *ei; /* Pointer to the function engine */ - sds code; /* Library code */ + sds name; /* Library name */ + dict *functions; /* Functions dictionary */ + scriptingEngine *engine; /* Pointer to the scripting engine */ + sds code; /* Library code */ }; -int functionsRegisterEngine(const char *engine_name, - ValkeyModule *engine_module, - void *engine_ctx, - engineMethods *engine_methods); -int functionsUnregisterEngine(const char *engine_name); - sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibCtx *lib_ctx, size_t timeout); unsigned long functionsMemory(void); unsigned long functionsMemoryOverhead(void); @@ -159,6 +99,8 @@ void functionsLibCtxFree(functionsLibCtx *functions_lib_ctx); void functionsLibCtxClear(functionsLibCtx *lib_ctx, void(callback)(dict *)); void functionsLibCtxSwapWithCurrent(functionsLibCtx *new_lib_ctx, int async); +void functionsRemoveLibFromEngine(scriptingEngine *engine); + int luaEngineInitEngine(void); int functionsInit(void); diff --git a/src/module.c b/src/module.c index fa60335837..75dcd81cd6 100644 --- a/src/module.c +++ b/src/module.c @@ -62,8 +62,8 @@ #include "crc16_slottable.h" #include "valkeymodule.h" #include "io_threads.h" -#include "functions.h" #include "module.h" +#include "scripting_engine.h" #include #include #include @@ -13165,10 +13165,10 @@ int VM_RegisterScriptingEngine(ValkeyModuleCtx *module_ctx, return VALKEYMODULE_ERR; } - if (functionsRegisterEngine(engine_name, - module_ctx->module, - engine_ctx, - engine_methods) != C_OK) { + if (scriptingEngineManagerRegister(engine_name, + module_ctx->module, + engine_ctx, + engine_methods) != C_OK) { return VALKEYMODULE_ERR; } @@ -13184,7 +13184,9 @@ int VM_RegisterScriptingEngine(ValkeyModuleCtx *module_ctx, */ int VM_UnregisterScriptingEngine(ValkeyModuleCtx *ctx, const char *engine_name) { UNUSED(ctx); - functionsUnregisterEngine(engine_name); + if (scriptingEngineManagerUnregister(engine_name) != C_OK) { + return VALKEYMODULE_ERR; + } return VALKEYMODULE_OK; } diff --git a/src/scripting_engine.c b/src/scripting_engine.c new file mode 100644 index 0000000000..9488f5ef93 --- /dev/null +++ b/src/scripting_engine.c @@ -0,0 +1,284 @@ +#include "scripting_engine.h" +#include "dict.h" +#include "functions.h" +#include "module.h" + +typedef struct scriptingEngineImpl { + /* Engine specific context */ + engineCtx *ctx; + + /* Callback functions implemented by the scripting engine module */ + engineMethods methods; +} scriptingEngineImpl; + +typedef struct scriptingEngine { + sds name; /* Name of the engine */ + ValkeyModule *module; /* the module that implements the scripting engine */ + scriptingEngineImpl impl; /* engine context and callbacks to interact with the engine */ + client *c; /* Client that is used to run commands */ + ValkeyModuleCtx *module_ctx; /* Cache of the module context object */ +} scriptingEngine; + + +typedef struct engineManger { + dict *engines; /* engines dictionary */ + size_t total_memory_overhead; /* the sum of the memory overhead of all registered scripting engines */ +} engineManager; + + +static engineManager engineMgr = { + .engines = NULL, + .total_memory_overhead = 0, +}; + +static uint64_t dictStrCaseHash(const void *key) { + return dictGenCaseHashFunction((unsigned char *)key, strlen((char *)key)); +} + +dictType engineDictType = { + dictStrCaseHash, /* hash function */ + NULL, /* key dup */ + dictSdsKeyCaseCompare, /* key compare */ + NULL, /* key destructor */ + NULL, /* val destructor */ + NULL /* allow to expand */ +}; + +/* Initializes the scripting engine manager. + * The engine manager is responsible for managing the several scripting engines + * that are loaded in the server and implemented by Valkey Modules. + * + * Returns C_ERR if some error occurs during the initialization. + */ +int scriptingEngineManagerInit(void) { + engineMgr.engines = dictCreate(&engineDictType); + return C_OK; +} + +/* Returns the amount of memory overhead consumed by all registered scripting + engines. */ +size_t scriptingEngineManagerGetTotalMemoryOverhead(void) { + return engineMgr.total_memory_overhead; +} + +size_t scriptingEngineManagerGetNumEngines(void) { + return dictSize(engineMgr.engines); +} + +size_t scriptingEngineManagerGetMemoryUsage(void) { + return dictMemUsage(engineMgr.engines) + sizeof(engineMgr); +} + +/* Registers a new scripting engine in the engine manager. + * + * - `engine_name`: the name of the scripting engine. This name will match + * against the engine name specified in the script header using a shebang. + * + * - `ctx`: engine specific context pointer. + * + * - engine_methods - the struct with the scripting engine callback functions + * pointers. + * + * Returns C_ERR in case of an error during registration. + */ +int scriptingEngineManagerRegister(const char *engine_name, + ValkeyModule *engine_module, + engineCtx *engine_ctx, + engineMethods *engine_methods) { + sds engine_name_sds = sdsnew(engine_name); + + if (dictFetchValue(engineMgr.engines, engine_name_sds)) { + serverLog(LL_WARNING, "Scripting engine '%s' is already registered in the server", engine_name_sds); + sdsfree(engine_name_sds); + return C_ERR; + } + + client *c = createClient(NULL); + c->flag.deny_blocking = 1; + c->flag.script = 1; + c->flag.fake = 1; + + scriptingEngine *e = zmalloc(sizeof(*e)); + *e = (scriptingEngine){ + .name = engine_name_sds, + .module = engine_module, + .impl = { + .ctx = engine_ctx, + .methods = { + .create_functions_library = engine_methods->create_functions_library, + .call_function = engine_methods->call_function, + .free_function = engine_methods->free_function, + .get_function_memory_overhead = engine_methods->get_function_memory_overhead, + .get_memory_info = engine_methods->get_memory_info, + }, + }, + .c = c, + .module_ctx = engine_module ? moduleAllocateContext() : NULL, + }; + + dictAdd(engineMgr.engines, engine_name_sds, e); + + engineMemoryInfo mem_info = scriptingEngineCallGetMemoryInfo(e); + engineMgr.total_memory_overhead += zmalloc_size(e) + + sdsAllocSize(e->name) + + mem_info.engine_memory_overhead; + + return C_OK; +} + +/* Removes a scripting engine from the engine manager. + * + * - `engine_name`: name of the engine to remove + */ +int scriptingEngineManagerUnregister(const char *engine_name) { + dictEntry *entry = dictUnlink(engineMgr.engines, engine_name); + if (entry == NULL) { + serverLog(LL_WARNING, "There's no engine registered with name %s", engine_name); + return C_ERR; + } + + scriptingEngine *e = dictGetVal(entry); + + functionsRemoveLibFromEngine(e); + + engineMemoryInfo mem_info = scriptingEngineCallGetMemoryInfo(e); + engineMgr.total_memory_overhead -= zmalloc_size(e) + + sdsAllocSize(e->name) + + mem_info.engine_memory_overhead; + + sdsfree(e->name); + freeClient(e->c); + if (e->module_ctx) { + serverAssert(e->module != NULL); + zfree(e->module_ctx); + } + zfree(e); + + dictFreeUnlinkedEntry(engineMgr.engines, entry); + + return C_OK; +} + +/* + * Lookups the engine with `engine_name` in the engine manager and returns it if + * it exists. Otherwise returns `NULL`. + */ +scriptingEngine *scriptingEngineManagerFind(sds engine_name) { + dictEntry *entry = dictFind(engineMgr.engines, engine_name); + if (entry) { + return dictGetVal(entry); + } + return NULL; +} + +sds scriptingEngineGetName(scriptingEngine *engine) { + return engine->name; +} + +client *scriptingEngineGetClient(scriptingEngine *engine) { + return engine->c; +} + +ValkeyModule *scriptingEngineGetModule(scriptingEngine *engine) { + return engine->module; +} + +/* + * Iterates the list of engines registered in the engine manager and calls the + * callback function with each engine. + * + * The `context` pointer is also passed in each callback call. + */ +void scriptingEngineManagerForEachEngine(engineIterCallback callback, + void *context) { + dictIterator *iter = dictGetIterator(engineMgr.engines); + dictEntry *entry = NULL; + while ((entry = dictNext(iter))) { + scriptingEngine *e = dictGetVal(entry); + callback(e, context); + } + dictReleaseIterator(iter); +} + +static void engineSetupModuleCtx(scriptingEngine *e, client *c) { + if (e->module != NULL) { + serverAssert(e->module_ctx != NULL); + moduleScriptingEngineInitContext(e->module_ctx, e->module, c); + } +} + +static void engineTeardownModuleCtx(scriptingEngine *e) { + if (e->module != NULL) { + serverAssert(e->module_ctx != NULL); + moduleFreeContext(e->module_ctx); + } +} + +compiledFunction **scriptingEngineCallCreateFunctionsLibrary(scriptingEngine *engine, + const char *code, + size_t timeout, + size_t *out_num_compiled_functions, + robj **err) { + engineSetupModuleCtx(engine, NULL); + + compiledFunction **functions = engine->impl.methods.create_functions_library( + engine->module_ctx, + engine->impl.ctx, + code, + timeout, + out_num_compiled_functions, + err); + + engineTeardownModuleCtx(engine); + + return functions; +} + +void scriptingEngineCallFunction(scriptingEngine *engine, + functionCtx *func_ctx, + client *caller, + void *compiled_function, + robj **keys, + size_t nkeys, + robj **args, + size_t nargs) { + engineSetupModuleCtx(engine, caller); + + engine->impl.methods.call_function( + engine->module_ctx, + engine->impl.ctx, + func_ctx, + compiled_function, + keys, + nkeys, + args, + nargs); + + engineTeardownModuleCtx(engine); +} + +void scriptingEngineCallFreeFunction(scriptingEngine *engine, + void *compiled_func) { + engineSetupModuleCtx(engine, NULL); + engine->impl.methods.free_function(engine->module_ctx, + engine->impl.ctx, + compiled_func); + engineTeardownModuleCtx(engine); +} + +size_t scriptingEngineCallGetFunctionMemoryOverhead(scriptingEngine *engine, + void *compiled_function) { + engineSetupModuleCtx(engine, NULL); + size_t mem = engine->impl.methods.get_function_memory_overhead( + engine->module_ctx, compiled_function); + engineTeardownModuleCtx(engine); + return mem; +} + +engineMemoryInfo scriptingEngineCallGetMemoryInfo(scriptingEngine *engine) { + engineSetupModuleCtx(engine, NULL); + engineMemoryInfo mem_info = engine->impl.methods.get_memory_info( + engine->module_ctx, engine->impl.ctx); + engineTeardownModuleCtx(engine); + return mem_info; +} diff --git a/src/scripting_engine.h b/src/scripting_engine.h new file mode 100644 index 0000000000..0ed49e6f88 --- /dev/null +++ b/src/scripting_engine.h @@ -0,0 +1,73 @@ +#ifndef _SCRIPTING_ENGINE_H_ +#define _SCRIPTING_ENGINE_H_ + +#include "server.h" + +// Forward declaration of the engine structure. +typedef struct scriptingEngine scriptingEngine; + +/* ValkeyModule type aliases for scripting engine structs and types. */ +typedef struct ValkeyModule ValkeyModule; +typedef ValkeyModuleScriptingEngineCtx engineCtx; +typedef ValkeyModuleScriptingEngineFunctionCtx functionCtx; +typedef ValkeyModuleScriptingEngineCompiledFunction compiledFunction; +typedef ValkeyModuleScriptingEngineMemoryInfo engineMemoryInfo; +typedef ValkeyModuleScriptingEngineMethods engineMethods; + +/* + * Callback function used to iterate the list of engines registered in the + * engine manager. + * + * - `engine`: the scripting engine in the current iteration. + * + * - `context`: a generic pointer to a context object. + * + */ +typedef void (*engineIterCallback)(scriptingEngine *engine, void *context); + +/* + * Engine manager API functions. + */ +int scriptingEngineManagerInit(void); +size_t scriptingEngineManagerGetTotalMemoryOverhead(void); +size_t scriptingEngineManagerGetNumEngines(void); +size_t scriptingEngineManagerGetMemoryUsage(void); +int scriptingEngineManagerRegister(const char *engine_name, + ValkeyModule *engine_module, + engineCtx *engine_ctx, + engineMethods *engine_methods); +int scriptingEngineManagerUnregister(const char *engine_name); +scriptingEngine *scriptingEngineManagerFind(sds engine_name); +void scriptingEngineManagerForEachEngine(engineIterCallback callback, + void *context); + +/* + * Engine API functions. + */ +sds scriptingEngineGetName(scriptingEngine *engine); +client *scriptingEngineGetClient(scriptingEngine *engine); +ValkeyModule *scriptingEngineGetModule(scriptingEngine *engine); + +/* + * API to call engine callback functions. + */ +compiledFunction **scriptingEngineCallCreateFunctionsLibrary(scriptingEngine *engine, + const char *code, + size_t timeout, + size_t *out_num_compiled_functions, + robj **err); +void scriptingEngineCallFunction(scriptingEngine *engine, + functionCtx *func_ctx, + client *caller, + void *compiled_function, + robj **keys, + size_t nkeys, + robj **args, + size_t nargs); +void scriptingEngineCallFreeFunction(scriptingEngine *engine, + void *compiled_func); +size_t scriptingEngineCallGetFunctionMemoryOverhead(scriptingEngine *engine, + void *compiled_function); +engineMemoryInfo scriptingEngineCallGetMemoryInfo(scriptingEngine *engine); + +#endif /* _SCRIPTING_ENGINE_H_ */ diff --git a/src/server.c b/src/server.c index 8255b57e25..144841eff9 100644 --- a/src/server.c +++ b/src/server.c @@ -43,6 +43,7 @@ #include "io_threads.h" #include "sds.h" #include "module.h" +#include "scripting_engine.h" #include #include @@ -2895,12 +2896,15 @@ void initServer(void) { server.maxmemory_policy = MAXMEMORY_NO_EVICTION; } + if (scriptingEngineManagerInit() == C_ERR) { + serverPanic("Scripting engine manager initialization failed, check the server logs."); + } + /* Initialize the LUA scripting engine. */ scriptingInit(1); /* Initialize the functions engine based off of LUA initialization. */ if (functionsInit() == C_ERR) { serverPanic("Functions initialization failed, check the server logs."); - exit(1); } slowlogInit(); latencyMonitorInit(); diff --git a/src/util.c b/src/util.c index 6e44392ce1..ea4f7d72d7 100644 --- a/src/util.c +++ b/src/util.c @@ -1381,23 +1381,3 @@ int snprintf_async_signal_safe(char *to, size_t n, const char *fmt, ...) { va_end(args); return result; } - -/* A printf-like function that returns a freshly allocated string. - * - * This function is similar to asprintf function, but it uses zmalloc for - * allocating the string buffer. */ -char *valkey_asprintf(char const *fmt, ...) { - va_list args; - - va_start(args, fmt); - size_t str_len = vsnprintf(NULL, 0, fmt, args) + 1; - va_end(args); - - char *str = zmalloc(str_len); - - va_start(args, fmt); - vsnprintf(str, str_len, fmt, args); - va_end(args); - - return str; -} diff --git a/src/util.h b/src/util.h index 61095ddb65..51eb38f0b4 100644 --- a/src/util.h +++ b/src/util.h @@ -99,6 +99,5 @@ int snprintf_async_signal_safe(char *to, size_t n, const char *fmt, ...); #endif size_t valkey_strlcpy(char *dst, const char *src, size_t dsize); size_t valkey_strlcat(char *dst, const char *src, size_t dsize); -char *valkey_asprintf(char const *fmt, ...); #endif diff --git a/src/valkeymodule.h b/src/valkeymodule.h index 1d99d2ff7a..c501b373fd 100644 --- a/src/valkeymodule.h +++ b/src/valkeymodule.h @@ -818,7 +818,12 @@ typedef struct ValkeyModuleScriptingEngineCompiledFunction { } ValkeyModuleScriptingEngineCompiledFunction; /* This struct is used to return the memory information of the scripting - * engine. */ + * engine. + * + * IMPORTANT: If we ever need to add/remove fields from this struct, we need + * to bump the version number defined in the + * `VALKEYMODULE_SCRIPTING_ENGINE_ABI_VERSION` constant. + */ typedef struct ValkeyModuleScriptingEngineMemoryInfo { /* The memory used by the scripting engine runtime. */ size_t used_memory; @@ -826,14 +831,55 @@ typedef struct ValkeyModuleScriptingEngineMemoryInfo { size_t engine_memory_overhead; } ValkeyModuleScriptingEngineMemoryInfo; +/* The callback function called when `FUNCTION LOAD` command is called to load + * a library of functions. + * This callback function evaluates the source code passed to `FUNCTION LOAD` + * and registers the functions declared in the source code. + * + * - `engine_ctx`: the engine specific context pointer. + * + * - `code`: string pointer to the source code. + * + * - `timeout`: timeout for the library creation (0 for no timeout). + * + * - `out_num_compiled_functions`: out param with the number of objects + * returned by this function. + * + * - `err` - out param with the description of error (if occurred). + * + * Returns an array of compiled function objects, or `NULL` if some error + * occurred. + */ typedef ValkeyModuleScriptingEngineCompiledFunction **(*ValkeyModuleScriptingEngineCreateFunctionsLibraryFunc)( ValkeyModuleCtx *module_ctx, ValkeyModuleScriptingEngineCtx *engine_ctx, const char *code, size_t timeout, size_t *out_num_compiled_functions, - char **err); + ValkeyModuleString **err); +/* The callback function called when `FCALL` command is called on a function + * registered in the scripting engine. + * This callback function executes the `compiled_function` code. + * + * - `module_ctx`: the module runtime context. + * + * - `engine_ctx`: the engine specific context pointer. + * + * - `func_ctx`: the context opaque structure that represents the runtime + * context for the function. + * + * - `compiled_function`: pointer to the compiled function registered by the + * engine. + * + * - `keys`: the array of key strings passed in the `FCALL` command. + * + * - `nkeys`: the number of elements present in the `keys` array. + * + * - `args`: the array of string arguments passed in the `FCALL` command. + * + * - `nargs`: the number of elements present in the `args` array. + */ typedef void (*ValkeyModuleScriptingEngineCallFunctionFunc)( ValkeyModuleCtx *module_ctx, ValkeyModuleScriptingEngineCtx *engine_ctx, @@ -844,10 +890,15 @@ typedef void (*ValkeyModuleScriptingEngineCallFunctionFunc)( ValkeyModuleString **args, size_t nargs); + +/* Return memory overhead for a given function, such memory is not counted as + * engine memory but as general structs memory that hold different information + */ typedef size_t (*ValkeyModuleScriptingEngineGetFunctionMemoryOverheadFunc)( ValkeyModuleCtx *module_ctx, void *compiled_function); +/* Free the given function */ typedef void (*ValkeyModuleScriptingEngineFreeFunctionFunc)( ValkeyModuleCtx *module_ctx, ValkeyModuleScriptingEngineCtx *engine_ctx, @@ -865,13 +916,13 @@ typedef struct ValkeyModuleScriptingEngineMethodsV1 { * ValkeyModuleScriptingEngineCompiledFunc objects. */ ValkeyModuleScriptingEngineCreateFunctionsLibraryFunc create_functions_library; - /* Function callback to free the memory of a registered engine function. */ - ValkeyModuleScriptingEngineFreeFunctionFunc free_function; - /* The callback function called when `FCALL` command is called on a function * registered in this engine. */ ValkeyModuleScriptingEngineCallFunctionFunc call_function; + /* Function callback to free the memory of a registered engine function. */ + ValkeyModuleScriptingEngineFreeFunctionFunc free_function; + /* Function callback to return memory overhead for a given function. */ ValkeyModuleScriptingEngineGetFunctionMemoryOverheadFunc get_function_memory_overhead; diff --git a/tests/modules/helloscripting.c b/tests/modules/helloscripting.c index c912164bda..5a34e89f68 100644 --- a/tests/modules/helloscripting.c +++ b/tests/modules/helloscripting.c @@ -72,6 +72,7 @@ typedef struct HelloFunc { char *name; HelloInst instructions[256]; uint32_t num_instructions; + uint32_t index; } HelloFunc; /* @@ -151,8 +152,9 @@ static void helloLangParseArgs(HelloFunc *func) { /* * Parses an HELLO program source code. */ -static HelloProgram *helloLangParseCode(const char *code, - HelloProgram *program) { +static int helloLangParseCode(const char *code, + HelloProgram *program, + ValkeyModuleString **err) { char *_code = ValkeyModule_Alloc(sizeof(char) * strlen(code) + 1); strcpy(_code, code); @@ -171,6 +173,7 @@ static HelloProgram *helloLangParseCode(const char *code, ValkeyModule_Assert(currentFunc == NULL); currentFunc = ValkeyModule_Alloc(sizeof(HelloFunc)); memset(currentFunc, 0, sizeof(HelloFunc)); + currentFunc->index = program->num_functions; program->functions[program->num_functions++] = currentFunc; helloLangParseFunction(currentFunc); break; @@ -188,7 +191,9 @@ static HelloProgram *helloLangParseCode(const char *code, currentFunc = NULL; break; default: - ValkeyModule_Assert(0); + *err = ValkeyModule_CreateStringPrintf(NULL, "Failed to parse instruction: '%s'", token); + ValkeyModule_Free(_code); + return -1; } token = strtok(NULL, " \n"); @@ -196,7 +201,7 @@ static HelloProgram *helloLangParseCode(const char *code, ValkeyModule_Free(_code); - return program; + return 0; } /* @@ -223,6 +228,7 @@ static uint32_t executeHelloLangFunction(HelloFunc *func, break; } case RETURN: { + ValkeyModule_Assert(sp > 0); uint32_t val = stack[--sp]; ValkeyModule_Assert(sp == 0); return val; @@ -248,8 +254,10 @@ static ValkeyModuleScriptingEngineMemoryInfo engineGetMemoryInfo(ValkeyModuleCtx for (uint32_t i = 0; i < ctx->program->num_functions; i++) { HelloFunc *func = ctx->program->functions[i]; - mem_info.used_memory += ValkeyModule_MallocSize(func); - mem_info.used_memory += ValkeyModule_MallocSize(func->name); + if (func != NULL) { + mem_info.used_memory += ValkeyModule_MallocSize(func); + mem_info.used_memory += ValkeyModule_MallocSize(func->name); + } } } @@ -273,7 +281,9 @@ static void engineFreeFunction(ValkeyModuleCtx *module_ctx, void *compiled_function) { VALKEYMODULE_NOT_USED(module_ctx); VALKEYMODULE_NOT_USED(engine_ctx); + HelloLangCtx *ctx = (HelloLangCtx *)engine_ctx; HelloFunc *func = (HelloFunc *)compiled_function; + ctx->program->functions[func->index] = NULL; ValkeyModule_Free(func->name); func->name = NULL; ValkeyModule_Free(func); @@ -284,7 +294,7 @@ static ValkeyModuleScriptingEngineCompiledFunction **createHelloLangEngine(Valke const char *code, size_t timeout, size_t *out_num_compiled_functions, - char **err) { + ValkeyModuleString **err) { VALKEYMODULE_NOT_USED(module_ctx); VALKEYMODULE_NOT_USED(timeout); VALKEYMODULE_NOT_USED(err); @@ -298,7 +308,17 @@ static ValkeyModuleScriptingEngineCompiledFunction **createHelloLangEngine(Valke ctx->program->num_functions = 0; } - ctx->program = helloLangParseCode(code, ctx->program); + int ret = helloLangParseCode(code, ctx->program, err); + if (ret < 0) { + for (uint32_t i = 0; i < ctx->program->num_functions; i++) { + HelloFunc *func = ctx->program->functions[i]; + ValkeyModule_Free(func->name); + ValkeyModule_Free(func); + ctx->program->functions[i] = NULL; + } + ctx->program->num_functions = 0; + return NULL; + } ValkeyModuleScriptingEngineCompiledFunction **compiled_functions = ValkeyModule_Alloc(sizeof(ValkeyModuleScriptingEngineCompiledFunction *) * ctx->program->num_functions); @@ -341,7 +361,8 @@ callHelloLangFunction(ValkeyModuleCtx *module_ctx, ValkeyModule_ReplyWithLongLong(module_ctx, result); } -int ValkeyModule_OnLoad(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, +int ValkeyModule_OnLoad(ValkeyModuleCtx *ctx, + ValkeyModuleString **argv, int argc) { VALKEYMODULE_NOT_USED(argv); VALKEYMODULE_NOT_USED(argc); diff --git a/tests/unit/moduleapi/scriptingengine.tcl b/tests/unit/moduleapi/scriptingengine.tcl index c350633dd8..3a37339ea8 100644 --- a/tests/unit/moduleapi/scriptingengine.tcl +++ b/tests/unit/moduleapi/scriptingengine.tcl @@ -51,6 +51,10 @@ start_server {tags {"modules"}} { assert_error {ERR Function already exists in the library} {r function load "#!hello name=mylib2\nFUNCTION foo\nARGS 0\nRETURN\nFUNCTION foo\nARGS 0\nRETURN"} } + test {Load script with syntax error} { + assert_error {ERR Failed to parse instruction: 'SEND'} {r function load replace "#!hello name=mylib3\nFUNCTION foo\nARGS 0\nSEND"} + } + test {Call scripting engine function: calling foo works} { r fcall foo 0 134 } {134}