Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import warnings
from pathlib import Path
from packaging import version

from spikeinterface.core import write_binary_recording

from spikeinterface.core import write_binary_recording, Motion, BaseRecording
from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs
from .kilosortbase import KilosortBase
from spikeinterface.sorters.basesorter import get_job_kwargs
Expand Down Expand Up @@ -453,6 +455,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if (sorter_output_folder / "recording.dat").is_file():
(sorter_output_folder / "recording.dat").unlink()

# close logger
for handler in logger.handlers.copy():
logger.removeHandler(handler)
handler.close()

@classmethod
def _get_result_from_folder(cls, sorter_output_folder):
return KilosortBase._get_result_from_folder(sorter_output_folder)
Expand Down Expand Up @@ -484,3 +491,47 @@ def _setup_json_probe_map(cls, recording, sorter_output_folder):
"n_chan": n_chan,
}
save_probe(probe, str(sorter_output_folder / "chanMap.json"))


def read_kilosort4_motion(sorter_output_folder: str | Path, recording: BaseRecording | None = None) -> Motion:
"""Reads the motion information from a Kilosort4 output folder and returns a Motion object.

Parameters
----------
sorter_output_folder: str or Path
The path to the Kilosort4 output folder.
recording: BaseRecording, optional
The recording object. If provided, the temporal bins will be estimated based on the recording's
start and end times. If not provided, the temporal bins will be estimated based on the number
of batches in the ops file.

Returns
-------
Motion
A Motion object containing the displacement, temporal bins, and spatial bins.

"""
sorter_output_folder = Path(sorter_output_folder)
ops_file = sorter_output_folder / "ops.npy"
if not ops_file.is_file():
raise FileNotFoundError("'ops.npy' file not found!")
ops = np.load(ops_file, allow_pickle=True).item()
yblk = ops.get("yblk")
dshift = ops.get("dshift")
if yblk is None or dshift is None:
raise Exception("'yblk' and 'dshift' fields not found in ops file!")
displacement = dshift + yblk
spatial_bins_um = yblk
# estimate temporal bins
batch_size = ops["batch_size"]
fs = ops["fs"]
t_bin = batch_size / fs
if recording is not None:
t_start = recording.get_start_time()
t_end = recording.get_end_time()
temporal_bins_s = np.linspace(t_start + t_bin / 2, t_end - t_bin / 2)
else:
temporal_bins_s = np.arange(displacement.shape[0]) * t_bin + t_bin / 2

motion = Motion(displacement=displacement, temporal_bins_s=temporal_bins_s, spatial_bins_um=spatial_bins_um)
return motion
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/sorterlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .external.kilosort2 import Kilosort2Sorter
from .external.kilosort2_5 import Kilosort2_5Sorter
from .external.kilosort3 import Kilosort3Sorter
from .external.kilosort4 import Kilosort4Sorter
from .external.kilosort4 import Kilosort4Sorter, read_kilosort4_motion
from .external.pykilosort import PyKilosortSorter
from .external.klusta import KlustaSorter
from .external.mountainsort4 import Mountainsort4Sorter
Expand Down
Loading