Skip to content

Commit

Permalink
lua/Class: add WrapMethod()
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxKellermann committed Dec 4, 2023
1 parent 7a8fe67 commit 63bd1f8
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 74 deletions.
16 changes: 16 additions & 0 deletions src/lua/Class.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@ struct Class {
return *(pointer)luaL_checkudata(L, idx, name);
}

/**
* Generate a wrapper function which invokes Cast() and calls
* the specified method.
*/
template<auto method>
static constexpr lua_CFunction WrapMethod() noexcept
requires std::is_class_v<T> &&
std::is_member_function_pointer_v<decltype(method)> {
static_assert(std::is_same_v<decltype(method), int (T::*)(lua_State *)>);

return [](lua_State *L) {
reference object = Cast(L, 1);
return (object.*method)(L);
};
}

private:
static int l_gc(lua_State *L) {
const ScopeCheckStack check_stack(L);
Expand Down
29 changes: 9 additions & 20 deletions src/lua/io/XattrTable.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,31 @@ class XattrTable final {
explicit XattrTable(UniqueFileDescriptor &&_fd) noexcept
:fd(std::move(_fd)) {}

static int Close(lua_State *L);
static int Index(lua_State *L);

private:
int _Close() {
fd.Close();
return 0;
}

int _Index(lua_State *L, const char *name);
int Close(lua_State *L);
int Index(lua_State *L);
};

static constexpr char lua_xattr_table_class[] = "io.XattrTable";
using XattrTableClass = Lua::Class<XattrTable, lua_xattr_table_class>;

int
inline int
XattrTable::Close(lua_State *L)
{
if (lua_gettop(L) != 1)
return luaL_error(L, "Invalid parameters");

return XattrTableClass::Cast(L, 1)._Close();
fd.Close();
return 0;
}

int
inline int
XattrTable::Index(lua_State *L)
{
if (lua_gettop(L) != 2)
return luaL_error(L, "Invalid parameters");

return XattrTableClass::Cast(L, 1)._Index(L, luaL_checkstring(L, 2));
}
const char *const name = luaL_checkstring(L, 2);

inline int
XattrTable::_Index(lua_State *L, const char *name)
{
if (!fd.IsDefined())
luaL_error(L, "Stale object");

Expand Down Expand Up @@ -86,8 +75,8 @@ void
InitXattrTable(lua_State *L) noexcept
{
XattrTableClass::Register(L);
SetField(L, RelativeStackIndex{-1}, "__index", XattrTable::Index);
SetField(L, RelativeStackIndex{-1}, "__close", XattrTable::Close);
SetField(L, RelativeStackIndex{-1}, "__index", XattrTableClass::WrapMethod<&XattrTable::Index>());
SetField(L, RelativeStackIndex{-1}, "__close", XattrTableClass::WrapMethod<&XattrTable::Close>());
lua_pop(L, 1);
}

Expand Down
60 changes: 23 additions & 37 deletions src/lua/pg/Connection.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -125,28 +125,25 @@ class PgConnection final : Pg::SharedConnectionHandler {
return connection.GetEventLoop();
}

private:
static int Execute(lua_State *L);
int Execute(lua_State *L, int sql, int params);
static int Listen(lua_State *L);
int Listen(lua_State *L, int name_idx, int handler_idx);
int Execute(lua_State *L);
int Listen(lua_State *L);

private:
/* virtual methods from class Pg::SharedConnectionHandler */
void OnPgConnect() override;
void OnPgNotify(const char *name) override;
void OnPgError(std::exception_ptr e) noexcept override;

public:
static constexpr struct luaL_Reg methods [] = {
{"execute", Execute},
{"listen", Listen},
{nullptr, nullptr}
};
};

static constexpr char lua_pg_connection_class[] = "pg.Connection";
using PgConnectionClass = Lua::Class<PgConnection, lua_pg_connection_class>;

static constexpr struct luaL_Reg lua_pg_connection_methods[] = {
{"execute", PgConnectionClass::WrapMethod<&PgConnection::Execute>()},
{"listen", PgConnectionClass::WrapMethod<&PgConnection::Listen>()},
{nullptr, nullptr}
};

class PgRequest final
: public Pg::SharedConnectionQuery, Pg::AsyncResultHandler
{
Expand Down Expand Up @@ -243,15 +240,6 @@ static constexpr char lua_pg_request_class[] = "pg.Request";
using PgRequestClass = Lua::Class<PgRequest, lua_pg_request_class>;

inline int
PgConnection::Execute(lua_State *L, int sql, int params)
{
auto *request = PgRequestClass::New(L, L, connection,
sql, params);
connection.ScheduleQuery(*request);
return lua_yield(L, 1);
}

int
PgConnection::Execute(lua_State *L)
{
if (lua_gettop(L) < 2)
Expand All @@ -268,13 +256,23 @@ PgConnection::Execute(lua_State *L)
params = 3;
}

auto &connection = PgConnectionClass::Cast(L, 1);
return connection.Execute(L, sql, params);
auto *request = PgRequestClass::New(L, L, connection,
sql, params);
connection.ScheduleQuery(*request);
return lua_yield(L, 1);
}

inline int
PgConnection::Listen(lua_State *L, int name_idx, int handler_idx)
PgConnection::Listen(lua_State *L)
{
if (lua_gettop(L) < 3)
return luaL_error(L, "Not enough parameters");
if (lua_gettop(L) > 3)
return luaL_error(L, "Too many parameters");

constexpr int name_idx = 2;
constexpr int handler_idx = 3;

const char *name = luaL_checkstring(L, name_idx);
luaL_checktype(L, 3, LUA_TFUNCTION);

Expand All @@ -290,18 +288,6 @@ PgConnection::Listen(lua_State *L, int name_idx, int handler_idx)
return 0;
}

int
PgConnection::Listen(lua_State *L)
{
if (lua_gettop(L) < 3)
return luaL_error(L, "Not enough parameters");
if (lua_gettop(L) > 3)
return luaL_error(L, "Too many parameters");

auto &connection = PgConnectionClass::Cast(L, 1);
return connection.Listen(L, 2, 3);
}

void
PgConnection::OnPgConnect()
{
Expand Down Expand Up @@ -457,7 +443,7 @@ void
InitPgConnection(lua_State *L) noexcept
{
PgConnectionClass::Register(L);
luaL_newlib(L, PgConnection::methods);
luaL_newlib(L, lua_pg_connection_methods);
lua_setfield(L, -2, "__index");
lua_pop(L, 1);

Expand Down
24 changes: 7 additions & 17 deletions src/lua/pg/Result.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,19 @@ class PgResult final {
explicit PgResult(Pg::Result &&_result) noexcept
:result(std::move(_result)) {}

private:
static int Fetch(lua_State *L);
int _Fetch(lua_State *L);

public:
static constexpr struct luaL_Reg methods [] = {
{"fetch", Fetch},
{nullptr, nullptr}
};
int Fetch(lua_State *L);
};

static constexpr char lua_pg_result_class[] = "pg.Result";
using PgResultClass = Lua::Class<PgResult, lua_pg_result_class>;

int
PgResult::Fetch(lua_State *L)
{
auto &result = PgResultClass::Cast(L, 1);
return result._Fetch(L);
}
static constexpr struct luaL_Reg lua_pg_result_methods [] = {
{"fetch", PgResultClass::WrapMethod<&PgResult::Fetch>()},
{nullptr, nullptr}
};

inline int
PgResult::_Fetch(lua_State *L)
PgResult::Fetch(lua_State *L)
{
if (next_row >= result.GetRowCount())
return 0;
Expand Down Expand Up @@ -86,7 +76,7 @@ void
InitPgResult(lua_State *L) noexcept
{
PgResultClass::Register(L);
luaL_newlib(L, PgResult::methods);
luaL_newlib(L, lua_pg_result_methods);
lua_setfield(L, -2, "__index");
lua_pop(L, 1);
}
Expand Down

0 comments on commit 63bd1f8

Please sign in to comment.