diff --git a/lua/lazydev/buf.lua b/lua/lazydev/buf.lua index afc45b2..d0f7f38 100644 --- a/lua/lazydev/buf.lua +++ b/lua/lazydev/buf.lua @@ -92,24 +92,13 @@ end ---@param first number ---@param last number function M.on_lines(buf, first, last) - if -- fast exit when no line contains "require" in the range - #vim.tbl_filter(function(line) - return line:find("require", 1, true) - end, vim.api.nvim_buf_get_lines(buf, first, last, false)) == 0 - then - return - end - - -- Find require calls in the range - local parser = vim.treesitter.get_parser(buf) local changes = {} ---@type string[] - for id, node in M.query:iter_captures(parser:trees()[1]:root(), buf, first, last) do - local capture = M.query.captures[id] - if capture == "modname" then - local text = vim.treesitter.get_node_text(node, buf) - if M.modules[text] == nil then - changes[#changes + 1] = text - end + + local lines = vim.api.nvim_buf_get_lines(buf, first, last, false) + for _, line in ipairs(lines) do + local module = Pkg.get_module(line) + if module then + changes[#changes + 1] = module end end diff --git a/lua/lazydev/pkg.lua b/lua/lazydev/pkg.lua index f970a04..35738f1 100644 --- a/lua/lazydev/pkg.lua +++ b/lua/lazydev/pkg.lua @@ -1,6 +1,13 @@ ---@class lazydev.Pkg local M = {} +M.PAT_MODULE_BASE = "%-%-%-%s*@module%s*[\"']([%w%.%-_]+)" +M.PAT_REQUIRE_BASE = "require%s*%(?%s*['\"]([%w%.%-_]+)" +M.PAT_MODULE_BEFORE = M.PAT_MODULE_BASE .. "$" +M.PAT_REQUIRE_BEFORE = M.PAT_REQUIRE_BASE .. "$" +M.PAT_MODULE = M.PAT_MODULE_BASE .. "[\"']" +M.PAT_REQUIRE = M.PAT_REQUIRE_BASE .. "[\"']" + local is_lazy = type(package.loaded.lazy) == "table" ---@param modname string @@ -38,4 +45,25 @@ end M.get_unloaded = is_lazy and M.lazy_unloaded or M.pack_unloaded +--- Get the module name from a line, +--- either `---@module "modname"` or `require "modname"` +---@param line string +---@param opts? {before?:boolean} +---@return string? +function M.get_module(line, opts) + local patterns = opts and opts.before and { + M.PAT_MODULE_BEFORE, + M.PAT_REQUIRE_BEFORE, + } or { + M.PAT_MODULE, + M.PAT_REQUIRE, + } + for _, pat in ipairs(patterns) do + local match = line:match(pat) + if match then + return match + end + end +end + return M