diff --git a/MANIFEST.in b/MANIFEST.in index 1d5922be3..c70e324fb 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,7 @@ include src/deepforest/conf/config.yaml +include src/deepforest/conf/point.yaml +include src/deepforest/conf/point_pretrain.yaml include src/deepforest/data/testfile_deepforest.csv include src/deepforest/data/testfile_multi.csv include src/deepforest/data/classes.csv diff --git a/pyproject.toml b/pyproject.toml index e86c0c405..2911fcc84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "rasterio", "safetensors<0.6.0", "scikit-image>=0.25.2", + "scipy>=1.15.3", "setuptools", "shapely>2.0.0", "slidingwindow", diff --git a/src/deepforest/conf/point_pretrain.yaml b/src/deepforest/conf/point_pretrain.yaml new file mode 100644 index 000000000..683231d88 --- /dev/null +++ b/src/deepforest/conf/point_pretrain.yaml @@ -0,0 +1,63 @@ +# NEON pretraining config for point model. + +defaults: + - config + - _self_ + +architecture: treeformer + +# Train from ImageNet PvT weights, not the released point checkpoint. +model: + name: null + revision: main + +num_classes: 1 +label_dict: + Tree: 0 + +batch_size: 32 +workers: 8 +accelerator: auto +precision: bf16-mixed +matmul_precision: high +score_thresh: 0.15 +patch_size: 512 + +# Replace the CSV paths here with the training dataset e.g. full NEON pretrain (all_annotations_train_fold_0.csv) +train: + csv_file: MISSING + root_dir: MISSING + lr: 2e-4 + optimizer: + type: AdamW + weight_decay: 1e-5 + epochs: 40 + scheduler: + type: cosine + params: + T_max: 40 + eta_min: 1e-6 + augmentations: + - HorizontalFlip: {p: 0.5} + - VerticalFlip: {p: 0.5} + - Rotate: {degrees: 180, p: 0.5} + +validation: + csv_file: MISSING + root_dir: MISSING + val_accuracy_interval: 1 + +point: + backbone: OpenGVLab/pvt_v2_b3 + score_integration_radius: 2 + distance_threshold: 10.0 + density_sigma: 3.0 + mae_weight: 0.025 + ot_weight: 0.1 + density_l1_weight: 0.05 + count_cls_weight: 0.1 + sinkhorn_reg: 1.0 + num_of_iter_in_ot: 100 + losses: [count, ot, density_l1] + norm_cood: true + enforce_count: false diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index e43ea2f99..7952a388a 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -133,13 +133,34 @@ class CropModelConfig: @dataclass class PointConfig: - """Configuration for point models.""" + """Configuration for point models. + + The loss fields configure training for density/point models such as + TreeFormer; defaults mirror the ``TreeFormerModel`` constructor. + ``losses`` selects the active terms (``None`` enables all of + ``count``, ``ot``, ``density_l1``, ``count_cls``). ``norm_cood`` + normalises OT coordinates to [-1, 1] (global transport) and affects + only the OT loss, not inference; ``enforce_count`` rescales the + density map to a predicted count and does affect inference. + """ backbone: str = "pvt_v2_b3" score_integration_radius: int = 5 nms_distance_thresh: float = 5.0 distance_threshold: float = 10.0 + # Training loss hyperparameters (used by TreeFormer / density models). + density_sigma: float = 5.0 + mae_weight: float = 1.0 + ot_weight: float = 0.1 + density_l1_weight: float = 0.01 + count_cls_weight: float = 1.0 + sinkhorn_reg: float = 1.0 + num_of_iter_in_ot: int = 100 + losses: list[str] | None = None + norm_cood: bool = False + enforce_count: bool = True + @dataclass class Config: diff --git a/src/deepforest/losses/__init__.py b/src/deepforest/losses/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/deepforest/losses/ot_loss.py b/src/deepforest/losses/ot_loss.py new file mode 100644 index 000000000..c2a9a8192 --- /dev/null +++ b/src/deepforest/losses/ot_loss.py @@ -0,0 +1,240 @@ +"""PyTorch implementation of Sinkhorn-Knopp for optimal transport. + +Based on `ot.bregman.sinkhorn` from the Python Optimal Transport library, +https://pythonot.github.io, rewritten in PyTorch and adapted from the +DM-Count/TreeFormer repositories. + +Some small changes have been made to the original code for stability and +logging. + +Original implementation: +https://github.com/cvlab-stonybrook/DM-Count/blob/master/losses/ot_loss.py + +Reference: M. Cuturi, "Sinkhorn Distances: Lightspeed Computation of Optimal +Transport", NeurIPS 2013. +""" + +from typing import cast + +import torch +from torch.nn import Module + +M_EPS = 1e-16 + + +def sinkhorn_knopp( + a: torch.Tensor, + b: torch.Tensor, + C: torch.Tensor, + reg: float = 1e-1, + maxIter: int = 1000, + stopThr: float = 1e-9, + log: bool = False, + eval_freq: int = 10, + warm_start: dict | None = None, +) -> tuple[torch.Tensor, dict] | torch.Tensor: + """Solve entropic-regularized OT via Sinkhorn-Knopp matrix scaling. + + Minimises ``_F + reg * sum(gamma * log(gamma))`` + subject to ``gamma @ 1 = a`` and ``gamma.T @ 1 = b``. + + Args: + a: (na,) source measure (sums to 1). + b: (nb,) target measure (sums to 1). + C: (na, nb) cost matrix. + reg: Entropic regularization strength (> 0). + maxIter: Maximum number of Sinkhorn iterations. + stopThr: Early-stop threshold on marginal error. + log: If True, return ``(P, log_dict)``; otherwise return ``P``. + eval_freq: Check convergence every this many iterations. + warm_start: Optional dict with keys ``"u"`` and ``"v"`` to resume + from a previous solve. + + Returns: + P (na, nb) optimal transport plan, and optionally a log dict + containing ``"u"``, ``"v"``, ``"alpha"``, ``"beta"`` (dual variables) + and ``"err"`` (list of marginal errors). + """ + device = a.device + na, nb = C.shape + + assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b doesn't match C" + assert reg > 0, "reg must be > 0" + + log_dict: dict = {"err": []} if log else {} + + if warm_start is not None: + u = warm_start["u"] + v = warm_start["v"] + else: + u = torch.ones(na, dtype=a.dtype, device=device) / na + v = torch.ones(nb, dtype=b.dtype, device=device) / nb + + K = torch.exp(C / -reg) + + it = 1 + err = 1.0 + while err > stopThr and it <= maxIter: + upre, vpre = u, v + v = b / (torch.mv(K.t(), u) + M_EPS) + u = a / (torch.mv(K, v) + M_EPS) + + if torch.any(torch.isnan(u) | torch.isinf(u)) or torch.any( + torch.isnan(v) | torch.isinf(v) + ): + u, v = upre, vpre + break + + if log and it % eval_freq == 0: + b_hat = torch.mv(K.t(), u) * v + err = (b - b_hat).pow(2).sum().item() + log_dict["err"].append(err) + + it += 1 + + if log: + log_dict["u"] = u + log_dict["v"] = v + log_dict["alpha"] = reg * torch.log(u + M_EPS) + log_dict["beta"] = reg * torch.log(v + M_EPS) + log_dict["its"] = it + log_dict["K_min"] = K.min().item() + # Final marginal error (recompute if not already stored from last eval). + b_hat = torch.mv(K.t(), u) * v + log_dict["sinkhorn_err"] = (b - b_hat).pow(2).sum().item() + + P = u.unsqueeze(1) * K * v.unsqueeze(0) + return (P, log_dict) if log else P + + +class OT_Loss(Module): + def __init__(self, norm_coord, device, num_of_iter_in_ot=100, reg=1.0): + super().__init__() + self.device = device + self.norm_coord = norm_coord + self.num_of_iter_in_ot = num_of_iter_in_ot + self.reg = reg + + @torch.autocast("cuda", enabled=False) + @torch.autocast("cpu", enabled=False) + def forward(self, normed_density, unnormed_density, points): + # Disable AMP autocast and force float32 to prevent float16 + # underflow in Sinkhorn. K = exp(C / -reg) underflows to zero + # in float16 for squared distances > ~11, collapsing the + # transport plan entirely. + normed_density = normed_density.float() + unnormed_density = unnormed_density.float() + points = [p.float() for p in points] + + batch_size = normed_density.size(0) + assert len(points) == batch_size + output_h = normed_density.size(2) + output_w = normed_density.size(3) + + # Define grid coordinates (centered) + x_cood = ( + torch.arange(output_w, dtype=torch.float32, device=self.device) + 0.5 + ).unsqueeze(0) + y_cood = ( + torch.arange(output_h, dtype=torch.float32, device=self.device) + 0.5 + ).unsqueeze(0) + + # Optionally normalize coordinates to [-1, 1]. We generally + # recommend this. + if self.norm_coord: + x_cood = x_cood / output_w * 2 - 1 + y_cood = y_cood / output_h * 2 - 1 + + loss_terms = [] + ot_obj_values = torch.zeros([1]).to(self.device) + wd = 0 # Wasserstein distance + n_active = 0 # Total number of points over all images + total_its = 0 # Accumulated Sinkhorn iterations + total_K_min = 0.0 # Accumulated K_min for diagnostics + total_beta_abs_max = 0.0 # Accumulated max |beta| for divergence canary + total_sinkhorn_err = 0.0 # Accumulated final marginal error + + for idx, im_points in enumerate(points): + if len(im_points) > 0: + n_active += 1 + + # compute l2 square distance, it should be source target distance. [#gt, #cood * #cood] + if self.norm_coord: + x = im_points[:, 0].unsqueeze(1) / output_w * 2 - 1 + y = im_points[:, 1].unsqueeze(1) / output_h * 2 - 1 + else: + x = im_points[:, 0].unsqueeze(1) + y = im_points[:, 1].unsqueeze(1) + + x_dis = ( + -2 * torch.matmul(x, x_cood) + x * x + x_cood * x_cood + ) # [#gt, #cood] + y_dis = -2 * torch.matmul(y, y_cood) + y * y + y_cood * y_cood + y_dis.unsqueeze_(2) + x_dis.unsqueeze_(1) + dis = y_dis + x_dis + dis = dis.view((dis.size(0), -1)) # size of [#gt, #cood * #cood] + + source_prob = normed_density[idx][0].view([-1]).detach() + target_prob = (torch.ones([len(im_points)]) / len(im_points)).to( + self.device + ) + + # use sinkhorn to solve OT, compute optimal beta. + P, log = sinkhorn_knopp( + target_prob, + source_prob, + dis, + self.reg, + maxIter=self.num_of_iter_in_ot, + log=True, + ) + wd += torch.sum(dis * P).item() + + log = cast(dict[str, torch.Tensor], log) + total_its += log["its"] + total_K_min += log["K_min"] + total_sinkhorn_err += log["sinkhorn_err"] + # beta (dual variable = reg * log(v + eps)) + beta = log["beta"] # [#cood * #cood] + total_beta_abs_max = max(total_beta_abs_max, beta.abs().max().item()) + ot_obj_values = ot_obj_values + torch.sum( + normed_density[idx] * beta.view([1, output_h, output_w]) + ) + # compute the gradient of OT loss to predicted density (unnormed_density). + # im_grad = beta / source_count - < beta, source_density> / (source_count)^2 + source_density = unnormed_density[idx][0].view([-1]).detach() + source_count = source_density.sum() + im_grad_1 = ( + (source_count) / (source_count * source_count + 1e-8) * beta + ) # size of [#cood * #cood] + im_grad_2 = (source_density * beta).sum() / ( + source_count * source_count + 1e-8 + ) # size of 1 + im_grad = im_grad_1 - im_grad_2 + im_grad = im_grad.detach().view([1, output_h, output_w]) + # Define loss = . The gradient of loss w.r.t prediced density is im_grad. + loss_terms.append(torch.sum(unnormed_density[idx] * im_grad)) + + if n_active > 0: + loss = torch.stack(loss_terms).sum() / n_active + ot_obj_values = ot_obj_values / n_active + else: + # All images in this batch have zero points. Keep loss connected + # to unnormed_density so DDP gradient buckets fire on every rank; + # the 0.0 multiplier means no actual gradient flows. + loss = 0.0 * unnormed_density.sum() + + avg_its = total_its / n_active if n_active > 0 else 0 + avg_K_min = total_K_min / n_active if n_active > 0 else 0.0 + avg_beta_abs_max = total_beta_abs_max / n_active if n_active > 0 else 0.0 + avg_sinkhorn_err = total_sinkhorn_err / n_active if n_active > 0 else 0.0 + return ( + loss, + wd, + ot_obj_values, + avg_its, + avg_K_min, + avg_beta_abs_max, + avg_sinkhorn_err, + ) diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 4ff14bca5..41da9ae76 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -805,8 +805,25 @@ def training_step(self, batch, batch_idx): images, targets, image_names = batch loss_dict = self.model.forward(images, targets) - # sum of regression and classification loss - losses = sum(loss_dict.values()) + # Models that compute a combined total (e.g. TreeFormer) return it under + # "loss" alongside diagnostics. We return both so that we can log via Lightning + # but we don't wantt to backprop. + losses = loss_dict["loss"] if "loss" in loss_dict else sum(loss_dict.values()) + + # Guard against non-finite losses + if not torch.isfinite(losses): + non_finite = { + k: v.item() + for k, v in loss_dict.items() + if torch.is_tensor(v) and not torch.isfinite(v) + } + warnings.warn( + f"Non-finite loss {losses.item()}; non-finite components: " + f"{non_finite}. Skipping batch.", + stacklevel=2, + ) + trainable = [p for p in self.model.parameters() if p.requires_grad] + return 0.0 * sum(p.sum() for p in trainable) if trainable else None # Log loss for key, value in loss_dict.items(): @@ -839,8 +856,8 @@ def validation_step(self, batch, batch_idx): with torch.no_grad(): loss_dict = self.model.forward(images, targets) - # sum of regression and classification loss - losses = sum(loss_dict.values()) + # Similar to train_step, only backprop through loss terms. + losses = loss_dict["loss"] if "loss" in loss_dict else sum(loss_dict.values()) # Log losses for key, value in loss_dict.items(): diff --git a/src/deepforest/models/treeformer.py b/src/deepforest/models/treeformer.py index 3e3222075..fde428415 100644 --- a/src/deepforest/models/treeformer.py +++ b/src/deepforest/models/treeformer.py @@ -6,12 +6,15 @@ See: 10.1109/TGRS.2023.3295802 """ +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import PyTorchModelHubMixin +from scipy.ndimage import gaussian_filter from transformers import AutoConfig, AutoImageProcessor, PvtV2Model +from deepforest.losses.ot_loss import OT_Loss from deepforest.model import BaseModel from deepforest.models.treeformer_decoder import Regression from deepforest.utilities import density_to_points @@ -69,7 +72,7 @@ def __init__( do_resize=False, ) - # Instaniate architecture but don't pull weights + # Instantiate architecture but don't pull weights self.backbone = PvtV2Model(AutoConfig.from_pretrained(backbone)) self.backbone.gradient_checkpointing_enable( @@ -125,6 +128,10 @@ def __init__( if losses is not None else ["count", "ot", "density_l1", "count_cls"] ) + self.active_losses = set(self.losses) + + # OT_Loss is created lazily on first use, once the device is known. + self._ot_loss: OT_Loss | None = None self.kwargs = kwargs self.update_config() @@ -154,24 +161,53 @@ def update_config(self): def device(self) -> torch.device: return next(self.parameters()).device + def _get_ot_loss(self) -> OT_Loss: + """Return the optimal-transport loss, created on first call with the + current device.""" + if self._ot_loss is None: + self._ot_loss = OT_Loss( + self.norm_cood, + self.device, + self.ot_iter, + self.sinkhorn_reg, + ) + return self._ot_loss + def _normalize_density( - self, score_map: torch.Tensor, cls_count: torch.Tensor + self, + score_map: torch.Tensor, + cls_count: torch.Tensor | None, + gt_count: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Return (density_map, normed_density) from raw score map and count - scalar. - - If the model is in enforce_count mode, the raw density map is - scaled to match the count prediction; otherwise it's returned - as-is. + """Return (density_map, normed_density) from a raw score map and count. + + With ``enforce_count`` the density map is rescaled so its sum matches a + count: during training the ground-truth ``gt_count`` is used when + provided, otherwise the CLS prediction ``cls_count``. ``cls_count`` may + be ``None`` when ``gt_count`` is always supplied (the training path). + Without ``enforce_count`` the raw ``score_map`` is returned unscaled. + ``normed_density`` is the spatially normalised map (sums to 1), used by + the OT and density-L1 losses. """ B = score_map.size(0) score_sum = score_map.view(B, -1).sum(1).view(B, 1, 1, 1) normed = score_map / (score_sum + 1e-4) if self.enforce_count: - count = cls_count.view(B, 1, 1, 1).abs().clamp(min=1e-4) + if self.training and gt_count is not None: + count = gt_count.view(B, 1, 1, 1).clamp(min=1e-4) + else: + count = cls_count.view(B, 1, 1, 1).abs().clamp(min=1e-4) return normed * count, normed return score_map, normed + def _count_areas(self, image_shapes: list[tuple[int, int]]) -> torch.Tensor: + """Per-image pixel areas, used to convert between density and count.""" + return torch.tensor( + [image_h * image_w for image_h, image_w in image_shapes], + dtype=torch.float32, + device=self.device, + ) + def _output_shapes( self, image_shapes: list[tuple[int, int]] ) -> list[tuple[int, int]]: @@ -184,6 +220,21 @@ def _output_shapes( for image_h, image_w in image_shapes ] + def _build_output_mask( + self, + image_shapes: list[tuple[int, int]], + out_h: int, + out_w: int, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """Mask padded decoder outputs so losses ignore batch padding.""" + mask = torch.zeros( + len(image_shapes), 1, out_h, out_w, device=self.device, dtype=dtype + ) + for index, (valid_h, valid_w) in enumerate(self._output_shapes(image_shapes)): + mask[index, :, :valid_h, :valid_w] = 1.0 + return mask + def _cls_outputs_to_count( self, cls_output: torch.Tensor, @@ -192,12 +243,44 @@ def _cls_outputs_to_count( """Convert CLS-head count-density prediction to absolute count by multiplying by the area of each image.""" count_density = cls_output.reshape(len(image_shapes)) - count_areas = torch.tensor( - [image_h * image_w for image_h, image_w in image_shapes], - dtype=torch.float32, - device=self.device, - ) - return count_density * count_areas + return count_density * self._count_areas(image_shapes) + + def _scale_points_to_output( + self, + points: list, + image_shapes: list[tuple[int, int]], + output_shapes: list[tuple[int, int]], + ) -> list: + """Scale image-space point coordinates into output-map coordinates.""" + scaled_points = [] + for p, (image_h, image_w), (out_h, out_w) in zip( + points, image_shapes, output_shapes, strict=True + ): + scale = torch.tensor( + [out_w / image_w, out_h / image_h], + dtype=torch.float32, + device=self.device, + ) + if len(p) == 0: + scaled_points.append(p.clone()) + else: + scaled_points.append(p.to(dtype=torch.float32) * scale) + return scaled_points + + def _cast_points(self, targets: list) -> list[torch.Tensor]: + """Cast training targets to a list of (N, 2) point tensors.""" + points = [] + for target in targets: + if isinstance(target, dict): + point_tensor = target.get("points") + if point_tensor is None: + raise ValueError("Each target dict must include a 'points' entry") + else: + point_tensor = target + if not isinstance(point_tensor, torch.Tensor): + point_tensor = torch.as_tensor(point_tensor, dtype=torch.float32) + points.append(point_tensor.to(device=self.device, dtype=torch.float32)) + return points def _forward_features(self, x: torch.Tensor): """Run backbone and project stage outputs to REG_DIMS. @@ -219,11 +302,11 @@ def postprocess_density(self, density_map, images) -> list[dict]: prediction dataloader which returns a list via its collate_fn). Args: - density_outputs: tuple ``(density_map, normed_density)`` from - ``TreeFormerModel.forward``. When ``images`` is a list, - ``density_map`` is also a list of per-image tensors. + density_map: list of per-image density tensors (each + ``(1, 1, H, W)``) or a ``(B, 1, H, W)`` batch tensor. images: batch tensor ``(B, C, H_img, W_img)`` or list of - ``(C, H_img, W_img)`` tensors. + ``(C, H_img, W_img)`` tensors, used to rescale points to + image-pixel coordinates. Returns: List of dicts with ``"points"`` (N, 2), ``"scores"`` (N,), @@ -250,42 +333,187 @@ def postprocess_density(self, density_map, images) -> list[dict]: return preds + def rasterize_points( + self, im_height: int, im_width: int, points: np.ndarray + ) -> np.ndarray: + """Rasterize (N, 2) (x, y) points (in output-map space) into an impulse + map of per-pixel point counts.""" + discrete_map = np.zeros([im_height, im_width], dtype=np.float32) + points_np = np.asarray(points, dtype=np.float32).reshape(-1, 2) + points_np = np.rint(points_np).astype(int) + p_h = np.clip(points_np[:, 1], 0, im_height - 1) + p_w = np.clip(points_np[:, 0], 0, im_width - 1) + np.add.at(discrete_map, (p_h, p_w), 1) + return discrete_map + + def _make_gt_density(self, points: list, out_h: int, out_w: int) -> torch.Tensor: + """Build a batched Gaussian density map (B, 1, H, W) from point lists. + + Rasterizes points to impulses, applies a Gaussian blur with + ``density_sigma``, and rescales so the smoothed map preserves + the discrete point count. + """ + sigma = self.density_sigma + maps = [] + for p in points: + p_np = p.cpu().numpy() + discrete = self.rasterize_points(out_h, out_w, p_np) + if discrete.sum() > 0: + smoothed = gaussian_filter(discrete, sigma=sigma) + smoothed = smoothed * (discrete.sum() / smoothed.sum()) + else: + smoothed = discrete + maps.append(smoothed) + return torch.from_numpy(np.stack(maps)).unsqueeze(1).float().to(self.device) + def compute_loss( self, - density_maps: list[torch.Tensor], - normed_density: list[torch.Tensor], + density_maps: list, + normed_density: torch.Tensor, cls_outputs: list, targets: list, image_shapes: list[tuple[int, int]], + gt_density: torch.Tensor | None = None, + points: list | None = None, ) -> dict: - """Compute training loss: L1 between density-map count and GT count. + """Compute the supervised TreeFormer/DM-Count losses. + + Spatial structure is trained by optimal transport (``ot``) and a + pixel-wise L1 against a Gaussian GT density (``density_l1``). The total + count is trained by MAE on the density-map sum (``count``, in log + space) and optionally an auxiliary CLS count head (``count_cls``). The + CLS head predicts count density (count / area) so it transfers across + image sizes. Individual terms are enabled via ``self.active_losses``. Args: - density_maps: list of multi-scale density maps; index 0 is primary. - normed_density: spatially normalised version of density_maps[0]. - cls_outputs: CLS-head count predictions (unused here). - targets: list of target dicts with ``"points"`` tensors. - image_shapes: original ``(H, W)`` for each image. + density_maps: ``[y0, y1, y2]`` multi-scale outputs; index 0 primary. + normed_density: ``density_maps[0]`` normalised to sum to 1, (B,1,H,W). + cls_outputs: ``[yc0, yc1, yc2]`` count-density scalars from CLS head. + targets: list of B target dicts (or point tensors) in image space. + image_shapes: original ``(H, W)`` for each image in the batch. + gt_density: optional precomputed ``(B, 1, H, W)`` Gaussian density; + computed from points when ``None``. + points: optional pre-cast list of per-image point tensors; cast from + ``targets`` when ``None``. Returns: - dict with ``"loss"`` scalar tensor (supports backprop). + dict with ``"loss"`` (total) plus individual named terms/diagnostics. """ - true_counts = torch.tensor( - [float(len(t["points"])) for t in targets], - dtype=torch.float32, - device=self.device, + density_map = density_maps[0] # (B, 1, H', W') + B, _, H, W = density_map.shape + if points is None: + points = self._cast_points(targets) + output_shapes = self._output_shapes(image_shapes) + areas = self._count_areas(image_shapes) + + point_counts = torch.tensor( + [len(p) for p in points], dtype=torch.float32, device=self.device ) + scaled_points = self._scale_points_to_output(points, image_shapes, output_shapes) + + if gt_density is None: + gt_density = self._make_gt_density(scaled_points, H, W) - # TODO: This is a placeholder to verify integration, full loss - # computation as used with the model releases will be implemented - # in a future PR. - pred_counts = self._cls_outputs_to_count(cls_outputs[0], image_shapes) - count_loss = self.cls_l1(pred_counts, true_counts) * self.mae_weight + active = self.active_losses + zero = density_map.new_zeros(()) + pred_sum = density_map.view(B, -1).sum(1) - return {"loss": count_loss} + # Unweighted count diagnostics, so calibration is visible in the logs. + count_mae = self.cls_l1(pred_sum, point_counts) + cls_preds = torch.stack([c.reshape(B) for c in cls_outputs]) # (3, B) + gt_counts = point_counts.unsqueeze(0).expand(3, -1) # (3, B) + + # ---- MAE count loss ------------------------------------------------- + if "count" in active: + count_loss = ( + self.cls_l1(torch.log1p(pred_sum), torch.log1p(point_counts)) + * self.mae_weight + ) + else: + count_loss = zero + + # ---- Optimal transport loss ----------------------------------------- + if "ot" in active: + ( + ot_raw, + ot_wd_val, + _, + ot_avg_its, + ot_K_min, + ot_beta_abs_max, + ot_sinkhorn_err, + ) = self._get_ot_loss()(normed_density, density_map, scaled_points) + ot_loss = ot_raw * self.ot_weight + # Wasserstein distance + Sinkhorn diagnostics (not backpropagated). + ot_wd = torch.tensor( + ot_wd_val, device=density_map.device, dtype=torch.float32 + ) + sinkhorn_its = torch.tensor( + ot_avg_its, device=density_map.device, dtype=torch.float32 + ) + sinkhorn_K_min = torch.tensor( + ot_K_min, device=density_map.device, dtype=torch.float32 + ) + sinkhorn_beta_abs_max = torch.tensor( + ot_beta_abs_max, device=density_map.device, dtype=torch.float32 + ) + sinkhorn_err = torch.tensor( + ot_sinkhorn_err, device=density_map.device, dtype=torch.float32 + ) + else: + ot_loss = zero + ot_wd = zero + sinkhorn_its = zero + sinkhorn_K_min = zero + sinkhorn_beta_abs_max = zero + sinkhorn_err = zero + + # ---- Density pixel-wise L1 loss between normalised density maps - + if "density_l1" in active: + gt_density_normed = gt_density / (point_counts.view(B, 1, 1, 1) + 1e-4) + per_pixel = self.density_l1(normed_density, gt_density_normed) + density_l1_loss = ( + per_pixel.sum(dim=[1, 2, 3]).mul(point_counts).mean() + * self.density_l1_weight + ) + else: + density_l1_loss = zero + + # ---- Auxiliary CLS count regression in density space --------- + if "count_cls" in active: + gt_counts_normed = gt_counts / areas.unsqueeze(0) + count_cls_loss = ( + self.cls_l1(cls_preds, gt_counts_normed) * self.count_cls_weight + ) + cls_pred_counts = cls_preds * areas.unsqueeze(0) + count_cls_mae = self.cls_l1(cls_pred_counts, gt_counts) + else: + count_cls_loss = zero + + total = count_loss + ot_loss + density_l1_loss + count_cls_loss + result = { + "loss": total, + "count_mae": count_mae, + "count_loss": count_loss, + "ot_loss": ot_loss, + "ot_wd": ot_wd, + "sinkhorn_its": sinkhorn_its, + "sinkhorn_K_min": sinkhorn_K_min, + "sinkhorn_beta_abs_max": sinkhorn_beta_abs_max, + "sinkhorn_err": sinkhorn_err, + "density_l1_loss": density_l1_loss, + } + if "count_cls" in active: + result["count_cls_mae"] = count_cls_mae + result["count_cls_loss"] = count_cls_loss + + return result def forward( - self, inputs: torch.Tensor | list[torch.Tensor], targets: list | None = None + self, + inputs: torch.Tensor | list[torch.Tensor], + targets: list | None = None, + gt_density: torch.Tensor | None = None, ): """Forward pass. @@ -295,6 +523,8 @@ def forward( Args: inputs: ``(B, C, H, W)`` tensor or list of ``(C, H, W)`` tensors. targets: list of target dicts (required during training). + gt_density: optional precomputed ``(B, 1, H', W')`` Gaussian density + target for the density-L1 loss; computed from points if omitted. """ # Batch-pad variable-size images; record original sizes. if isinstance(inputs, list): @@ -322,24 +552,43 @@ def forward( label_feats, l_cls = self._forward_features(encoded) out_L, out_cls_l = self.regression(label_feats, l_cls) - # Crop each output to its valid spatial extent and normalise. - # Cropping naturally excludes batch padding, which is correct for both - # training losses and inference peak-finding. - density_list, normed_list = [], [] - for i, (valid_h, valid_w) in enumerate(self._output_shapes(shapes)): - crop = out_L[0][i : i + 1, :, :valid_h, :valid_w].contiguous() - cls_count = self._cls_outputs_to_count(out_cls_l[0][i : i + 1], [shapes[i]]) - dm, nd = self._normalize_density(crop, cls_count) - density_list.append(dm) - normed_list.append(nd) - if self.training: if targets is None: raise ValueError("targets must be provided in training mode") + # Train on the full batched output, masking the padded region. + output_mask = self._build_output_mask( + shapes, out_L[0].shape[-2], out_L[0].shape[-1], dtype=out_L[0].dtype + ) + points = self._cast_points(targets) + gt_counts = torch.tensor( + [len(p) for p in points], dtype=torch.float32, device=self.device + ) + # In training the map is scaled by the ground-truth counts when + # enforce_count is set, and left raw otherwise. The CLS count + # prediction is only used at inference, so pass None here. + density_map, label_normed = self._normalize_density( + out_L[0] * output_mask, None, gt_count=gt_counts + ) return self.compute_loss( - density_list, normed_list, out_cls_l, targets, image_shapes=shapes + [density_map] + out_L[1:], + label_normed, + out_cls_l, + targets, + image_shapes=shapes, + gt_density=gt_density, + points=points, ) + # Eval: crop each output to its valid spatial extent and normalise. + # Cropping naturally excludes batch padding, which is required for + # inference peak-finding. + density_list = [] + for i, (valid_h, valid_w) in enumerate(self._output_shapes(shapes)): + crop = out_L[0][i : i + 1, :, :valid_h, :valid_w].contiguous() + cls_count = self._cls_outputs_to_count(out_cls_l[0][i : i + 1], [shapes[i]]) + dm, _ = self._normalize_density(crop, cls_count) + density_list.append(dm) + return self.postprocess_density(density_list, batch) @@ -372,6 +621,23 @@ def create_model( ) backbone = self.config.point.backbone + # Loss/normalisation hyperparameters from config, passed only when + # training from scratch. A loaded checkpoint keeps its own saved config. + point_cfg = self.config.point + loss_kwargs = { + "density_sigma": point_cfg.density_sigma, + "mae_weight": point_cfg.mae_weight, + "ot_weight": point_cfg.ot_weight, + "density_l1_weight": point_cfg.density_l1_weight, + "count_cls_weight": point_cfg.count_cls_weight, + "sinkhorn_reg": point_cfg.sinkhorn_reg, + "num_of_iter_in_ot": point_cfg.num_of_iter_in_ot, + "losses": list(point_cfg.losses) if point_cfg.losses is not None else None, + "norm_cood": point_cfg.norm_cood, + "enforce_count": point_cfg.enforce_count, + } + scratch_args = {**loss_kwargs, **hf_args} + # Load fully trained backbone + head from hub if pretrained: if label_dict is not None: @@ -393,7 +659,7 @@ def create_model( label_dict=label_dict, score_thresh=self.config.score_thresh, score_integration_radius=self.config.point.score_integration_radius, - **hf_args, + **scratch_args, ) model.backbone = PvtV2Model.from_pretrained( model.backbone_name, ignore_mismatched_sizes=True @@ -405,7 +671,7 @@ def create_model( label_dict=label_dict, score_thresh=self.config.score_thresh, score_integration_radius=self.config.point.score_integration_radius, - **hf_args, + **scratch_args, ) if map_location is not None: diff --git a/src/deepforest/visualize.py b/src/deepforest/visualize.py index a74cdc1c0..bf657ad24 100644 --- a/src/deepforest/visualize.py +++ b/src/deepforest/visualize.py @@ -478,8 +478,22 @@ def _plot_image_with_geometry(df, image, sv_color, thickness=1, radius=3): detections=detections, ) elif geom_type == "point": - point_annotator = sv.VertexAnnotator(color=sv_color, radius=radius) + # VertexAnnotator can't take a palette, so we use circle annotations instead. + # Convert points to "boxes" for visualization + xy = detections.xy.reshape(-1, 2) + xyxy = np.stack( + [xy[:, 0] - radius, xy[:, 1] - radius, xy[:, 0] + radius, xy[:, 1] + radius], + axis=1, + ) + point_detections = sv.Detections( + xyxy=xyxy, + class_id=detections.class_id, + confidence=detections.confidence.reshape(-1) + if detections.confidence is not None + else None, + ) + point_annotator = sv.CircleAnnotator(color=sv_color, thickness=thickness) annotated_frame = point_annotator.annotate( - scene=image.copy(), key_points=detections + scene=image.copy(), detections=point_detections ) return annotated_frame diff --git a/tests/conftest.py b/tests/conftest.py index d64dd156c..d7b34273d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -93,3 +93,21 @@ def m(download_release, tmp_path_factory): m.load_model("weecology/deepforest-tree") return m + +@pytest.fixture(scope="session") +def m_point(download_release, tmp_path_factory): + m = main.deepforest(config="point") + m.config.train.csv_file = get_data("example.csv") + m.config.train.root_dir = os.path.dirname(get_data("example.csv")) + m.config.train.fast_dev_run = True + m.config.batch_size = 2 + m.config.validation.csv_file = get_data("example.csv") + m.config.validation.root_dir = os.path.dirname(get_data("example.csv")) + m.config.workers = 0 + m.config.validation.val_accuracy_interval = 1 + m.config.train.epochs = 2 + m.config.log_root = str(tmp_path_factory.mktemp("logs")) + + m.create_trainer() + + return m diff --git a/tests/test_point.py b/tests/test_point.py index d9b651d2f..8f9e14dd6 100644 --- a/tests/test_point.py +++ b/tests/test_point.py @@ -78,8 +78,10 @@ def test_eval_single(point_model): assert 0.3 <= metrics["point_recall"] <= 1.0, f"Expected point_recall to be in [0.3, 1.0], got {metrics['point_recall']:.2f}" def test_treeformer_forward_pass_train(): - """Training forward pass returns a loss dict with a scalar, differentiable loss.""" - model = TreeFormerModel(backbone="pvt_v2_b0", num_classes=1) + """Training forward pass returns the full loss dict; the loss is a scalar, + differentiable, and gradients reach the backbone and regression head.""" + # norm_cood=True gives non-degenerate (global) OT, matching the NEON config. + model = TreeFormerModel(backbone="pvt_v2_b0", num_classes=1, norm_cood=True) model.train() images = torch.rand(2, 3, 128, 128) @@ -91,9 +93,29 @@ def test_treeformer_forward_pass_train(): output = model(images, targets) assert isinstance(output, dict) - assert "loss" in output - assert output["loss"].ndim == 0 - assert output["loss"].requires_grad + for key in ("loss", "count_loss", "ot_loss", "density_l1_loss", "count_cls_loss"): + assert key in output, f"missing loss term: {key}" + + loss = output["loss"] + assert loss.ndim == 0 + assert loss.requires_grad + assert torch.isfinite(loss) + + # With norm_cood=True the OT term is non-degenerate: finite Wasserstein + # distance and Sinkhorn converges within the iteration budget. + assert torch.isfinite(output["ot_wd"]) + assert output["sinkhorn_its"] < model.ot_iter + + loss.backward() + backbone_grads = regression_grads = 0 + for name, param in model.named_parameters(): + if param.grad is None: + continue + assert torch.isfinite(param.grad).all(), f"non-finite grad in {name}" + backbone_grads += name.startswith("backbone") + regression_grads += name.startswith("regression") + assert backbone_grads > 0, "no gradient reached the backbone" + assert regression_grads > 0, "no gradient reached the regression head" def test_treeformer_forward_pass_val(): @@ -117,3 +139,28 @@ def test_treeformer_forward_pass_val(): assert pred["points"].shape == (n, 2) assert pred["scores"].shape == (n,) assert pred["labels"].shape == (n,) + + +def test_treeformer_loss_terms_toggle(): + """active_losses gating: disabled terms are zero, count_cls is omitted, and + the total equals the single active term.""" + model = TreeFormerModel( + backbone="pvt_v2_b0", num_classes=1, losses=["count"], enforce_count=False + ) + model.train() + + images = torch.rand(2, 3, 96, 96) + targets = [ + {"points": torch.rand(4, 2) * 96, "labels": torch.zeros(4, dtype=torch.int64)}, + {"points": torch.rand(2, 2) * 96, "labels": torch.zeros(2, dtype=torch.int64)}, + ] + + output = model(images, targets) + + assert output["ot_loss"].item() == 0.0 + assert output["density_l1_loss"].item() == 0.0 + assert "count_cls_loss" not in output + assert torch.allclose(output["loss"], output["count_loss"]) + assert output["loss"].requires_grad + assert torch.isfinite(output["loss"]) + output["loss"].backward() diff --git a/tests/test_visualize.py b/tests/test_visualize.py index dbf599afa..75a24dcfa 100644 --- a/tests/test_visualize.py +++ b/tests/test_visualize.py @@ -60,13 +60,19 @@ def gdf_box(): return gdf -def test_predict_image_and_plot(m, tmp_path): +def test_predict_boxes_and_plot(m, tmp_path): sample_image_path = get_data("OSBS_029.png") results = m.predict_image(path=sample_image_path) visualize.plot_results(results, savedir=tmp_path) assert os.path.exists(os.path.join(tmp_path, "OSBS_029.png")) +def test_predict_points_and_plot(m_point, tmp_path): + sample_image_path = get_data("OSBS_029.png") + results = m_point.predict_image(path=sample_image_path) + visualize.plot_results(results, savedir=tmp_path) + + assert os.path.exists(os.path.join(tmp_path, "OSBS_029.png")) def test_predict_tile_and_plot(m, tmp_path): sample_image_path = get_data("OSBS_029.png")