diff --git a/README.md b/README.md index 431100fd..a7321905 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,14 @@ For extra features neotest provides consumers which interact with the state of t Some consumers will be passive while others can be interacted with. +### Watch Tests + +`:h neotest.watch` + +Watches files related to tests for changes and re-runs tests + +https://user-images.githubusercontent.com/24252670/229367494-6775d7f1-a8fb-461b-bbbd-d6124031293e.mp4 + ### Output Window `:h neotest.output` @@ -224,6 +232,8 @@ Displays the status of a test/namespace beside the beginning of the definition. ![image](https://user-images.githubusercontent.com/24252670/166143402-b318ef91-c053-4973-b929-5ee97572f2c2.png) +See the help doc for a list of all consumers and their documentation. + ## Strategies Strategies are methods of running tests. They provide the functionality to attach to running processes and so attaching diff --git a/doc/neotest.txt b/doc/neotest.txt index da7ef75e..93da8b1a 100644 --- a/doc/neotest.txt +++ b/doc/neotest.txt @@ -9,6 +9,7 @@ neotest *neotest* Output Consumer.............................................|neotest.output| Output Panel Consumer.................................|neotest.output_panel| Run Consumer...................................................|neotest.run| + Watch Consumer...............................................|neotest.watch| Status Consumer.............................................|neotest.status| Diagnostic Consumer.....................................|neotest.diagnostic| Summary Consumer...........................................|neotest.summary| @@ -101,7 +102,8 @@ Default values: skipped = "NeotestSkipped", target = "NeotestTarget", test = "NeotestTest", - unknown = "NeotestUnknown" + unknown = "NeotestUnknown", + watching = "NeotestWatching" }, icons = { child_indent = "│", @@ -116,7 +118,8 @@ Default values: running = "", running_animated = { "/", "|", "\\", "-", "/", "|", "\\", "-" }, skipped = "", - unknown = "" + unknown = "", + watching = "" }, jump = { enabled = true @@ -177,9 +180,18 @@ Default values: run_marked = "R", short = "O", stop = "u", - target = "t" + target = "t", + watch = "w" }, open = "botright vsplit | vertical resize 50" + }, + watch = { + enabled = true, + symbol_queries = { + elixir = , + lua = ' ;query\n ;Captures module names in require calls\n (function_call\n name: ((identifier) @function (#eq? @function "require"))\n arguments: (arguments (string) @symbol))\n ', + python = " ;query\n ;Captures imports and modules they're imported from\n (import_from_statement (_ (identifier) @symbol))\n (import_statement (_ (identifier) @symbol))\n " + } } } < @@ -219,6 +231,7 @@ Fields~ {quickfix} `(neotest.Config.quickfix)` {status} `(neotest.Config.status)` {state} `(neotest.Config.state)` +{watch} `(neotest.Config.watch)` {diagnostic} `(neotest.Config.diagnostic)` {projects} `(table)` Project specific settings, keys @@ -230,9 +243,8 @@ Fields~ {concurrent} `(integer)` Number of workers to parse files concurrently. 0 automatically assigns number based on CPU. Set to 1 if experiencing lag. {filter_dir} `(nil)` | fun(name: string, rel_path: string, root: string): -boolean -A function to filter directories when searching for test files. Receives the name, -path relative to project root and project root path +boolean A function to filter directories when searching for test files. +Receives the name, path relative to project root and project root path *neotest.Config.running* Fields~ @@ -293,6 +305,7 @@ for its adapter adapter {next_failed} `(string|string[])` Jump to the next failed position {prev_failed} `(string|string[])` Jump to the previous failed position +{watch} `(string|string[])` Toggle watching for changes *neotest.Config.output* Fields~ @@ -327,6 +340,17 @@ Fields~ {virtual_text} `(boolean)` Display status using virtual text {signs} `(boolean)` Display status using signs + *neotest.Config.watch* +Fields~ +{enabled} `(boolean)` +{symbol_queries} `(table)` Treesitter queries or functions to capture symbols that +are used for querying the LSP server for defintions to link files. If it is a +function then the return value should be a list of node ranges. +{filter_path?} `(fun(path: string, root: string): boolean)` Returns whether +the watcher should inspect a path for dependencies. Default ignores paths not +under root or common package manager directories. + ============================================================================== neotest.consumers *neotest.consumers* @@ -433,6 +457,7 @@ A consumer providing a simple interface to run tests. Inherits: `neotest.client.RunTreeArgs` Fields~ +{[1]} `(string?)` Position ID to run {suite} `(boolean)` Run the entire suite instead of a single position *neotest.run.run()* @@ -508,13 +533,6 @@ Parameters~ args then args[1] should be the position ID. - *neotest.run.adapters()* -`adapters`() - -Get the list of all known adapter IDs. -Return~ -`(string[])` - *neotest.run.get_last_run()* `get_last_run`() @@ -525,6 +543,65 @@ Return~ `(neotest.run.RunArgs|nil)` args +============================================================================== +neotest.watch *neotest.watch* + + +Allows watching tests and re-running them whenever related files are +changed. When watching a directory, all files are run in separate processes. +Otherwise the tests are run in the same process (if allowed by the adapter). + +Related files are determined through an LSP client through a "best effort" +which means there are cases where a file may not be determined as related +despite it having an effect on a test. + +To determine file relationships, a treesitter query is used to find symbols +that are queried for using the `textDocument/definition` LSP request. The +query can be configured through the watch consumer's config. Any captures +named `symbol` will be used. If your language is not present in the default +config, please submit a PR to add support out of the box! + + *neotest.watch.watch()* +`watch`({args}) + +Watch a position and run it whenever related files are changed. +Arguments are the same as the `neotest.run.run`, which allows +for custom runner arguments, env vars, strategy etc. If a position is +already being watched, the existing watcher will be stopped. +Parameters~ +{args?} `(neotest.run.RunArgs|string)` + + *neotest.watch.toggle()* +`toggle`({args}) + +Toggle watching a position and run it whenever related files are changed. +Arguments are the same as the `neotest.run.run`, which allows +for custom runner arguments, env vars, strategy etc. + +Toggle watching the current file +>vim + lua require("neotest").watch.toggle(vim.fn.expand("%")) +< +Parameters~ +{args?} `(neotest.run.RunArgs|string)` + + *neotest.watch.stop()* +`stop`({position_id}) + +Stop watching a position. If no position is provided, all watched positions are stopped. +Parameters~ +{position_id} `(string)` + + *neotest.watch.is_watching()* +`is_watching`({position_id}) + +Check if a position is being watched. +Parameters~ +{position_id} `(string)` +Return~ +`(boolean)` + + ============================================================================== neotest.status *neotest.status* @@ -575,9 +652,9 @@ Close the summary window the summary window - +>vim lua require("neotest").summary.toggle() - +< *neotest.summmary.RunMarkedArgs* Inherits: `neotest.run.RunArgs` @@ -1400,17 +1477,26 @@ Return~ Return~ `(neotest.Tree)` + *neotest.types.tree.IterNodesArgs* +Fields~ +{continue} `(fun(node: neotest.Tree): boolean)` A predicate for if the given +node's children should be iterated over. Defaults to `true`. + *neotest.Tree:iter_nodes()* -`Tree:iter_nodes`() +`Tree:iter_nodes`({args}) +Parameters~ +{args?} `(neotest.types.tree.IterNodesArgs)` Return~ -`(fun(): integer,neotest.Tree)` +`(fun():integer,neotest.Tree)` *neotest.Tree:iter()* -`Tree:iter`() +`Tree:iter`({args}) +Parameters~ +{args?} `(neotest.types.tree.IterNodesArgs)` Return~ -`(fun(): integer,neotest.Position)` +`(fun():integer,neotest.Position)` *neotest.Tree:node()* `Tree:node`({index}) diff --git a/lua/neotest/config/init.lua b/lua/neotest/config/init.lua index 78ebde8a..462e321f 100644 --- a/lua/neotest/config/init.lua +++ b/lua/neotest/config/init.lua @@ -1,3 +1,4 @@ +local lib = require("neotest.lib") ---@tag neotest.config ---@toc_entry Configuration Options @@ -7,7 +8,7 @@ local function define_highlights() hi default NeotestFailed ctermfg=Red guifg=#F70067 hi default NeotestRunning ctermfg=Yellow guifg=#FFEC63 hi default NeotestSkipped ctermfg=Cyan guifg=#00f1f5 - hi default link NeotestTest Normal + hi default link NeotestTest Normal hi default NeotestNamespace ctermfg=Magenta guifg=#D484FF hi default NeotestFocused gui=bold,underline cterm=bold,underline hi default NeotestFile ctermfg=Cyan guifg=#00f1f5 @@ -18,6 +19,7 @@ local function define_highlights() hi default NeotestWinSelect ctermfg=Cyan guifg=#00f1f5 gui=bold hi default NeotestMarked ctermfg=Brown guifg=#F79000 gui=bold hi default NeotestTarget ctermfg=Red guifg=#F70067 + hi default NeotestWatching ctermfg=Yellow guifg=#FFEC63 hi default link NeotestUnknown Normal ]]) end @@ -45,17 +47,15 @@ define_highlights() ---@field quickfix neotest.Config.quickfix ---@field status neotest.Config.status ---@field state neotest.Config.state +---@field watch neotest.Config.watch ---@field diagnostic neotest.Config.diagnostic ---@field projects table Project specific settings, keys --- are project root directories (e.g "~/Dev/my_project") ---@class neotest.Config.discovery ---@field enabled boolean ----@field concurrent integer Number of workers to parse files concurrently. 0 ---- automatically assigns number based on CPU. Set to 1 if experiencing lag. ----@field filter_dir nil | fun(name: string, rel_path: string, root: string): boolean ---- A function to filter directories when searching for test files. Receives the name, ---- path relative to project root and project root path +---@field concurrent integer Number of workers to parse files concurrently. 0 automatically assigns number based on CPU. Set to 1 if experiencing lag. +---@field filter_dir nil | fun(name: string, rel_path: string, root: string): boolean A function to filter directories when searching for test files. Receives the name, path relative to project root and project root path ---@class neotest.Config.running ---@field concurrent boolean Run tests concurrently when an adapter provides multiple commands to run @@ -100,6 +100,7 @@ define_highlights() ---@field clear_target string|string[] Clear the target position for the selected adapter ---@field next_failed string|string[] Jump to the next failed position ---@field prev_failed string|string[] Jump to the previous failed position +---@field watch string|string[] Toggle watching for changes ---@class neotest.Config.output ---@field enabled boolean @@ -126,6 +127,11 @@ define_highlights() ---@field virtual_text boolean Display status using virtual text ---@field signs boolean Display status using signs +---@class neotest.Config.watch +---@field enabled boolean +---@field symbol_queries table Treesitter queries or functions to capture symbols that are used for querying the LSP server for defintions to link files. If it is a function then the return value should be a list of node ranges. +---@field filter_path? fun(path: string, root: string): boolean Returns whether the watcher should inspect a path for dependencies. Default ignores paths not under root or common package manager directories. + ---@private ---@type neotest.Config local default_config = { @@ -167,6 +173,7 @@ local default_config = { final_child_prefix = "╰", child_indent = "│", final_child_indent = " ", + watching = "", }, highlights = { passed = "NeotestPassed", @@ -186,6 +193,7 @@ local default_config = { marked = "NeotestMarked", target = "NeotestTarget", unknown = "NeotestUnknown", + watching = "NeotestWatching", }, floating = { border = "rounded", @@ -224,6 +232,7 @@ local default_config = { clear_target = "T", next_failed = "J", prev_failed = "K", + watch = "w", }, }, benchmark = { @@ -259,6 +268,54 @@ local default_config = { state = { enabled = true, }, + watch = { + enabled = true, + symbol_queries = { + python = [[ + ;query + ;Captures imports and modules they're imported from + (import_from_statement (_ (identifier) @symbol)) + (import_statement (_ (identifier) @symbol)) + ]], + lua = [[ + ;query + ;Captures module names in require calls + (function_call + name: ((identifier) @function (#eq? @function "require")) + arguments: (arguments (string) @symbol)) + ]], + elixir = function(root, content) + local query = lib.treesitter.normalise_query( + "elixir", + [[;; query + (call (identifier) @_func_name + (arguments (alias) @symbol) + (#match? @_func_name "^(alias|require|import|use)") + (#gsub! @symbol ".*%.(.*)" "%1") + ) + ]] + ) + local symbols = {} + for _, match, metadata in query:iter_matches(root, content) do + for id, node in pairs(match) do + local name = query.captures[id] + + if name == "symbol" then + local start_row, start_col, end_row, end_col = node:range() + if metadata[id] ~= nil then + local real_symbol_length = string.len(metadata[id]["text"]) + start_col = end_col - real_symbol_length + end + + symbols[#symbols + 1] = { start_row, start_col, end_row, end_col } + end + end + end + return symbols + end, + }, + filter_path = nil, + }, projects = {}, } diff --git a/lua/neotest/consumers/init.lua b/lua/neotest/consumers/init.lua index fc5c1999..4a0d3845 100644 --- a/lua/neotest/consumers/init.lua +++ b/lua/neotest/consumers/init.lua @@ -16,11 +16,11 @@ local neotest = {} --- The client interface provides methods for interacting with tests, fetching --- results as well as event listeners. To listen to an event, just assign the event --- listener to a function: ---- >lua +--- ```lua --- client.listeners.discover_positions = function (adapter_id, tree) --- ... --- end ---- < +--- ``` --- Available events and the listener signatures are visible as properties on the --- `client.listeners` table --- @@ -38,6 +38,7 @@ neotest.consumers = { benchmark = require("neotest.consumers.benchmark"), quickfix = require("neotest.consumers.quickfix"), state = require("neotest.consumers.state"), + watch = require("neotest.consumers.watch"), } return neotest.consumers diff --git a/lua/neotest/consumers/jump.lua b/lua/neotest/consumers/jump.lua index f113ff56..a656ea1b 100644 --- a/lua/neotest/consumers/jump.lua +++ b/lua/neotest/consumers/jump.lua @@ -12,10 +12,10 @@ local neotest = {} --- A consumer that allows jumping between tests --- --- Example mappings to jump between test failures ---- >vim +--- ```vim --- nnoremap [n lua require("neotest").jump.prev({ status = "failed" }) --- nnoremap ]n lua require("neotest").jump.next({ status = "failed" }) ---- < +--- ``` ---@class neotest.consumers.jump neotest.jump = {} diff --git a/lua/neotest/consumers/output.lua b/lua/neotest/consumers/output.lua index 456da0d3..8d4db79d 100644 --- a/lua/neotest/consumers/output.lua +++ b/lua/neotest/consumers/output.lua @@ -145,11 +145,11 @@ end ---@field auto_close boolean Close output window when leaving it, or when cursor moves outside of window --- Open the output of a test result ---- >vim +--- ```vim --- lua require("neotest").output.open({ enter = true }) ---- < +--- ``` ---@param opts? neotest.consumers.output.OpenArgs -neotest.output.open = nio.create(function(opts) +function neotest.output.open(opts) opts = opts or {} if win then if opts.short ~= short_opened then @@ -188,7 +188,9 @@ neotest.output.open = nio.create(function(opts) return end open_output(result, opts) -end, 1) +end + +neotest.output.open = nio.create(neotest.output.open, 1) neotest.output = setmetatable(neotest.output, { __call = function(_, client_) diff --git a/lua/neotest/consumers/output_panel/init.lua b/lua/neotest/consumers/output_panel/init.lua index 1104f7ad..e58055ea 100644 --- a/lua/neotest/consumers/output_panel/init.lua +++ b/lua/neotest/consumers/output_panel/init.lua @@ -56,25 +56,25 @@ local init = function(client) end --- Open the output panel ---- >vim +--- ```vim --- lua require("neotest").output_panel.open() ---- < +--- ``` function neotest.output_panel.open() panel.win:open() end --- Close the output panel ---- >vim +--- ```vim --- lua require("neotest").output_panel.close() ---- < +--- ``` function neotest.output_panel.close() panel.win:close() end --- Toggle the output panel ---- >vim +--- ```vim --- lua require("neotest").output_panel.toggle() ---- < +--- ``` function neotest.output_panel.toggle() if panel.win:is_open() then neotest.output_panel.close() diff --git a/lua/neotest/consumers/run.lua b/lua/neotest/consumers/run.lua index 67e1b635..7a67676a 100644 --- a/lua/neotest/consumers/run.lua +++ b/lua/neotest/consumers/run.lua @@ -5,7 +5,6 @@ local lib = require("neotest.lib") ---@type neotest.Client local client local last_run -local client_ready = false local neotest = {} @@ -15,8 +14,13 @@ local neotest = {} ---@class neotest.consumers.run neotest.run = {} ----@private +---@package +---@nodoc function neotest.run.get_tree_from_args(args, store) + args = args or {} + if type(args) == "string" then + args = { args } + end local tree, adapter = (function() if args.suite then if not args.adapter then @@ -39,54 +43,53 @@ function neotest.run.get_tree_from_args(args, store) end ---@class neotest.run.RunArgs : neotest.client.RunTreeArgs +---@field [1] string? Position ID to run ---@field suite boolean Run the entire suite instead of a single position --- Run the given position or the nearest position if not given. --- All arguments are optional --- --- Run the current file ---- >vim +--- ```vim --- lua require("neotest").run.run(vim.fn.expand("%")) ---- < +--- ``` --- --- Run the nearest test ---- >vim +--- ```vim --- lua require("neotest").run.run() ---- < +--- ``` --- --- Debug the current file with nvim-dap ---- >vim +--- ```vim --- lua require("neotest").run.run({vim.fn.expand("%"), strategy = "dap"}) ---- < +--- ``` ---@param args string|neotest.run.RunArgs? Position ID to run or args. -neotest.run.run = nio.create(function(args) - args = args or {} - if type(args) == "string" then - args = { args } - end +function neotest.run.run(args) local tree = neotest.run.get_tree_from_args(args, true) if not tree then lib.notify("No tests found") return end client:run_tree(tree, args) -end, 1) +end + +neotest.run.run = nio.create(neotest.run.run, 1) --- Re-run the last position that was run. --- Arguments are optional --- --- Run the last position that was run with the same arguments and strategy ---- >vim +--- ```vim --- lua require("neotest").run.run_last() ---- < +--- ``` --- --- Run the last position that was run with the same arguments but debug with --- nvim-dap ---- >vim +--- ```vim --- lua require("neotest").run.run_last({ strategy = "dap" }) ---- < +--- ``` ---@param args neotest.run.RunArgs? Argument overrides -neotest.run.run_last = nio.create(function(args) +function neotest.run.run_last(args) args = args or {} if not last_run then lib.notify("No tests run yet") @@ -102,7 +105,9 @@ neotest.run.run_last = nio.create(function(args) end client:run_tree(tree, args) end) -end, 1) +end + +neotest.run.run_last = nio.create(neotest.run.run_last, 1) local function get_tree_interactive() local running = client:running_positions() @@ -125,7 +130,7 @@ end --- ---@param args string|neotest.run.StopArgs? Position ID to stop or args. If --- args then args[1] should be the position ID. -neotest.run.stop = nio.create(function(args) +function neotest.run.stop(args) args = args or {} if type(args) == "string" then args = { args } @@ -141,7 +146,8 @@ neotest.run.stop = nio.create(function(args) return end client:stop(pos, args) -end, 1) +end +neotest.run.stop = nio.create(neotest.run.stop, 1) ---@class neotest.run.AttachArgs : neotest.client.AttachArgs ---@field interactive boolean Select a running position interactively @@ -150,7 +156,7 @@ end, 1) --- ---@param args string|neotest.run.AttachArgs? Position ID to attach to or args. If args then --- args[1] should be the position ID. -neotest.run.attach = nio.create(function(args) +function neotest.run.attach(args) args = args or {} if type(args) == "string" then args = { args } @@ -166,15 +172,18 @@ neotest.run.attach = nio.create(function(args) return end client:attach(pos, args) -end, 1) +end +neotest.run.attach = nio.create(neotest.run.attach, 1) --- Get the list of all known adapter IDs. ---@return string[] +---@nodoc function neotest.run.adapters() - if not client_ready then - return {} - end - return client:get_adapters() + lib.notify( + "`neotest.run.adapters` is deprecated, please use `neotest.state.adapter_ids` instead", + vim.log.levels.WARN + ) + return require("neotest").state.adapter_ids() end --- Get last test position ID and args @@ -191,9 +200,6 @@ neotest.run = setmetatable(neotest.run, { ---@param client_ neotest.Client __call = function(_, client_) client = client_ - client.listeners.starting = function() - client_ready = true - end return neotest.run end, }) diff --git a/lua/neotest/consumers/summary/component.lua b/lua/neotest/consumers/summary/component.lua index ada91058..746cb897 100644 --- a/lua/neotest/consumers/summary/component.lua +++ b/lua/neotest/consumers/summary/component.lua @@ -178,6 +178,13 @@ function SummaryComponent:_render(canvas, tree, expanded, focused, indent) neotest.run.attach(position.id, { adapter = self.adapter_id }) end) ) + canvas:add_mapping( + "watch", + async_func(function() + neotest.watch.toggle({ position.id, adapter = self.adapter_id }) + neotest.summary.render() + end) + ) canvas:add_mapping( "output", async_func(function() @@ -228,7 +235,13 @@ function SummaryComponent:_render(canvas, tree, expanded, focused, indent) local status = self:_get_status(position) has_running = has_running or status == "running" + local state_icon, state_icon_group = self:_state_icon(status) + + if neotest.watch.is_watching(position.id) then + state_icon = config.icons.watching + end + canvas:write(" " .. state_icon .. " ", { group = state_icon_group }) local name_groups = { config.highlights[position.type] } diff --git a/lua/neotest/consumers/summary/init.lua b/lua/neotest/consumers/summary/init.lua index b76efe9e..2ceb2179 100644 --- a/lua/neotest/consumers/summary/init.lua +++ b/lua/neotest/consumers/summary/init.lua @@ -71,9 +71,9 @@ neotest.summary.render = function(positions) end --- Open the summary window ---- >vim +--- ```vim --- lua require("neotest").summary.open() ---- < +--- ``` function neotest.summary.open() if summary.win:is_open() then return @@ -83,18 +83,18 @@ function neotest.summary.open() end --- Close the summary window ---- >vim +--- ```vim --- lua require("neotest").summary.close() ---- < +--- ``` function neotest.summary.close() summary:close() end ---Toggle the summary window --- ---->vim +--- ```vim --- lua require("neotest").summary.toggle() ----< +--- ``` function neotest.summary.toggle() nio.run(function() if summary.win:is_open() then diff --git a/lua/neotest/consumers/watch/init.lua b/lua/neotest/consumers/watch/init.lua new file mode 100644 index 00000000..abcf94fb --- /dev/null +++ b/lua/neotest/consumers/watch/init.lua @@ -0,0 +1,191 @@ +local lib = require("neotest.lib") +local config = require("neotest.config") +local logger = require("neotest.logging") +local nio = require("nio") +local Watcher = require("neotest.consumers.watch.watcher") + +local watchers = {} +---@type table +---@private +local start_tasks = {} + +local neotest = {} + +---@toc_entry Watch Consumer +---@text +--- Allows watching tests and re-running them whenever related files are +--- changed. When watching a directory, all files are run in separate processes. +--- Otherwise the tests are run in the same process (if allowed by the adapter). +--- +--- Related files are determined through an LSP client through a "best effort" +--- which means there are cases where a file may not be determined as related +--- despite it having an effect on a test. +--- +--- To determine file relationships, a treesitter query is used to find symbols +--- that are queried for using the `textDocument/definition` LSP request. The +--- query can be configured through the watch consumer's config. Any captures +--- named `symbol` will be used. If your language is not present in the default +--- config, please submit a PR to add support out of the box! +---@class neotest.consumers.watch +neotest.watch = {} + +local function get_valid_client_id(bufnr) + local sync_clients = vim.lsp.get_active_clients({ bufnr = bufnr }) + for _, client in ipairs(sync_clients) do + ---@type nio.lsp.types.ServerCapabilities + local caps = client.server_capabilities + if caps.definitionProvider then + logger.debug("Found client", client.name, "for watch") + return client.id + end + end +end + +local function get_lsp_client(tree) + for _, buf in ipairs(nio.api.nvim_list_bufs()) do + local path = nio.fn.fnamemodify(nio.api.nvim_buf_get_name(buf), ":p") + if tree:get_key(path) then + local client_id = get_valid_client_id(buf) + if client_id then + return nio.lsp.client(client_id) + end + end + end +end + +local ignored_dirs = { + "venv", + ".venv", + "node_modules", +} + +---@type neotest.consumers.watch.watcher.WatchArgs +---@private +local default_args = { + symbol_queries = config.watch.symbol_queries, + filter_path = config.watch.filter_path or function(path, root) + if not vim.startswith(path, root) then + return false + end + for _, dir in ipairs(ignored_dirs) do + if vim.startswith(path, root .. lib.files.sep .. dir) then + return false + end + end + return true + end, +} + +--- Watch a position and run it whenever related files are changed. +--- Arguments are the same as the `neotest.run.run`, which allows +--- for custom runner arguments, env vars, strategy etc. If a position is +--- already being watched, the existing watcher will be stopped. +---@param args? neotest.run.RunArgs|string +function neotest.watch.watch(args) + args = args or {} + if type(args) == "string" then + args = { args } + end + args = vim.tbl_extend("keep", args, default_args) + + local run = require("neotest").run + local tree = run.get_tree_from_args(args, false) + + if not tree then + lib.notify(("No position found with args %s"):format(vim.inspect(args)), vim.log.levels.ERROR) + return + end + + local lsp_client = get_lsp_client(tree) + if not lsp_client then + lib.notify( + "No valid LSP client found for watching. Ensure that at least one test file is open and has an LSP client attached.", + vim.log.levels.ERROR + ) + return + end + + local watcher = Watcher:new(lsp_client) + + local pos_id = tree:data().id + if watchers[pos_id] then + neotest.watch.stop(pos_id) + end + watchers[pos_id] = watcher + + start_tasks[pos_id] = nio.run(function() + lib.notify(("Starting watcher for %s"):format(tree:data().name)) + watcher:watch(tree, args) + lib.notify(("Watcher running for %s"):format(tree:data().name)) + start_tasks[pos_id] = nil + end) +end + +neotest.watch.watch = nio.create(neotest.watch.watch, 1) + +--- Toggle watching a position and run it whenever related files are changed. +--- Arguments are the same as the `neotest.run.run`, which allows +--- for custom runner arguments, env vars, strategy etc. +--- +--- Toggle watching the current file +--- ```vim +--- lua require("neotest").watch.toggle(vim.fn.expand("%")) +--- ``` +---@param args? neotest.run.RunArgs|string +function neotest.watch.toggle(args) + local run = require("neotest").run + local tree = run.get_tree_from_args(args, false) + + if not tree then + lib.notify(("No position found with args %s"):format(vim.inspect(args)), vim.log.levels.ERROR) + return + end + + local position_id = tree:data().id + + if neotest.watch.is_watching(position_id) then + neotest.watch.stop(position_id) + else + neotest.watch.watch(args) + end +end + +neotest.watch.toggle = nio.create(neotest.watch.toggle, 1) + +--- Stop watching a position. If no position is provided, all watched positions are stopped. +---@param position_id string +function neotest.watch.stop(position_id) + if not position_id then + for watched in pairs(watchers) do + neotest.watch.stop(watched) + end + return + end + + if not watchers[position_id] then + lib.notify(("%s is not being watched"):format(position_id), vim.log.levels.WARN) + return + end + lib.notify(("Stopping watch for %s"):format(position_id), vim.log.levels.INFO) + watchers[position_id]:stop_watch() + watchers[position_id] = nil + if start_tasks[position_id] then + start_tasks[position_id].cancel() + start_tasks[position_id] = nil + end +end + +--- Check if a position is being watched. +---@param position_id string +---@return boolean +function neotest.watch.is_watching(position_id) + return watchers[position_id] ~= nil +end + +neotest.watch = setmetatable(neotest.watch, { + __call = function() + return neotest.watch + end, +}) + +return neotest.watch diff --git a/lua/neotest/consumers/watch/watcher.lua b/lua/neotest/consumers/watch/watcher.lua new file mode 100644 index 00000000..00b1c8d0 --- /dev/null +++ b/lua/neotest/consumers/watch/watcher.lua @@ -0,0 +1,202 @@ +local lib = require("neotest.lib") +local logger = require("neotest.logging") +local nio = require("nio") +local config = require("neotest.config") + +---@class neotest.consumers.watch.Watcher +---@field lsp_client nio.lsp.Client +---@field autocmd_id? string +local Watcher = {} + +function Watcher:new(lsp_client) + local obj = { lsp_client = lsp_client } + self.__index = self + return setmetatable(obj, self) +end + +---@return integer[][] +function Watcher._parse_symbols(path) + logger.debug("Parsing symbols for", path) + local content = lib.files.read(path) + local root, lang = lib.treesitter.get_parse_root(path, content, {}) + local query = config.watch.symbol_queries[lang] + if not query then + error("No symbols query for language: " .. lang) + end + if type(query) == "function" then + return query(root, content, path) + end + local parsed_query = lib.treesitter.normalise_query(lang, query) + local symbols = {} + for id, node in parsed_query:iter_captures(root, content) do + if parsed_query.captures[id] == "symbol" then + symbols[#symbols + 1] = { node:range() } + end + end + return symbols +end + +---@param args neotest.consumers.watch.watcher.WatchArgs +---@return string[] paths +function Watcher:_get_linked_files(path, root_path, args) + local symbols = lib.subprocess.enabled() + and lib.subprocess.call( + [[require("neotest.consumers.watch.watcher")._parse_symbols]], + { path } + ) + or self._parse_symbols(path) + local path_uri = vim.uri_from_fname(path) + local dependency_uris = {} + logger.debug("Getting symbol definitions for", path) + for _, range in ipairs(symbols) do + local err, defs = self.lsp_client.request.textDocument_definition({ + position = { line = range[1], character = range[2] }, + textDocument = { uri = path_uri }, + }, nil, { timeout = 1000 }) + + if err then + logger.debug("Error getting symbol definitions for", path, ":", err) + end + + if defs ~= nil and type(defs[1]) ~= "table" then + defs = { defs } + end + + for _, def in ipairs(defs or {}) do + dependency_uris[def.uri or def.targetUri] = true + end + end + local paths = { path } + for uri in pairs(dependency_uris) do + local p = vim.uri_to_fname(uri) + if uri ~= path_uri and args.filter_path(p, root_path) then + paths[#paths + 1] = p + end + end + logger.debug("Found", #paths, "linked files for", path) + return paths +end + +---@class neotest.consumers.watch.watcher.WatchArgs +---@field filter_path fun(root: string, path: string): boolean + +---@paam tree neotest.Tree +function Watcher:_files_in_tree(tree) + if tree:data().type ~= "dir" then + return { tree:data().path } + end + local paths = {} + for _, pos in + tree:iter({ + continue = function(node) + return node:data().type == "dir" + end, + }) + do + if pos.type == "file" then + paths[#paths + 1] = pos.path + end + end + return paths +end + +---@param root string +---@param paths string[] +---@param args neotest.consumers.watch.watcher.WatchArgs +function Watcher:_build_dependencies(root, paths, args, dependencies) + local count = 0 + local worker = function() + while #paths > 0 do + local path = table.remove(paths) + + if not dependencies[path] then + count = count + 1 + dependencies[path] = {} + local path_results = self:_get_linked_files(path, root, args) + dependencies[path] = path_results + + for _, p in ipairs(path_results) do + if not dependencies[p] then + paths[#paths + 1] = p + end + end + end + end + end + local num_workers = 4 + local workers = {} + for _ = 1, num_workers do + workers[#workers + 1] = worker + end + nio.gather(workers) +end + +---@param dependencies table +function Watcher:_build_dependants(dependencies) + local dependants = {} + for path, deps in pairs(dependencies) do + for _, dep in ipairs(deps) do + dependants[dep] = dependants[dep] or {} + dependants[dep][#dependants[dep] + 1] = path + end + end + return dependants +end + +function Watcher:watch(tree, args) + local run = require("neotest").run + local paths = self:_files_in_tree(tree) + + local start = vim.loop.now() + local dependencies = {} + self:_build_dependencies(tree:root():data().path, paths, args, dependencies) + local elapsed = vim.loop.now() - start + logger.debug("Built dependencies in", elapsed, "ms for", tree:data().id, ":", dependencies) + local dependants = self:_build_dependants(dependencies) + + self.autocmd_id = nio.api.nvim_create_autocmd("BufWritePost", { + callback = function(autocmd_args) + nio.run(function() + local path = nio.fn.expand(nio.api.nvim_buf_get_name(autocmd_args.buf), ":p") + + local buf_dependants = dependants[path] + if not buf_dependants then + return + end + + if tree:data().type ~= "dir" then + run.run(vim.tbl_extend("keep", { tree:data().id }, args)) + else + for _, dep in ipairs(buf_dependants) do + run.run(vim.tbl_extend("keep", { dep }, args)) + end + end + + if dependencies[path] then + dependencies[path] = nil + self:_build_dependencies(tree:root():data().path, { path }, args, dependencies) + logger.debug("Rebuilt dependencies for", tree:data().id, ":", dependencies) + dependants = self:_build_dependants(dependencies) + end + end, function(success, err) + if not success then + lib.notify(("Error watching %s: %s"):format(tree:data().name, err), vim.log.levels.ERROR) + end + end) + end, + }) + + run.run(vim.tbl_extend("keep", { tree:data().id }, args)) + logger.info("Starting watch of", tree:data().id) +end + +function Watcher:stop_watch() + if not self.autocmd_id then + logger.warn("Watcher never started, can't stop it") + return + end + logger.info("Stopping watch") + nio.api.nvim_del_autocmd(self.autocmd_id) +end + +return Watcher diff --git a/lua/neotest/init.lua b/lua/neotest/init.lua index 7f04a522..91bc032e 100644 --- a/lua/neotest/init.lua +++ b/lua/neotest/init.lua @@ -23,9 +23,9 @@ --- Each consumer can be accessed as a property of the neotest module --- See the table of contents for the consumers --- ---- >vim +--- ```vim --- lua require("neotest").summary.toggle() ---- < +--- ``` --- ---@class neotest @@ -37,6 +37,7 @@ ---@field diagnostic neotest.consumers.diagnostic ---@field jump neotest.consumers.jump ---@field state neotest.consumers.state +---@field watch neotest.consumers.watch ---@nodoc local neotest = {} diff --git a/lua/neotest/types/tree.lua b/lua/neotest/types/tree.lua index 65700cc3..fbd52d48 100644 --- a/lua/neotest/types/tree.lua +++ b/lua/neotest/types/tree.lua @@ -189,23 +189,32 @@ function neotest.Tree:root() return node end ----@return fun(): integer,neotest.Tree -function neotest.Tree:iter_nodes() +---@class neotest.types.tree.IterNodesArgs +---@field continue fun(node: neotest.Tree): boolean A predicate for if the given node's children should be iterated over. Defaults to `true`. + +---@param args? neotest.types.tree.IterNodesArgs +---@return fun():integer,neotest.Tree +function neotest.Tree:iter_nodes(args) + args = args or {} local child_i = 0 local total_i = 1 local child_iter = nil + local continue = not args.continue and true or args.continue(self) return function() if child_i == 0 then child_i = 1 return 1, self end + if not continue then + return nil + end while true do if not child_iter then if #self._children < child_i then return nil end - child_iter = self._children[child_i]:iter_nodes() + child_iter = self._children[child_i]:iter_nodes(args) end local _, child_data = child_iter() if child_data then @@ -218,9 +227,10 @@ function neotest.Tree:iter_nodes() end end ----@return fun(): integer,neotest.Position -function neotest.Tree:iter() - local node_iter = self:iter_nodes() +---@param args? neotest.types.tree.IterNodesArgs +---@return fun():integer,neotest.Position +function neotest.Tree:iter(args) + local node_iter = self:iter_nodes(args) return function() local i, node = node_iter() if not i then diff --git a/lua/nio/tasks.lua b/lua/nio/tasks.lua index edb83472..5b305043 100644 --- a/lua/nio/tasks.lua +++ b/lua/nio/tasks.lua @@ -91,7 +91,7 @@ function nio.tasks.run(func, cb) future.set_error(err) if cb then cb(false, err) - else + elseif not cancelled then error("Async task failed without callback: " .. err) end else diff --git a/scripts/gendocs.lua b/scripts/gendocs.lua index 9421a3e8..453ef267 100644 --- a/scripts/gendocs.lua +++ b/scripts/gendocs.lua @@ -846,6 +846,7 @@ minidoc.generate( "./lua/neotest/consumers/output.lua", "./lua/neotest/consumers/output_panel/init.lua", "./lua/neotest/consumers/run.lua", + "./lua/neotest/consumers/watch/init.lua", "./lua/neotest/consumers/status.lua", "./lua/neotest/consumers/diagnostic.lua", "./lua/neotest/consumers/summary/init.lua", diff --git a/tests/unit/types/tree_spec.lua b/tests/unit/types/tree_spec.lua index 7f9968e5..df976e21 100644 --- a/tests/unit/types/tree_spec.lua +++ b/tests/unit/types/tree_spec.lua @@ -31,4 +31,20 @@ describe("neotest tree", function() i = i + 1 end end) + it("iterates with predicate", function() + local data = { 1, { 2, { 0 }, { 3, { 0 } } } } + local tree = Tree.from_list(data, function(x) + return x + end) + local i = 1 + local iterator = tree:iter({ + continue = function(x) + return x:data() == 1 + end, + }) + for _, elem in iterator do + assert.are.same(elem, i) + i = i + 1 + end + end) end)