Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions rex_xai/explanation/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ def __repr__(self) -> str:
def extract(self, method: Strategy):
self.blank()
if method == Strategy.Global:
self.__global()
self._global()
if method == Strategy.Spatial:
if self.data.mode == "spectral":
logger.warning(
"spatial search not yet implemented for spectral data, so defaulting to global search"
)
self.__global()
self._global()
else:
_ = self.__spatial()
_ = self._spatial()

if isinstance(self.final_mask, tt.Tensor):
self.final_mask = self.final_mask.detach().cpu().numpy()
Expand All @@ -115,7 +115,7 @@ def set_to_true(self, coords, mask=None):
mask, self.data.mode, self.data.model_order, coords
)

def __global(self, map=None, wipe=False):
def _global(self, map=None, wipe=False):
if map is None:
map = self.target_map
ranking = get_map_locations(map)
Expand Down Expand Up @@ -155,7 +155,7 @@ def __global(self, map=None, wipe=False):
return p.confidence
masks = []

def __generate_circle_coordinates(self, centre, radius: int):
def _generate_circle_coordinates(self, centre, radius: int):
assert self.data.model_height is not None
assert self.data.model_width is not None
Y, X = tt.meshgrid(
Expand All @@ -174,13 +174,13 @@ def __generate_circle_coordinates(self, centre, radius: int):

return circle_mask

def __draw_circle(self, centre, start_radius=None):
def _draw_circle(self, centre, start_radius=None):
if start_radius is None:
start_radius = self.args.spatial_initial_radius
mask = tt.zeros(
self.data.model_shape[1:], dtype=tt.bool, device=self.data.device
)
circle_mask = self.__generate_circle_coordinates(centre, start_radius)
circle_mask = self._generate_circle_coordinates(centre, start_radius)
if self.data.model_order == "first":
mask[:, circle_mask] = True
else:
Expand Down Expand Up @@ -212,13 +212,13 @@ def compute_masked_responsibility(self, mask):
)
return tt.mean(masked_responsibility).item()

def __spatial(self, centre=None, expansion_limit=None):
def _spatial(self, centre=None, expansion_limit=None):
# we don't have a search location to start from, so we try to isolate one
map = self.target_map
if centre is None:
centre = tt.unravel_index(tt.argmax(map), map.shape) # type: ignore

start_radius, circle, mask = self.__draw_circle(centre)
start_radius, circle, mask = self._draw_circle(centre)

if self.args.spotlight_objective_function == "none":
masked_responsibility = None
Expand All @@ -243,10 +243,10 @@ def __spatial(self, centre=None, expansion_limit=None):
and p.confidence
>= self.data.target.confidence * self.args.minimum_confidence_threshold # type: ignore
):
conf = self.__global(map=tt.where(circle, map, 0)) # type: ignore
conf = self._global(map=tt.where(circle, map, 0)) # type: ignore
return SpatialSearch.Found, masked_responsibility, conf
start_radius = int(start_radius * (1 + self.args.spatial_radius_eta))
_, circle, _ = self.__draw_circle(centre, start_radius)
_, circle, _ = self._draw_circle(centre, start_radius)
if self.data.model_order == "first":
mask[:, circle] = True
else:
Expand Down
10 changes: 5 additions & 5 deletions rex_xai/explanation/multi_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def extract(self, method=None):
self.blank()
# we start with the global max explanation
logger.info("spotlight number 1 (global max)")
conf = self._Explanation__global() # type: ignore
conf = self._global()
if self.final_mask is not None:
self.explanations.append(self.final_mask)
self.explanation_confidences.append(conf)
Expand Down Expand Up @@ -251,15 +251,15 @@ def spotlight_search(self, origin=None):
else:
centre = origin

ret, resp, conf = self._Explanation__spatial( # type: ignore
ret, resp, conf = self._spatial( # type: ignore
centre=centre, expansion_limit=self.args.no_expansions
)

steps = 0
while ret == SpatialSearch.NotFound and steps < self.args.max_spotlight_budget:
if self.args.spotlight_objective_function == "none":
centre = self.__random_location()
ret, resp, conf = self._Explanation__spatial( # type: ignore
ret, resp, conf = self._spatial( # type: ignore
centre=centre, expansion_limit=self.args.no_expansions
)
else:
Expand All @@ -271,12 +271,12 @@ def spotlight_search(self, origin=None):
self.data.model_width,
step=self.args.spotlight_step,
)
ret, new_resp, conf = self._Explanation__spatial( # type: ignore
ret, new_resp, conf = self._spatial( # type: ignore
centre=centre, expansion_limit=self.args.no_expansions
)
if ret == SpatialSearch.Found:
return conf
ret, resp, conf = self._Explanation__spatial( # type: ignore
ret, resp, conf = self._spatial( # type: ignore
centre=centre, expansion_limit=self.args.no_expansions
)
steps += 1
Expand Down