From f2a202057362cb4884ac51713510f24162a3d505 Mon Sep 17 00:00:00 2001 From: "Liz Ing-Simmons (k2474365)" Date: Thu, 13 Mar 2025 11:01:26 +0000 Subject: [PATCH 1/3] fix: replace double underscores with single underscore in key Explanation methods to allow inheritance --- rex_xai/explanation/explanation.py | 12 ++++++------ rex_xai/explanation/multi_explanation.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/rex_xai/explanation/explanation.py b/rex_xai/explanation/explanation.py index d5f81e21..2cb923f4 100644 --- a/rex_xai/explanation/explanation.py +++ b/rex_xai/explanation/explanation.py @@ -87,15 +87,15 @@ def __repr__(self) -> str: def extract(self, method: Strategy): self.blank() if method == Strategy.Global: - self.__global() + self._global() if method == Strategy.Spatial: if self.data.mode == "spectral": logger.warning( "spatial search not yet implemented for spectral data, so defaulting to global search" ) - self.__global() + self._global() else: - _ = self.__spatial() + _ = self._spatial() if isinstance(self.final_mask, tt.Tensor): self.final_mask = self.final_mask.detach().cpu().numpy() @@ -115,7 +115,7 @@ def set_to_true(self, coords, mask=None): mask, self.data.mode, self.data.model_order, coords ) - def __global(self, map=None, wipe=False): + def _global(self, map=None, wipe=False): if map is None: map = self.target_map ranking = get_map_locations(map) @@ -212,7 +212,7 @@ def compute_masked_responsibility(self, mask): ) return tt.mean(masked_responsibility).item() - def __spatial(self, centre=None, expansion_limit=None): + def _spatial(self, centre=None, expansion_limit=None): # we don't have a search location to start from, so we try to isolate one map = self.target_map if centre is None: @@ -243,7 +243,7 @@ def __spatial(self, centre=None, expansion_limit=None): and p.confidence >= self.data.target.confidence * self.args.minimum_confidence_threshold # type: ignore ): - conf = self.__global(map=tt.where(circle, map, 0)) # type: ignore + conf = self._global(map=tt.where(circle, map, 0)) # type: ignore return SpatialSearch.Found, masked_responsibility, conf start_radius = int(start_radius * (1 + self.args.spatial_radius_eta)) _, circle, _ = self.__draw_circle(centre, start_radius) diff --git a/rex_xai/explanation/multi_explanation.py b/rex_xai/explanation/multi_explanation.py index aca514ca..0a1122e1 100644 --- a/rex_xai/explanation/multi_explanation.py +++ b/rex_xai/explanation/multi_explanation.py @@ -123,7 +123,7 @@ def extract(self, method=None): self.blank() # we start with the global max explanation logger.info("spotlight number 1 (global max)") - conf = self._Explanation__global() # type: ignore + conf = self._global() if self.final_mask is not None: self.explanations.append(self.final_mask) self.explanation_confidences.append(conf) @@ -251,7 +251,7 @@ def spotlight_search(self, origin=None): else: centre = origin - ret, resp, conf = self._Explanation__spatial( # type: ignore + ret, resp, conf = self._spatial( centre=centre, expansion_limit=self.args.no_expansions ) @@ -259,7 +259,7 @@ def spotlight_search(self, origin=None): while ret == SpatialSearch.NotFound and steps < self.args.max_spotlight_budget: if self.args.spotlight_objective_function == "none": centre = self.__random_location() - ret, resp, conf = self._Explanation__spatial( # type: ignore + ret, resp, conf = self._spatial( centre=centre, expansion_limit=self.args.no_expansions ) else: @@ -271,12 +271,12 @@ def spotlight_search(self, origin=None): self.data.model_width, step=self.args.spotlight_step, ) - ret, new_resp, conf = self._Explanation__spatial( # type: ignore + ret, new_resp, conf = self._spatial( # type: ignore centre=centre, expansion_limit=self.args.no_expansions ) if ret == SpatialSearch.Found: return conf - ret, resp, conf = self._Explanation__spatial( # type: ignore + ret, resp, conf = self._spatial( # type: ignore centre=centre, expansion_limit=self.args.no_expansions ) steps += 1 From 49343e865e4c5ea8e22b98c281bd714a9aa826c5 Mon Sep 17 00:00:00 2001 From: "Liz Ing-Simmons (k2474365)" Date: Thu, 13 Mar 2025 11:03:56 +0000 Subject: [PATCH 2/3] fix: replace double underscores with single underscore in additional Explanation methods --- rex_xai/explanation/explanation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rex_xai/explanation/explanation.py b/rex_xai/explanation/explanation.py index 2cb923f4..35685492 100644 --- a/rex_xai/explanation/explanation.py +++ b/rex_xai/explanation/explanation.py @@ -155,7 +155,7 @@ def _global(self, map=None, wipe=False): return p.confidence masks = [] - def __generate_circle_coordinates(self, centre, radius: int): + def _generate_circle_coordinates(self, centre, radius: int): assert self.data.model_height is not None assert self.data.model_width is not None Y, X = tt.meshgrid( @@ -174,13 +174,13 @@ def __generate_circle_coordinates(self, centre, radius: int): return circle_mask - def __draw_circle(self, centre, start_radius=None): + def _draw_circle(self, centre, start_radius=None): if start_radius is None: start_radius = self.args.spatial_initial_radius mask = tt.zeros( self.data.model_shape[1:], dtype=tt.bool, device=self.data.device ) - circle_mask = self.__generate_circle_coordinates(centre, start_radius) + circle_mask = self._generate_circle_coordinates(centre, start_radius) if self.data.model_order == "first": mask[:, circle_mask] = True else: @@ -218,7 +218,7 @@ def _spatial(self, centre=None, expansion_limit=None): if centre is None: centre = tt.unravel_index(tt.argmax(map), map.shape) # type: ignore - start_radius, circle, mask = self.__draw_circle(centre) + start_radius, circle, mask = self._draw_circle(centre) if self.args.spotlight_objective_function == "none": masked_responsibility = None @@ -246,7 +246,7 @@ def _spatial(self, centre=None, expansion_limit=None): conf = self._global(map=tt.where(circle, map, 0)) # type: ignore return SpatialSearch.Found, masked_responsibility, conf start_radius = int(start_radius * (1 + self.args.spatial_radius_eta)) - _, circle, _ = self.__draw_circle(centre, start_radius) + _, circle, _ = self._draw_circle(centre, start_radius) if self.data.model_order == "first": mask[:, circle] = True else: From 9fb5764a49f080c571971500f5e40a5bc3dfc8b4 Mon Sep 17 00:00:00 2001 From: "Liz Ing-Simmons (k2474365)" Date: Thu, 13 Mar 2025 13:40:27 +0000 Subject: [PATCH 3/3] fix: add back 'type: ignore' statements for remaining problems unrelated to method inheritance --- rex_xai/explanation/multi_explanation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rex_xai/explanation/multi_explanation.py b/rex_xai/explanation/multi_explanation.py index 0a1122e1..397e2aa7 100644 --- a/rex_xai/explanation/multi_explanation.py +++ b/rex_xai/explanation/multi_explanation.py @@ -251,7 +251,7 @@ def spotlight_search(self, origin=None): else: centre = origin - ret, resp, conf = self._spatial( + ret, resp, conf = self._spatial( # type: ignore centre=centre, expansion_limit=self.args.no_expansions ) @@ -259,7 +259,7 @@ def spotlight_search(self, origin=None): while ret == SpatialSearch.NotFound and steps < self.args.max_spotlight_budget: if self.args.spotlight_objective_function == "none": centre = self.__random_location() - ret, resp, conf = self._spatial( + ret, resp, conf = self._spatial( # type: ignore centre=centre, expansion_limit=self.args.no_expansions ) else: