diff --git a/docs/user_guide/13_annotation.md b/docs/user_guide/13_annotation.md index 543836fb9..011c90d6b 100644 --- a/docs/user_guide/13_annotation.md +++ b/docs/user_guide/13_annotation.md @@ -73,7 +73,7 @@ Yes! Object detection models use non-annotated areas of an image as negative dat 1. **Select Important Images**: Duplicate backgrounds or objects contribute little to model generalization. Focus on gathering a wide variety of object appearances. 2. **Avoid Over-splitting Labels**: Often, using a superclass for detection followed by a separate model for classification is more effective. See the [`CropModel`](03_cropmodels.md) for an example. -3. **Balance Accuracy and Practicality**: Depending on the goal (e.g., object counting or detection), keypoints can sometimes be used instead of precise boxes to simplify the process. +3. **Balance Accuracy and Practicality**: Depending on the goal (e.g., object counting or detection), point/keypoint annotations can sometimes be used instead of precise boxes to simplify the process. ## Quick Video on Annotating Images diff --git a/pyproject.toml b/pyproject.toml index 5155df2a0..6b8f8ca63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "pyyaml>=5.1.0", "rasterio", "safetensors<0.6.0", + "scikit-image>=0.25.2", "setuptools", "shapely>2.0.0", "slidingwindow", diff --git a/src/deepforest/conf/config.yaml b/src/deepforest/conf/config.yaml index fa93a94c4..63ef1a137 100644 --- a/src/deepforest/conf/config.yaml +++ b/src/deepforest/conf/config.yaml @@ -142,3 +142,8 @@ cropmodel: normalize: # Number of pixels to expand bbox crop windows for better prediction context. expand: 0 + +point: + score_integration_radius: 5 + nms_distance_thresh: 5.0 + distance_threshold: 10 diff --git a/src/deepforest/conf/point.yaml b/src/deepforest/conf/point.yaml new file mode 100644 index 000000000..b7163ddce --- /dev/null +++ b/src/deepforest/conf/point.yaml @@ -0,0 +1,22 @@ +architecture: treeformer + +model: + name: 'weecology/deepforest-tree-point' + revision: 'main' + +# This model was trained on 10 cm/px data and has been tested +# at resolutions around 5-20 cm/px. If your data is significantly +# different, for example < 5 cm/px drone data, we suggest first +# downsampling your orthomosaics. +patch_size: 512 +patch_overlap: 0.1 + +# Setting a threshold below 0.3 is likely to include a lot of +# noisy/low-confidence predictions. +score_thresh: 0.3 + +point: + backbone: 'pvt_v2_b3' + score_integration_radius: 5 + nms_distance_thresh: 5.0 + distance_threshold: 10.0 diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index 0ef3b4cbc..b89254379 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -41,7 +41,7 @@ class SchedulerConfig: """Set the type of scheduler, by default DeepForest uses a stepped learning function reducing at "milestones" during training.""" - type: str | None = "StepLR" + type: str | None = "stepLR" params: SchedulerParamsConfig = field(default_factory=SchedulerParamsConfig) @@ -57,8 +57,10 @@ class OptimizerConfig: @dataclass class TrainConfig: - """Main training configuration. The CSV file and root directory are - required to specify the location of the training dataset. + """Main training configuration. + + The CSV file and root directory are required to specify the location + of the training dataset. The default learning rate may need to be changed for certain architectures, such as transformers-based models which sometimes @@ -84,8 +86,10 @@ class TrainConfig: @dataclass class ValidationConfig: - """Main validation configuration. As with training data, it's required that - you set a CSV file and root directory. + """Main validation configuration. + + As with training data, it's required that you set a CSV file and + root directory. Validation during training is important to identify if the model has converged or is overfitting. @@ -125,11 +129,22 @@ class CropModelConfig: expand: int = 0 +@dataclass +class PointConfig: + """Configuration for point models.""" + + backbone: str = "pvt_v2_b3" + score_integration_radius: int = 5 + nms_distance_thresh: float = 5.0 + distance_threshold: float = 10.0 + + @dataclass class Config: - """General DeepForest configuration. Some parameters here are shared - between dataloaders, for example the batch size, accelerator and number of - workers. + """General DeepForest configuration. + + Some parameters here are shared between dataloaders, for example the + batch size, accelerator and number of workers. Here we also set the architecture, which can be one of "retinanet" or "DeformableDetr" currently. If you modify the number of classes @@ -173,3 +188,4 @@ class Config: validation: ValidationConfig = field(default_factory=ValidationConfig) predict: PredictConfig = field(default_factory=PredictConfig) cropmodel: CropModelConfig = field(default_factory=CropModelConfig) + point: PointConfig = field(default_factory=PointConfig) diff --git a/src/deepforest/datasets/training.py b/src/deepforest/datasets/training.py index 7ed188f45..8c2a4ccb5 100644 --- a/src/deepforest/datasets/training.py +++ b/src/deepforest/datasets/training.py @@ -90,9 +90,10 @@ def _validate_labels(self) -> None: @abstractmethod def _validate_coordinates(self) -> None: - """Validate geometries in the annotation data. Must be overidden by - child classes to implement task-specific checks (e.g., boxes vs - points). + """Validate geometries in the annotation data. + + Must be overidden by child classes to implement task-specific + checks (e.g., boxes vs points). Should raise ValueError with details if any invalid geometries are found. @@ -285,8 +286,8 @@ def __getitem__(self, idx) -> tuple: return image, targets, self.image_names[idx] -class KeypointDataset(TrainingDataset): - """Dataset for keypoint detection tasks.""" +class PointDataset(TrainingDataset): + """Dataset for point detection tasks.""" _data_keys = [DataKey.IMAGE, DataKey.KEYPOINTS] diff --git a/src/deepforest/main.py b/src/deepforest/main.py index f5e3f0ac2..4d8d53988 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -7,6 +7,7 @@ import pandas as pd import pytorch_lightning as pl import torch +import torchmetrics from omegaconf import DictConfig, OmegaConf from PIL import Image from pytorch_lightning.callbacks import LearningRateMonitor @@ -102,10 +103,19 @@ def setup_metrics(self): self.mAP_metric = MeanAveragePrecision(backend="faster_coco_eval") - self.precision_recall_metric = RecallPrecision( - iou_threshold=self.config.validation.iou_threshold, - label_dict=self.label_dict, - ) + self.precision_recall_metric = RecallPrecision( + iou_threshold=self.config.validation.iou_threshold, + label_dict=self.label_dict, + task=self.model.task, + ) + # Point Metrics + elif self.model.task == "point": + self.mae_metric = torchmetrics.MeanAbsoluteError() + self.precision_recall_metric = RecallPrecision( + distance_threshold=self.config.point.distance_threshold, + label_dict=self.label_dict, + task=self.model.task, + ) def load_model(self, model_name=None, revision=None): """Loads a model that has already been pretrained for a specific task, @@ -125,7 +135,6 @@ def load_model(self, model_name=None, revision=None): Returns: None """ - if model_name is None: model_name = self.config.model.name @@ -173,8 +182,9 @@ def set_labels(self, label_dict): self.numeric_to_label_dict = {v: k for k, v in label_dict.items()} def create_model(self, initialize_model=False): - """Initialize a deepforest architecture. This can be done in two ways. - Passed as the model argument to deepforest __init__(), or as a named + """Initialize a deepforest architecture. + + This can be done in two ways. Passed as the model argument to deepforest __init__(), or as a named architecture in config.architecture, which corresponds to a file in models/, as is a subclass of model.Model(). The config args in the .yaml are specified. @@ -199,7 +209,6 @@ def create_trainer(self, logger=None, callbacks=None, **kwargs): callbacks: Optional list of callbacks **kwargs: Additional trainer arguments """ - # Setup metrics which may have changed if the config was modified self.setup_metrics() @@ -326,8 +335,9 @@ def load_dataset( preload_images=False, batch_size=1, ): - """Create a dataset for inference or training. Csv file format is .csv - file with the columns "image_path", "xmin","ymin","xmax","ymax" for the + """Create a dataset for inference or training. + + Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position. Image_path is the relative filename, not absolute path, which is in the root_dir directory. One bounding box per line. @@ -342,7 +352,6 @@ def load_dataset( Returns: ds: a pytorch dataset """ - if self.model.task == "box": ds = training.BoxDataset( csv_file=csv_file, @@ -352,8 +361,8 @@ def load_dataset( augmentations=augmentations, preload_images=preload_images, ) - elif self.model.task == "keypoint": - ds = training.KeypointDataset( + elif self.model.task == "point": + ds = training.PointDataset( csv_file=csv_file, root_dir=root_dir, transforms=transforms, @@ -363,7 +372,7 @@ def load_dataset( ) else: raise ValueError( - f"Invalid task type: {self.model.task}, expected 'box' or 'keypoint'" + f"Invalid task type: {self.model.task}, expected 'box' or 'point'" ) if len(ds) == 0: @@ -484,7 +493,11 @@ def predict_image(self, image: np.ndarray | None = None, path: str | None = None image = image.astype("float32") result = predict._predict_image_( - model=self.model, image=image, path=path, nms_thresh=self.config.nms_thresh + model=self.model, + image=image, + path=path, + iou_threshold=self.config.nms_thresh, + nms_distance_thresh=self.config.point.nms_distance_thresh, ) # If there were no predictions, return None @@ -526,7 +539,6 @@ def predict_file( Returns: df: pandas dataframe with bounding boxes, label and scores for each image in the csv file """ - ds = prediction.FromCSVFile(csv_file=csv_file, root_dir=root_dir) dataloader = self.predict_dataloader(ds, batch_size=self.config.batch_size) results = predict._dataloader_wrapper_( @@ -624,12 +636,11 @@ def predict_tile( # Flatten list from batched prediction # Track global window index across batches global_window_idx = 0 - for _idx, batch in enumerate(batched_results): - for _window_idx, window_result in enumerate(batch): - formatted_result = ds.postprocess( - window_result, global_window_idx + for batch in batched_results: + for window_result in batch: + image_results.append( + ds.postprocess(window_result, global_window_idx) ) - image_results.append(formatted_result) global_window_idx += 1 if not image_results: @@ -668,17 +679,27 @@ def predict_tile( # Perform mosaic for each image_path, or all if image_path is None mosaic_results = [] if results["image_path"].isnull().all(): - mosaic_results.append(predict.mosiac(results, iou_threshold=iou_threshold)) + mosaic_results.append( + predict.mosaic( + results, + iou_threshold=iou_threshold, + nms_distance_thresh=self.config.point.nms_distance_thresh, + ) + ) else: for image_path in results["image_path"].unique(): image_results = results[results["image_path"] == image_path] - image_mosaic = predict.mosiac(image_results, iou_threshold=iou_threshold) + image_mosaic = predict.mosaic( + image_results, + iou_threshold=iou_threshold, + nms_distance_thresh=self.config.point.nms_distance_thresh, + ) image_mosaic["image_path"] = image_path mosaic_results.append(image_mosaic) mosaic_results = pd.concat(mosaic_results) mosaic_results["label"] = mosaic_results.label.apply( - lambda x: self.numeric_to_label_dict[x] + lambda x: self.numeric_to_label_dict.get(x, x) ) if paths[0] is not None: @@ -762,16 +783,30 @@ def validation_step(self, batch, batch_idx): # Compute precision, recall and empty frame metrics. self.precision_recall_metric.update(preds, targets, image_names) - # Filter out empty frames for IoU/mAP metrics. pred + target - non_empty_pred = [] - non_empty_target = [] - for pred, target in zip(preds, targets, strict=True): - if not (target["boxes"].numel() == 0 or torch.all(target["boxes"] == 0)): - non_empty_pred.append(pred) - non_empty_target.append(target) - - self.iou_metric.update(non_empty_pred, non_empty_target) - self.mAP_metric.update(non_empty_pred, non_empty_target) + if self.model.task == "box": + # Filter out empty frames for IoU/mAP metrics. pred + target + non_empty_pred = [] + non_empty_target = [] + for pred, target in zip(preds, targets, strict=True): + if not (target["boxes"].numel() == 0 or torch.all(target["boxes"] == 0)): + non_empty_pred.append(pred) + non_empty_target.append(target) + + self.iou_metric.update(non_empty_pred, non_empty_target) + self.mAP_metric.update(non_empty_pred, non_empty_target) + elif self.model.task == "point": + device = targets[0]["points"].device if targets else torch.device("cpu") + pred_counts = torch.tensor( + [float(len(p["points"])) for p in preds], + dtype=torch.float32, + device=device, + ) + true_counts = torch.tensor( + [float(len(t["points"])) for t in targets], + dtype=torch.float32, + device=device, + ) + self.mae_metric.update(pred_counts, true_counts) # Log the predictions if you want to use them for evaluation logs for i, result in enumerate(preds): @@ -792,17 +827,23 @@ def _compute_epoch_metrics(self) -> dict: """ metrics = {} - # IoU and mAP - if len(self.iou_metric.groundtruth_labels) > 0: - metrics.update(self.iou_metric.compute()) - # Lightning bug: claims this is a warning but it's not. See issue #16218 in Lightning-AI/pytorch-lightning - output = self.mAP_metric.compute() - - # Remove classes from output dict - output = {key: value for key, value in output.items() if not key == "classes"} - metrics.update(output) - - # Box recall/precision + if self.model.task == "box": + # IoU and mAP + if len(self.iou_metric.groundtruth_labels) > 0: + metrics.update(self.iou_metric.compute()) + # Lightning bug: claims this is a warning but it's not. See issue #16218 in Lightning-AI/pytorch-lightning + output = self.mAP_metric.compute() + + # Remove classes from output dict + output = { + key: value for key, value in output.items() if not key == "classes" + } + metrics.update(output) + + elif self.model.task == "point": + metrics["val_mae"] = self.mae_metric.compute() + + # Recall/precision for all tasks metrics.update(self.precision_recall_metric.compute()) return metrics @@ -821,12 +862,16 @@ def on_validation_epoch_end(self): # Manual reset. Lightning does not do this automatically # unless we log the metric objects directly self.precision_recall_metric.reset() - self.iou_metric.reset() - self.mAP_metric.reset() + if self.model.task == "box": + self.iou_metric.reset() + self.mAP_metric.reset() + elif self.model.task == "point": + self.mae_metric.reset() def predict_step(self, batch, batch_idx): - """Predict a batch of images with the deepforest model. If batch is a - list, concatenate the images, predict and then split the results, + """Predict a batch of images with the deepforest model. + + If batch is a list, concatenate the images, predict and then split the results, useful for main.predict_tile. Args: @@ -846,6 +891,7 @@ def predict_step(self, batch, batch_idx): self.model.eval() with torch.no_grad(): preds = self.model.forward(images) + return preds def predict_batch(self, images, preprocess_fn=None): diff --git a/src/deepforest/metrics.py b/src/deepforest/metrics.py index 0136e8b49..c8be031c2 100644 --- a/src/deepforest/metrics.py +++ b/src/deepforest/metrics.py @@ -20,15 +20,19 @@ def __init__( self, task="box", iou_threshold: float = 0.4, + distance_threshold: float = 10.0, label_dict: dict | None = None, **kwargs, ) -> None: - """This metric performs DeepForest's box recall and precision - evaluation. + """This metric performs DeepForest's recall and precision evaluation. Args: - task (str, optional): The type of task to evaluate. Defaults to "box". - iou_threshold (float, optional): IOU threshold for evaluation. Defaults to 0.4. + task (str, optional): The type of task to evaluate. One of + ``"box"``, ``"polygon"``, or ``"point"``. Defaults to ``"box"``. + iou_threshold (float, optional): IoU threshold for box/polygon matching. + Defaults to 0.4. + distance_threshold (float, optional): Pixel distance threshold for + point matching. Defaults to 10.0. label_dict (dict | None): Mapping of class name to numeric ID. When provided and more than one class is present, per-class recall and precision are included in the ``compute()`` output. @@ -36,12 +40,20 @@ def __init__( super().__init__(**kwargs) self.iou_threshold = iou_threshold + self.distance_threshold = distance_threshold self.task = task self.label_dict = label_dict or {} self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} - if task != "box": - raise NotImplementedError("Only 'box' task is currently supported.") + if task not in ("box", "point"): + raise ValueError(f"Unsupported task: {task}. Use 'box' or 'point'.") + + if self.task == "box": + self.pred_key = "boxes" + self.geom_type = "box" + elif self.task == "point": + self.pred_key = "points" + self.geom_type = "point" self.add_state("precision", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("recall", default=torch.tensor(0.0), dist_reduce_fx="sum") @@ -85,12 +97,11 @@ def _update_single( ) -> None: """Update metric state for a single image.""" self.num_images += 1 - - n_pred = len(pred["boxes"]) - n_target = len(target["boxes"]) + n_pred = len(pred[self.pred_key]) + n_target = len(target[self.pred_key]) # Early exit for prediction/target base cases. - is_empty_frame = n_target == 0 or torch.all(target["boxes"] == 0) + is_empty_frame = n_target == 0 or torch.all(target[self.pred_key] == 0) if is_empty_frame: self.num_empty_frames += 1 if n_pred == 0: @@ -107,14 +118,16 @@ def _update_single( self.num_images_with_predictions += 1 # Note: format_geometry handles detach + CPU. IoU == 0 represents not matched. - ground_df = utilities.format_geometry(target, scores=False) - pred_df = utilities.format_geometry(pred, scores=True) + ground_df = utilities.format_geometry( + target, scores=False, geom_type=self.geom_type + ) + pred_df = utilities.format_geometry(pred, scores=True, geom_type=self.geom_type) ground_df["image_path"] = image_path pred_df["image_path"] = image_path result = match_predictions( predictions=pred_df, ground_df=ground_df, - task=self.task, + task=self.geom_type, ) result["image_path"] = image_path @@ -151,13 +164,12 @@ def _sync_dist(self, dist_sync_fn=None, process_group=None) -> None: def compute(self) -> dict: """Computes the recall/precision metrics. - DataFrames (match results and class recall) are stored as instance - attributes ``_all_results`` and ``_class_recall`` for callers that - need them. Only loggable scalar/tensor values are returned. - Per-class recall and precision are included when more than one class - is present in ``label_dict``. + DataFrames (match results and class recall) are stored as + instance attributes ``_all_results`` and ``_class_recall`` for + callers that need them. Only loggable scalar/tensor values are + returned. Per-class recall and precision are included when more + than one class is present in ``label_dict``. """ - # Map numeric label IDs to strings if self.results: self._all_results = pd.concat(self.results, ignore_index=True) @@ -170,9 +182,11 @@ def compute(self) -> dict: if pd.notna(x) else x ) - self._class_recall = compute_class_recall( - self._all_results[self._all_results["match"]] - ) + # TODO Check why this fails for point predictions + if self.task == "box" and len(self.label_dict) > 1: + self._class_recall = compute_class_recall( + self._all_results[self._all_results["match"]] + ) else: self._all_results = pd.DataFrame() self._class_recall = None diff --git a/src/deepforest/models/treeformer.py b/src/deepforest/models/treeformer.py new file mode 100644 index 000000000..3e3222075 --- /dev/null +++ b/src/deepforest/models/treeformer.py @@ -0,0 +1,413 @@ +"""Point density prediction model based on TreeFormer. + +PvT-V2 backbone + multi-scale Regression head producing a density map at +1/4 input resolution. + +See: 10.1109/TGRS.2023.3295802 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin +from transformers import AutoConfig, AutoImageProcessor, PvtV2Model + +from deepforest.model import BaseModel +from deepforest.models.treeformer_decoder import Regression +from deepforest.utilities import density_to_points + + +class TreeFormerModel(nn.Module, PyTorchModelHubMixin): + """PvT-V2 backbone + Regression head for density estimation.""" + + task = "point" + + # Native output channel dims for each PvtV2 variant. + HIDDEN_SIZES = { + "pvt_v2_b0": [32, 64, 160, 256], + "pvt_v2_b1": [64, 128, 320, 512], + "pvt_v2_b2": [64, 128, 320, 512], + "pvt_v2_b3": [64, 128, 320, 512], + "pvt_v2_b4": [64, 128, 320, 512], + "pvt_v2_b5": [64, 128, 320, 512], + } + + # Fixed dims Regression expects + REG_DIMS = [128, 256, 512, 1024] + + def __init__( + self, + backbone: str = "pvt_v2_b3", + num_classes: int = 1, + label_dict: dict | None = None, + num_of_iter_in_ot: int = 100, + sinkhorn_reg: float = 1.0, + density_sigma: float = 5.0, + mae_weight: float = 1.0, + ot_weight: float = 0.1, + density_l1_weight: float = 0.01, + count_cls_weight: float = 1.0, + losses: list | None = None, + norm_cood: bool = False, + enforce_count: bool = True, + score_thresh: float = 0.5, + score_integration_radius: int = 5, + **kwargs, + ): + """Initialize TreeFormerModel.""" + super().__init__() + if "/" not in backbone: + backbone = f"OpenGVLab/{backbone}" + self.backbone_name = backbone + + # Processor handles ImageNet normalization + self.processor = AutoImageProcessor.from_pretrained( + backbone, + use_fast=True, + do_normalize=True, + do_rescale=False, + do_resize=False, + ) + + # Instaniate architecture but don't pull weights + self.backbone = PvtV2Model(AutoConfig.from_pretrained(backbone)) + + self.backbone.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) + + # Suppress some noisy warnings that show in DDP. + torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) + for module in self.backbone.modules(): + if ( + isinstance(module, nn.Conv2d) + and module.groups > 1 + and module.groups == module.in_channels + ): + module.weight.register_hook(lambda grad: grad.contiguous()) + + variant = backbone.split("/")[-1] + src = self.HIDDEN_SIZES.get(variant, None) + if src is None: + raise ValueError( + f"Backbone variant {variant} isn't supported. Please use one of {list(self.HIDDEN_SIZES.keys())}" + ) + + self.proj = nn.ModuleList( + [nn.Conv2d(s, d, 1) for s, d in zip(src, self.REG_DIMS, strict=True)] + ) + self.num_classes = num_classes + self.label_dict = label_dict + self.regression = Regression(num_classes=num_classes) + + # Fixed output stride for PvtV2 + self.downsample_ratio = 4 + + self.enforce_count = enforce_count + self.norm_cood = norm_cood + + # Losses + self.density_l1 = nn.L1Loss(reduction="none") + self.cls_l1 = nn.L1Loss() + + # Training params + self.ot_iter = num_of_iter_in_ot + self.sinkhorn_reg = sinkhorn_reg + self.density_sigma = density_sigma + self.mae_weight = mae_weight + self.ot_weight = ot_weight + self.density_l1_weight = density_l1_weight + self.count_cls_weight = count_cls_weight + self.score_thresh = score_thresh + self.score_integration_radius = score_integration_radius + self.losses = ( + list(losses) + if losses is not None + else ["count", "ot", "density_l1", "count_cls"] + ) + + self.kwargs = kwargs + self.update_config() + + def update_config(self): + # Stored as config on HF + self._hub_mixin_config = { + "backbone": self.backbone_name, + "num_classes": self.num_classes, + "label_dict": self.label_dict, + "num_of_iter_in_ot": self.ot_iter, + "sinkhorn_reg": self.sinkhorn_reg, + "density_sigma": self.density_sigma, + "mae_weight": self.mae_weight, + "ot_weight": self.ot_weight, + "density_l1_weight": self.density_l1_weight, + "count_cls_weight": self.count_cls_weight, + "losses": self.losses, + "norm_cood": self.norm_cood, + "enforce_count": self.enforce_count, + "score_thresh": self.score_thresh, + "score_integration_radius": self.score_integration_radius, + **self.kwargs, + } + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def _normalize_density( + self, score_map: torch.Tensor, cls_count: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Return (density_map, normed_density) from raw score map and count + scalar. + + If the model is in enforce_count mode, the raw density map is + scaled to match the count prediction; otherwise it's returned + as-is. + """ + B = score_map.size(0) + score_sum = score_map.view(B, -1).sum(1).view(B, 1, 1, 1) + normed = score_map / (score_sum + 1e-4) + if self.enforce_count: + count = cls_count.view(B, 1, 1, 1).abs().clamp(min=1e-4) + return normed * count, normed + return score_map, normed + + def _output_shapes( + self, image_shapes: list[tuple[int, int]] + ) -> list[tuple[int, int]]: + """Return valid output-map extent (H//4, W//4) for each input image.""" + return [ + ( + max(image_h // self.downsample_ratio, 1), + max(image_w // self.downsample_ratio, 1), + ) + for image_h, image_w in image_shapes + ] + + def _cls_outputs_to_count( + self, + cls_output: torch.Tensor, + image_shapes: list[tuple[int, int]], + ) -> torch.Tensor: + """Convert CLS-head count-density prediction to absolute count by + multiplying by the area of each image.""" + count_density = cls_output.reshape(len(image_shapes)) + count_areas = torch.tensor( + [image_h * image_w for image_h, image_w in image_shapes], + dtype=torch.float32, + device=self.device, + ) + return count_density * count_areas + + def _forward_features(self, x: torch.Tensor): + """Run backbone and project stage outputs to REG_DIMS. + + Returns: + feats: list of 4 spatial tensors (B, REG_DIMS[i], H/4^i, W/4^i) + cls: list of 3 vectors (B, REG_DIMS[1..3]) — GAP of stages 1-3 + """ + out = self.backbone(x, output_hidden_states=True) + feats = [p(h) for p, h in zip(self.proj, out.hidden_states, strict=False)] + cls = [feats[i].mean(dim=[2, 3]) for i in range(1, 4)] + return feats, cls + + def postprocess_density(self, density_map, images) -> list[dict]: + """Convert TreeFormer density map outputs to per-image point dicts. + + Scales point coordinates from density-map space to image-pixel space. + Handles both batched tensor inputs and list inputs (from the default + prediction dataloader which returns a list via its collate_fn). + + Args: + density_outputs: tuple ``(density_map, normed_density)`` from + ``TreeFormerModel.forward``. When ``images`` is a list, + ``density_map`` is also a list of per-image tensors. + images: batch tensor ``(B, C, H_img, W_img)`` or list of + ``(C, H_img, W_img)`` tensors. + + Returns: + List of dicts with ``"points"`` (N, 2), ``"scores"`` (N,), + ``"labels"`` (N,) per image, with (x, y) in image-pixel coordinates. + """ + preds = density_to_points( + density_map, + score_thresh=self.score_thresh, + score_integration_radius=self.score_integration_radius, + ) + + if not isinstance(images, (list, tuple)): + images = [images[i] for i in range(images.shape[0])] + + # Per-image density maps; scale each individually. + for pred, dm, img in zip(preds, density_map, images, strict=False): + H_img, W_img = img.shape[-2], img.shape[-1] + H_dm, W_dm = dm.shape[-2], dm.shape[-1] + if pred["points"].shape[0] > 0: + pts = pred["points"].clone() + pts[:, 0] = pts[:, 0] * (W_img / W_dm) + pts[:, 1] = pts[:, 1] * (H_img / H_dm) + pred["points"] = pts + + return preds + + def compute_loss( + self, + density_maps: list[torch.Tensor], + normed_density: list[torch.Tensor], + cls_outputs: list, + targets: list, + image_shapes: list[tuple[int, int]], + ) -> dict: + """Compute training loss: L1 between density-map count and GT count. + + Args: + density_maps: list of multi-scale density maps; index 0 is primary. + normed_density: spatially normalised version of density_maps[0]. + cls_outputs: CLS-head count predictions (unused here). + targets: list of target dicts with ``"points"`` tensors. + image_shapes: original ``(H, W)`` for each image. + + Returns: + dict with ``"loss"`` scalar tensor (supports backprop). + """ + true_counts = torch.tensor( + [float(len(t["points"])) for t in targets], + dtype=torch.float32, + device=self.device, + ) + + # TODO: This is a placeholder to verify integration, full loss + # computation as used with the model releases will be implemented + # in a future PR. + pred_counts = self._cls_outputs_to_count(cls_outputs[0], image_shapes) + count_loss = self.cls_l1(pred_counts, true_counts) * self.mae_weight + + return {"loss": count_loss} + + def forward( + self, inputs: torch.Tensor | list[torch.Tensor], targets: list | None = None + ): + """Forward pass. + + In training mode returns a loss dict; in eval mode returns a list of + per-image prediction dicts (see ``postprocess_density``). + + Args: + inputs: ``(B, C, H, W)`` tensor or list of ``(C, H, W)`` tensors. + targets: list of target dicts (required during training). + """ + # Batch-pad variable-size images; record original sizes. + if isinstance(inputs, list): + shapes = [(img.shape[-2], img.shape[-1]) for img in inputs] + H = max(h for h, _ in shapes) + W = max(w for _, w in shapes) + batch = inputs[0].new_zeros(len(inputs), inputs[0].shape[0], H, W) + for i, img in enumerate(inputs): + batch[i, :, : shapes[i][0], : shapes[i][1]] = img + else: + shapes = [(inputs.shape[2], inputs.shape[3])] * inputs.shape[0] + batch = inputs + + # Pad to next multiple of 32 for PvT stride compatibility. + H, W = batch.shape[2:] + batch = F.pad(batch, (0, (32 - W % 32) % 32, 0, (32 - H % 32) % 32)) + + encoded = self.processor.preprocess( + images=batch, + return_tensors="pt", + do_rescale=False, + do_resize=False, + )["pixel_values"].to(self.device) + + label_feats, l_cls = self._forward_features(encoded) + out_L, out_cls_l = self.regression(label_feats, l_cls) + + # Crop each output to its valid spatial extent and normalise. + # Cropping naturally excludes batch padding, which is correct for both + # training losses and inference peak-finding. + density_list, normed_list = [], [] + for i, (valid_h, valid_w) in enumerate(self._output_shapes(shapes)): + crop = out_L[0][i : i + 1, :, :valid_h, :valid_w].contiguous() + cls_count = self._cls_outputs_to_count(out_cls_l[0][i : i + 1], [shapes[i]]) + dm, nd = self._normalize_density(crop, cls_count) + density_list.append(dm) + normed_list.append(nd) + + if self.training: + if targets is None: + raise ValueError("targets must be provided in training mode") + return self.compute_loss( + density_list, normed_list, out_cls_l, targets, image_shapes=shapes + ) + + return self.postprocess_density(density_list, batch) + + +class Model(BaseModel): + """DeepForest model wrapper for TreeFormer. + + Selected via ``config.architecture = "treeformer"``. + """ + + def create_model( + self, + pretrained: str | None = None, + revision: str | None = None, + map_location: str | torch.device | None = None, + **hf_args, + ) -> TreeFormerModel: + """Create or load a TreeFormerModel. + + Args: + pretrained: HuggingFace repo ID to load weights from, or None. + revision: Model revision/tag on the Hub. + map_location: Device to move model to after loading. + + Returns: + Configured TreeFormerModel instance. + """ + label_dict = dict(self.config.label_dict) if self.config.label_dict else None + num_classes = ( + len(label_dict) if label_dict is not None else self.config.num_classes or 1 + ) + backbone = self.config.point.backbone + + # Load fully trained backbone + head from hub + if pretrained: + if label_dict is not None: + hf_args["label_dict"] = label_dict + + model = TreeFormerModel.from_pretrained( + pretrained, + revision=revision, + num_classes=num_classes, + score_thresh=self.config.score_thresh, + score_integration_radius=self.config.point.score_integration_radius, + **hf_args, + ) + # Architecture from config, backbone weights from ImageNet. + elif backbone: + model = TreeFormerModel( + backbone=backbone, + num_classes=num_classes, + label_dict=label_dict, + score_thresh=self.config.score_thresh, + score_integration_radius=self.config.point.score_integration_radius, + **hf_args, + ) + model.backbone = PvtV2Model.from_pretrained( + model.backbone_name, ignore_mismatched_sizes=True + ) + # Random init + else: + model = TreeFormerModel( + num_classes=num_classes, + label_dict=label_dict, + score_thresh=self.config.score_thresh, + score_integration_radius=self.config.point.score_integration_radius, + **hf_args, + ) + + if map_location is not None: + model = model.to(map_location) + return model diff --git a/src/deepforest/models/treeformer_decoder.py b/src/deepforest/models/treeformer_decoder.py new file mode 100644 index 000000000..61d480f82 --- /dev/null +++ b/src/deepforest/models/treeformer_decoder.py @@ -0,0 +1,221 @@ +"""Decoder modules for the TreeFormer density-map prediction head. + +Contains the multi-scale Regression head and its helper blocks. This +code is a reimplementation of the code found in the TreeFormer +repository, but updated to follow more modern PyTorch practices. +""" + +import numpy as np +import torch +import torch.nn as nn +from torch.distributions.uniform import Uniform + + +class ChannelAttention(nn.Module): + def __init__(self, in_planes, ratio=16): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + + self.fc = nn.Sequential( + nn.Linear(in_planes, in_planes // ratio, bias=False), + nn.ReLU(inplace=True), + nn.Linear(in_planes // ratio, in_planes, bias=False), + ) + self.sigmoid = nn.Sigmoid() + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + + def forward(self, in_feature): + x = in_feature + b, c, _, _ = in_feature.size() + avg_out = self.fc(self.avg_pool(x).view(b, c)).view(b, c, 1, 1) + out = avg_out + return self.sigmoid(out).expand_as(in_feature) * in_feature + + +class FeatureDropDecoder(nn.Module): + def __init__(self, upscale, conv_in_ch, num_classes): + super().__init__() + + def feature_dropout(self, x): + attention = torch.mean(x, dim=1, keepdim=True) + max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True) + threshold = max_val * np.random.uniform(0.7, 0.9) + threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) + drop_mask = (attention < threshold).float() + return x.mul(drop_mask) + + def forward(self, x): + x = self.feature_dropout(x) + return x + + +class FeatureNoiseDecoder(nn.Module): + def __init__(self, upscale, conv_in_ch, num_classes, uniform_range=0.3): + super().__init__() + self.uni_dist = Uniform(-uniform_range, uniform_range) + + def feature_based_noise(self, x): + noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0) + x_noise = x.mul(noise_vector) + x + return x_noise + + def forward(self, x): + x = self.feature_based_noise(x) + return x + + +class DropOutDecoder(nn.Module): + def __init__( + self, upscale, conv_in_ch, num_classes, drop_rate=0.3, spatial_dropout=True + ): + super().__init__() + self.dropout = ( + nn.Dropout2d(p=drop_rate) if spatial_dropout else nn.Dropout(drop_rate) + ) + + def forward(self, x): + x = self.dropout(x) + return x + + +class Regression(nn.Module): + def __init__(self, num_classes: int = 1): + super().__init__() + self.num_classes = num_classes + + self.v1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(256, 128, 3, padding=1, dilation=1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + ) + + self.v2 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(512, 256, 3, padding=1, dilation=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + ) + + self.v3 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(1024, 512, 3, padding=1, dilation=1), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + ) + + self.ca2 = nn.Sequential( + ChannelAttention(512), + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + ) + + self.ca1 = nn.Sequential( + ChannelAttention(256), + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + ) + + self.ca0 = nn.Sequential( + ChannelAttention(128), + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + ) + + self.res2 = nn.Sequential( + nn.Conv2d(512, 256, 3, padding=1, dilation=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, 128, 3, padding=1, dilation=1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, num_classes, 3, padding=1, dilation=1), + nn.ReLU(inplace=True), + ) + + self.res1 = nn.Sequential( + nn.Conv2d(256, 128, 3, padding=1, dilation=1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 64, 3, padding=1, dilation=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.Conv2d(64, num_classes, 3, padding=1, dilation=1), + nn.ReLU(inplace=True), + ) + + self.res0 = nn.Sequential( + nn.Conv2d(128, 64, 3, padding=1, dilation=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.Conv2d(64, num_classes, 3, padding=1, dilation=1), + nn.ReLU(inplace=True), + ) + + self.noise2 = nn.Dropout2d(p=0.3) + self.noise1 = FeatureDropDecoder(1, 256, 256) + self.noise0 = FeatureNoiseDecoder(1, 128, 128) + self.noise_cls = nn.Dropout(p=0.3) + + self.upsam2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.upsam4 = nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True) + + self.cls_lin1 = nn.Linear(1024, 512, bias=False) + self.cls_lin2 = nn.Linear(512, 256, bias=False) + self.cls_lin3 = nn.Linear(256, 128, bias=False) + self.cls_lin4 = nn.Linear(128, num_classes, bias=True) + + self.init_param() + + def forward(self, x, cls): + x0, x1, x2, x3 = x[0], x[1], x[2], x[3] + + x2_1 = self.ca2(x2) + self.v3(x3) + x1_1 = self.ca1(x1) + self.v2(x2_1) + x0_1 = self.ca0(x0) + self.v1(x1_1) + + if self.training: + lin1_out = self.cls_lin1(cls[2]) + yc2 = self.cls_lin4( + self.cls_lin3(self.cls_lin2(self.noise_cls(lin1_out))) + ).squeeze(-1) + + lin2_out = self.cls_lin2(cls[1]) + lin2_noisy = self.noise1(lin2_out[:, :, None, None]).squeeze(-1).squeeze(-1) + yc1 = self.cls_lin4(self.cls_lin3(lin2_noisy)).squeeze(-1) + + lin3_out = self.cls_lin3(cls[0]) + lin3_noisy = self.noise0(lin3_out[:, :, None, None]).squeeze(-1).squeeze(-1) + yc0 = self.cls_lin4(lin3_noisy).squeeze(-1) + + y2 = self.res2(self.upsam4(self.noise2(x2_1))) + y1 = self.res1(self.upsam2(self.noise1(x1_1))) + y0 = self.res0(self.noise0(x0_1)) + + else: + yc2 = self.cls_lin4( + self.cls_lin3(self.cls_lin2(self.cls_lin1(cls[2]))) + ).squeeze(-1) + yc1 = self.cls_lin4(self.cls_lin3(self.cls_lin2(cls[1]))).squeeze(-1) + yc0 = self.cls_lin4(self.cls_lin3(cls[0])).squeeze(-1) + + y2 = self.res2(self.upsam4(x2_1)) + y1 = self.res1(self.upsam2(x1_1)) + y0 = self.res0(x0_1) + + return [y0, y1, y2], [yc0, yc1, yc2] + + def init_param(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.normal_(m.weight, std=0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) diff --git a/src/deepforest/predict.py b/src/deepforest/predict.py index 0dfad10a0..7396a023f 100644 --- a/src/deepforest/predict.py +++ b/src/deepforest/predict.py @@ -3,7 +3,10 @@ import numpy as np import pandas as pd +import shapely import torch +from scipy.spatial import cKDTree +from shapely import affinity from torchvision.ops import nms from deepforest import utilities @@ -15,7 +18,8 @@ def _predict_image_( model, image: np.ndarray | None = None, path: str | None = None, - nms_thresh: float = 0.15, + iou_threshold: float = 0.15, + nms_distance_thresh: float = 5.0, ): """Predict a single image with a deepforest model. @@ -23,26 +27,30 @@ def _predict_image_( model: a deepforest.main.model object image: a tensor of shape (channels, height, width) path: optional path to read image from disk instead of passing image arg - nms_thresh: Non-max suppression threshold, see config.nms_thresh + iou_threshold: IoU threshold for box non-max suppression + nms_distance_thresh: Distance threshold in pixels for point NMS, see config.point.nms_distance_thresh Returns: df: A pandas dataframe of predictions (Default) img: The input with predictions overlaid (Optional) """ - image = torch.tensor(image).permute(2, 0, 1) image = image / 255 with torch.no_grad(): prediction = model(image.unsqueeze(0)) + prediction = prediction[0] + geom_type = utilities.determine_geometry_type(prediction) + df = utilities.format_geometry(prediction, geom_type=geom_type) + # return None for no predictions - if len(prediction[0]["boxes"]) == 0: + if df is None: return None - df = utilities.format_boxes(prediction[0]) - - if df.label.nunique() > 1: - df = across_class_nms(df, iou_threshold=nms_thresh) + if geom_type == "box" and df.label.nunique() > 1: + df = across_class_nms(df, iou_threshold=iou_threshold) + elif geom_type == "point": + df = reduce_points(df, nms_thresh=nms_distance_thresh) # Add image path if provided if path is not None: @@ -51,103 +59,134 @@ def _predict_image_( return df -def transform_coordinates(boxes): - """Transform box coordinates from window space to original image space. +def translate_predictions(predictions: pd.DataFrame) -> pd.DataFrame: + """Shift window-relative predictions into image coordinates using geometry. Args: - boxes: DataFrame of predictions with xmin, ymin, xmax, ymax, window_xmin, window_ymin columns + predictions: DataFrame with geometry and window_xmin/window_ymin offset columns. Returns: - DataFrame with transformed coordinates + DataFrame with geometry (and coordinate columns) shifted by the window origin. """ - boxes = boxes.copy() - boxes["xmin"] += boxes["window_xmin"] - boxes["xmax"] += boxes["window_xmin"] - boxes["ymin"] += boxes["window_ymin"] - boxes["ymax"] += boxes["window_ymin"] + predictions = predictions.copy() + is_box = {"xmin", "ymin", "xmax", "ymax"}.issubset(predictions.columns) + + predictions["geometry"] = [ + affinity.translate(geom, xoff=dx, yoff=dy) + for geom, dx, dy in zip( + predictions.geometry, + predictions.window_xmin, + predictions.window_ymin, + strict=True, + ) + ] - # Cast to int - boxes["xmin"] = boxes["xmin"].astype(int) - boxes["ymin"] = boxes["ymin"].astype(int) - boxes["xmax"] = boxes["xmax"].astype(int) - boxes["ymax"] = boxes["ymax"].astype(int) + if is_box: + bounds = shapely.bounds(np.array(predictions["geometry"])) + predictions[["xmin", "ymin", "xmax", "ymax"]] = bounds.astype(int) + else: + coords = shapely.get_coordinates(np.array(predictions["geometry"])) + predictions["x"] = coords[:, 0] + predictions["y"] = coords[:, 1] - return boxes + return predictions.drop(columns=["window_xmin", "window_ymin"]).reset_index(drop=True) -def apply_nms(boxes, scores, labels, iou_threshold): - """Apply non-maximum suppression to boxes. +def reduce_boxes(predictions: pd.DataFrame, iou_threshold: float) -> pd.DataFrame: + """Reduce overlapping box predictions with torchvision NMS. Args: - boxes: tensor of shape (N, 4) containing box coordinates - scores: tensor of shape (N,) containing confidence scores - labels: array of shape (N,) containing labels - iou_threshold: IoU threshold for NMS + predictions: DataFrame of image-space box predictions. + iou_threshold: IoU threshold for NMS. Returns: - DataFrame with filtered boxes + DataFrame containing the filtered box predictions in the public box schema. """ - bbox_left_idx = nms(boxes=boxes, scores=scores, iou_threshold=iou_threshold) - bbox_left_idx = bbox_left_idx.numpy() - - new_boxes = boxes[bbox_left_idx].type(torch.int) - new_labels = labels[bbox_left_idx] - new_scores = scores[bbox_left_idx] + box_output_columns = ["xmin", "ymin", "xmax", "ymax", "label", "score"] + if predictions.shape[0] <= 1: + return predictions[box_output_columns].reset_index(drop=True).copy() - # Recreate box dataframe - image_detections = np.concatenate( - [ - new_boxes, - np.expand_dims(new_labels, axis=1), - np.expand_dims(new_scores, axis=1), - ], - axis=1, + print( + f"{predictions.shape[0]} predictions in overlapping windows, applying non-max suppression" ) - return pd.DataFrame( - image_detections, columns=["xmin", "ymin", "xmax", "ymax", "label", "score"] + boxes = torch.tensor( + predictions[["xmin", "ymin", "xmax", "ymax"]].values, dtype=torch.float32 ) + scores = torch.tensor(predictions.score.values, dtype=torch.float32) + keep_idx = nms(boxes=boxes, scores=scores, iou_threshold=iou_threshold).numpy() + + filtered_predictions = predictions.iloc[keep_idx].reset_index(drop=True) + print(f"{filtered_predictions.shape[0]} predictions kept after non-max suppression") + return filtered_predictions[box_output_columns].reset_index(drop=True).copy() + +def reduce_points(predictions: pd.DataFrame, nms_thresh: float) -> pd.DataFrame: + """Reduce nearby point predictions with distance-based suppression. -def mosiac(predictions, iou_threshold=0.1): + Args: + predictions: DataFrame of image-space point predictions. + nms_thresh: Distance threshold in pixels used to suppress duplicates. + + Returns: + Filtered point predictions with all non-coordinate columns preserved. + """ + predictions = predictions.reset_index(drop=True) + if nms_thresh <= 0 or len(predictions) <= 1: + return predictions + + coords = predictions[["x", "y"]].values + scores = predictions["score"].values + tree = cKDTree(coords) + order = np.argsort(scores)[::-1] + kept = np.ones(len(coords), dtype=bool) + + for idx in order: + if not kept[idx]: + continue + + for neighbor_idx in tree.query_ball_point(coords[idx], r=nms_thresh): + if neighbor_idx != idx: + kept[neighbor_idx] = False + + return predictions.iloc[np.flatnonzero(kept)].reset_index(drop=True) + + +def mosaic( + predictions: pd.DataFrame, + iou_threshold: float = 0.1, + nms_distance_thresh: float = 5.0, +) -> pd.DataFrame: """Mosaic predictions from overlapping windows. Args: predictions: A pandas dataframe containing predictions from overlapping windows from a single image. - iou_threshold: The IoU threshold for non-max suppression. + iou_threshold: The IoU threshold for non-max suppression (box predictions). + nms_distance_thresh: Distance in pixels below which two points are duplicates (point predictions). Returns: A pandas dataframe of predictions. """ - predicted_boxes = transform_coordinates(predictions) - - # Skip NMS if there's is one or less prediction - if predicted_boxes.shape[0] <= 1: - return predicted_boxes + if predictions.empty: + return predictions.copy() - print( - f"{predicted_boxes.shape[0]} predictions in overlapping windows, applying non-max suppression" - ) + is_box_predictions = {"xmin", "ymin", "xmax", "ymax"}.issubset(predictions.columns) + is_point_predictions = {"x", "y"}.issubset(predictions.columns) + translated_predictions = translate_predictions(predictions) - # Convert to tensors - boxes = torch.tensor( - predicted_boxes[["xmin", "ymin", "xmax", "ymax"]].values, dtype=torch.float32 - ) - scores = torch.tensor(predicted_boxes.score.values, dtype=torch.float32) - labels = predicted_boxes.label.values + if is_box_predictions: + return reduce_boxes(translated_predictions, iou_threshold=iou_threshold) - # Apply NMS - filtered_boxes = apply_nms(boxes, scores, labels, iou_threshold) - print(f"{filtered_boxes.shape[0]} predictions kept after non-max suppression") + if is_point_predictions: + return reduce_points(translated_predictions, nms_thresh=nms_distance_thresh) - return filtered_boxes + raise ValueError("Predictions must include either box or point coordinates.") def across_class_nms(predicted_boxes, iou_threshold=0.15): """Perform non-max suppression for a dataframe of results (see visualize.format_boxes) to remove boxes that overlap by iou_thresholdold of IoU.""" - # Skip NMS if there's is one or less prediction if predicted_boxes.shape[0] <= 1: return predicted_boxes diff --git a/src/deepforest/utilities.py b/src/deepforest/utilities.py index f467bbc6c..46043710b 100644 --- a/src/deepforest/utilities.py +++ b/src/deepforest/utilities.py @@ -7,9 +7,11 @@ import pandas as pd import rasterio import shapely +import torch import xmltodict from omegaconf import DictConfig, OmegaConf from PIL import Image +from skimage.feature import peak_local_max from tqdm import tqdm from deepforest import _ROOT @@ -39,7 +41,6 @@ def load_config( Returns: config (DictConfig): composed configuration """ - if not config_name.endswith(".yaml"): config_name += ".yaml" @@ -357,6 +358,12 @@ def determine_geometry_type(df): geometry_type = "polygon" elif "points" in df.keys(): geometry_type = "point" + else: + raise ValueError( + f"Could not determine geometry type from dict keys {list(df.keys())}" + ) + else: + raise ValueError(f"Could not determine geometry type from type {type(df)}") return geometry_type @@ -370,20 +377,16 @@ def format_geometry(predictions, scores=True, geom_type=None): df: a pandas dataframe None if the dataframe is empty """ - # Detect geometry type if geom_type is None: geom_type = determine_geometry_type(predictions) if geom_type == "box": df = format_boxes(predictions, scores=scores) - if df is None: - return None - elif geom_type == "polygon": raise ValueError("Polygon predictions are not yet supported for formatting") elif geom_type == "point": - raise ValueError("Point predictions are not yet supported for formatting") + df = format_points(predictions, scores=scores) return df @@ -416,6 +419,34 @@ def format_boxes(prediction, scores=True): return df +def format_points(prediction: dict, scores: bool = True) -> pd.DataFrame | None: + """Convert a single density_to_points dict to a DataFrame. + + Args: + prediction: dict with ``"points"`` (N, 2), ``"scores"`` (N,), + and ``"labels"`` (N,) tensors. + scores: Whether to include the score column. + + Returns: + DataFrame with columns ``x``, ``y``, ``label`` (and ``score`` if + requested), or ``None`` when there are no detections. + """ + if len(prediction["points"]) == 0: + return None + + df = pd.DataFrame( + prediction["points"].cpu().detach().numpy(), + columns=["x", "y"], + ) + df["label"] = prediction["labels"].cpu().detach().numpy() + + if scores: + df["score"] = prediction["scores"].cpu().detach().numpy() + + df["geometry"] = df.apply(lambda x: shapely.geometry.Point(x.x, x.y), axis=1) + return df + + def read_coco(json_file): """Read a COCO format JSON file and return a pandas dataframe. @@ -747,7 +778,6 @@ def geo_to_image_coordinates(gdf, image_bounds, image_resolution): Returns: gdf: a geopandas dataframe with the transformed to image origin. CRS is removed """ - if len(image_bounds) != 4: raise ValueError("image_bounds must be a tuple of (left, bottom, right, top)") @@ -766,7 +796,6 @@ def geo_to_image_coordinates(gdf, image_bounds, image_resolution): def round_with_floats(x): """Check if string x is float or int, return int, rounded if needed.""" - try: result = int(x) except BaseException: @@ -905,3 +934,65 @@ def image_to_geo_coordinates(gdf, root_dir=None, flip_y_axis=False): def collate_fn(batch): batch = list(filter(lambda x: x is not None, batch)) return tuple(zip(*batch, strict=False)) + + +def density_to_points( + density_map, + score_thresh: float = 0.1, + score_integration_radius: int = 5, +) -> list[dict]: + """Extract peak point predictions from a batch of density maps. + + Normalises the density map to [0, 1] per image then finds local maxima + using ``skimage.feature.peak_local_max`` with a relative threshold. + + Args: + density_map: ``(B, C, H, W)`` density-map tensor or a list of + per-image ``(1, C, H_i, W_i)`` tensors. + score_thresh: Relative peak threshold in [0, 1]. Peaks below + ``score_thresh * max(density_map)`` are discarded. + score_integration_radius: ``min_distance`` for ``peak_local_max`` — + suppresses duplicate detections within this many density-map pixels. + + Returns: + List of dicts with ``"points"`` (N, 2) in density-map pixel coords + (x, y order), ``"scores"`` (N,), and ``"labels"`` (N,) per image. + """ + if isinstance(density_map, list): + results = [] + for dm in density_map: + results.extend(density_to_points(dm, score_thresh, score_integration_radius)) + return results + + # Normalise to [0, 1] per image so threshold_rel is scale-invariant. + density_norm = density_map / density_map.amax(dim=(2, 3), keepdim=True).clamp( + min=1e-6 + ) + + results = [] + for b in range(density_map.shape[0]): + density_np = density_norm[b, 0].detach().cpu().float().numpy() + + coords = peak_local_max( + density_np, + min_distance=score_integration_radius, + threshold_rel=score_thresh, + ) # (N, 2) in (row, col) = (y, x) + + if len(coords) > 0: + yx = torch.from_numpy(coords).float() + xy = yx[:, [1, 0]] + scores = density_norm[b, 0][coords[:, 0], coords[:, 1]].detach() + else: + xy = torch.zeros((0, 2)) + scores = torch.zeros(0, device=density_map.device) + + n = xy.shape[0] + results.append( + { + "points": xy.to(density_map.device), + "scores": scores, + "labels": torch.zeros(n, dtype=torch.long, device=density_map.device), + } + ) + return results diff --git a/tests/test_datasets_prediction.py b/tests/test_datasets_prediction.py index f98ab3f72..a0f2519f3 100644 --- a/tests/test_datasets_prediction.py +++ b/tests/test_datasets_prediction.py @@ -6,6 +6,9 @@ import pytest from deepforest import get_data +from deepforest import predict +from shapely import geometry +import pandas as pd from deepforest.datasets.prediction import TiledRaster, SingleImage, MultiImage, FromCSVFile, PredictionDataset @@ -59,3 +62,57 @@ def test_FromCSVFile(): ds = FromCSVFile(csv_file=get_data("example.csv"), root_dir=os.path.dirname(get_data("example.csv"))) assert len(ds) == 1 + + +def test_translate_predictions_boxes(): + predictions = pd.DataFrame( + { + "xmin": [1, 5], + "ymin": [2, 6], + "xmax": [3, 9], + "ymax": [4, 10], + "label": [0, 1], + "score": [0.9, 0.8], + "window_xmin": [10, 100], + "window_ymin": [20, 200], + "geometry": [ + geometry.box(1, 2, 3, 4), + geometry.box(5, 6, 9, 10), + ], + } + ) + + translated = predict.translate_predictions(predictions) + + assert "window_xmin" not in translated.columns + assert "window_ymin" not in translated.columns + assert translated[["xmin", "ymin", "xmax", "ymax"]].values.tolist() == [ + [11, 22, 13, 24], + [105, 206, 109, 210], + ] + assert translated.geometry.iloc[0].bounds == (11.0, 22.0, 13.0, 24.0) + assert translated.geometry.iloc[1].bounds == (105.0, 206.0, 109.0, 210.0) + + +def test_translate_predictions_points(): + predictions = pd.DataFrame( + { + "x": [5.0, 10.0], + "y": [6.0, 12.0], + "label": [0, 0], + "score": [0.9, 0.8], + "patch_id": ["a", "b"], + "window_xmin": [100, 200], + "window_ymin": [300, 400], + "geometry": [geometry.Point(5.0, 6.0), geometry.Point(10.0, 12.0)], + } + ) + + translated = predict.translate_predictions(predictions) + + assert translated[["x", "y"]].values.tolist() == [[105.0, 306.0], [210.0, 412.0]] + assert translated["patch_id"].tolist() == ["a", "b"] + assert translated.geometry.iloc[0].x == pytest.approx(105.0) + assert translated.geometry.iloc[0].y == pytest.approx(306.0) + assert translated.geometry.iloc[1].x == pytest.approx(210.0) + assert translated.geometry.iloc[1].y == pytest.approx(412.0) diff --git a/tests/test_datasets_training_keypoint.py b/tests/test_datasets_training_point.py similarity index 77% rename from tests/test_datasets_training_keypoint.py rename to tests/test_datasets_training_point.py index cd16fdecb..eec79e5f1 100644 --- a/tests/test_datasets_training_keypoint.py +++ b/tests/test_datasets_training_point.py @@ -1,4 +1,4 @@ -"""Tests for KeypointDataset.""" +"""Tests for PointDataset.""" import os @@ -8,16 +8,16 @@ import torch from deepforest import get_data -from deepforest.datasets.training import KeypointDataset +from deepforest.datasets.training import PointDataset @pytest.fixture() -def keypoint_csv(): +def point_csv(): return get_data("2019_BLAN_3_751000_4330000_image_crop_keypoints.csv") @pytest.fixture() -def keypoint_root_dir(): +def point_root_dir(): return os.path.dirname( get_data("2019_BLAN_3_751000_4330000_image_crop_keypoints.csv") ) @@ -34,12 +34,12 @@ def box_root_dir(): return os.path.dirname(get_data("OSBS_029.png")) -def test_keypoint_dataset_centroid(keypoint_csv, keypoint_root_dir): +def test_point_dataset_centroid(point_csv, point_root_dir): """Basic construction, iteration, and output format.""" - ds = KeypointDataset( - csv_file=keypoint_csv, root_dir=keypoint_root_dir, label_dict={"Tree": 0}, output="centroid" + ds = PointDataset( + csv_file=point_csv, root_dir=point_root_dir, label_dict={"Tree": 0}, output="centroid" ) - raw = pd.read_csv(keypoint_csv) + raw = pd.read_csv(point_csv) assert len(ds) == len(raw.image_path.unique()) @@ -59,11 +59,11 @@ def test_keypoint_dataset_centroid(keypoint_csv, keypoint_root_dir): assert targets["labels"].shape == (raw.shape[0],) assert targets["labels"].dtype == torch.int64 -def test_keypoint_dataset_density(keypoint_csv, keypoint_root_dir): +def test_point_dataset_density(point_csv, point_root_dir): """Density output mode should return a class-first tensor.""" - ds = KeypointDataset( - csv_file=keypoint_csv, - root_dir=keypoint_root_dir, + ds = PointDataset( + csv_file=point_csv, + root_dir=point_root_dir, label_dict={"Tree": 0}, output="density", ) @@ -80,9 +80,9 @@ def test_keypoint_dataset_density(keypoint_csv, keypoint_root_dir): assert targets["labels"].max() > 0 -def test_keypoint_dataset_from_boxes(box_csv, box_root_dir): +def test_point_dataset_from_boxes(box_csv, box_root_dir): """When given bounding box geometry, annotations_for_path should extract centroids.""" - ds = KeypointDataset( + ds = PointDataset( csv_file=box_csv, root_dir=box_root_dir, label_dict={"Tree": 0} ) raw = pd.read_csv(box_csv) @@ -102,14 +102,14 @@ def test_keypoint_dataset_from_boxes(box_csv, box_root_dir): ) -def test_keypoint_dataset_hflip(keypoint_csv, keypoint_root_dir): +def test_point_dataset_hflip(point_csv, point_root_dir): """Test that augmentation works by performing a horizontal flip augmentation, checking it correctly flips x coordinates and leaves y unchanged.""" - ds_orig = KeypointDataset( - csv_file=keypoint_csv, root_dir=keypoint_root_dir, + ds_orig = PointDataset( + csv_file=point_csv, root_dir=point_root_dir, ) - ds_flip = KeypointDataset( - csv_file=keypoint_csv, root_dir=keypoint_root_dir, + ds_flip = PointDataset( + csv_file=point_csv, root_dir=point_root_dir, augmentations=[{"HorizontalFlip": {"p": 1.0}}], ) @@ -125,7 +125,7 @@ def test_keypoint_dataset_hflip(keypoint_csv, keypoint_root_dir): ) -def test_keypoint_dataset_validate_coordinates_oob(tmp_path, keypoint_root_dir): +def test_point_dataset_validate_coordinates_oob(tmp_path, point_root_dir): """Out-of-bounds points should raise ValueError.""" image_name = "2019_BLAN_3_751000_4330000_image_crop.jpg" @@ -141,12 +141,12 @@ def test_keypoint_dataset_validate_coordinates_oob(tmp_path, keypoint_root_dir): df.to_csv(csv_path, index=False) with pytest.raises(ValueError, match="exceeds image dimensions"): - KeypointDataset( - csv_file=csv_path, root_dir=keypoint_root_dir, label_dict={"Tree": 0} + PointDataset( + csv_file=csv_path, root_dir=point_root_dir, label_dict={"Tree": 0} ) -def test_keypoint_dataset_validate_coordinates_negative(tmp_path, keypoint_root_dir): +def test_point_dataset_validate_coordinates_negative(tmp_path, point_root_dir): """Negative coordinates should raise ValueError.""" image_name = "2019_BLAN_3_751000_4330000_image_crop.jpg" @@ -162,12 +162,12 @@ def test_keypoint_dataset_validate_coordinates_negative(tmp_path, keypoint_root_ df.to_csv(csv_path, index=False) with pytest.raises(ValueError, match="exceeds image dimensions"): - KeypointDataset( - csv_file=csv_path, root_dir=keypoint_root_dir, label_dict={"Tree": 0} + PointDataset( + csv_file=csv_path, root_dir=point_root_dir, label_dict={"Tree": 0} ) -def test_keypoint_dataset_empty_annotations(tmp_path, keypoint_root_dir): +def test_point_dataset_empty_annotations(tmp_path, point_root_dir): """Empty annotations (0,0) should produce empty targets.""" image_name = "2019_BLAN_3_751000_4330000_image_crop.jpg" @@ -182,19 +182,19 @@ def test_keypoint_dataset_empty_annotations(tmp_path, keypoint_root_dir): ) df.to_csv(csv_path, index=False) - ds = KeypointDataset( - csv_file=csv_path, root_dir=keypoint_root_dir, label_dict={"Tree": 0} + ds = PointDataset( + csv_file=csv_path, root_dir=point_root_dir, label_dict={"Tree": 0} ) image, targets, path = ds[0] assert targets["points"].shape == (0, 2) assert targets["labels"].shape == (0,) -def test_keypoint_dataset_filter_points(): +def test_point_dataset_filter_points(): """filter_points should remove out-of-bounds points.""" ds_csv = get_data("2019_BLAN_3_751000_4330000_image_crop_keypoints.csv") root_dir = os.path.dirname(ds_csv) - ds = KeypointDataset(csv_file=ds_csv, root_dir=root_dir, label_dict={"Tree": 0}) + ds = PointDataset(csv_file=ds_csv, root_dir=root_dir, label_dict={"Tree": 0}) points = torch.tensor([[10.0, 20.0], [-5.0, 30.0], [50.0, 60.0], [200.0, 300.0]]) labels = torch.tensor([0, 0, 0, 0]) @@ -209,11 +209,11 @@ def test_keypoint_dataset_filter_points(): ) -def test_keypoint_dataset_density_map(keypoint_csv, keypoint_root_dir): +def test_point_dataset_density_map(point_csv, point_root_dir): """Density map should place class-specific peaks at point locations.""" - ds = KeypointDataset( - csv_file=keypoint_csv, - root_dir=keypoint_root_dir, + ds = PointDataset( + csv_file=point_csv, + root_dir=point_root_dir, label_dict={"Tree": 0, "Shrub": 1}, density_sigma=2, output="density", @@ -229,11 +229,11 @@ def test_keypoint_dataset_density_map(keypoint_csv, keypoint_root_dir): assert torch.argmax(density[1]).item() == (15 * 100 + 70) -def test_keypoint_dataset_density_ignores_oob_points(keypoint_csv, keypoint_root_dir): +def test_point_dataset_density_ignores_oob_points(point_csv, point_root_dir): """Out-of-bounds points should not contribute to density map.""" - ds = KeypointDataset( - csv_file=keypoint_csv, - root_dir=keypoint_root_dir, + ds = PointDataset( + csv_file=point_csv, + root_dir=point_root_dir, label_dict={"Tree": 0}, density_sigma=2, output="density", @@ -254,11 +254,11 @@ def test_keypoint_dataset_density_ignores_oob_points(keypoint_csv, keypoint_root assert abs(peak_y - 16.0) <= 1 -def test_gaussian_density_count_normalization(keypoint_csv, keypoint_root_dir): +def test_gaussian_density_count_normalization(point_csv, point_root_dir): """Each class channel of the density map should sum to the number of points in that class.""" - ds = KeypointDataset( - csv_file=keypoint_csv, - root_dir=keypoint_root_dir, + ds = PointDataset( + csv_file=point_csv, + root_dir=point_root_dir, label_dict={"Tree": 0, "Shrub": 1}, output="density", ) diff --git a/tests/test_hf_models.py b/tests/test_hf_models.py index de693f634..5e10df73a 100644 --- a/tests/test_hf_models.py +++ b/tests/test_hf_models.py @@ -19,6 +19,9 @@ "weecology/cropmodel-deadtrees", ] +POINT_MODELS = [ + "weecology/deepforest-tree-point", +] @pytest.mark.parametrize("repo_id", CROP_MODELS) def test_load_crop_models(repo_id): @@ -37,3 +40,11 @@ def test_load_box_models(repo_id): assert df.model is not None # detection models should have label_dict on the underlying model assert getattr(df.model, "label_dict", None) is not None + +@pytest.mark.parametrize("repo_id", POINT_MODELS) +def test_load_point_models(repo_id): + df = main.deepforest(config="point") + df.load_model(model_name=repo_id) + assert df.model is not None + # detection models should have label_dict on the underlying model + assert getattr(df.model, "label_dict", None) is not None diff --git a/tests/test_point.py b/tests/test_point.py new file mode 100644 index 000000000..d9b651d2f --- /dev/null +++ b/tests/test_point.py @@ -0,0 +1,119 @@ +import os + +import pandas as pd +import pytest +import torch + +from deepforest import evaluate, get_data +from deepforest.main import deepforest +from deepforest.models.treeformer import TreeFormerModel + +@pytest.fixture() +def point_model(tmp_path_factory): + csv_file = get_data("2019_BLAN_3_751000_4330000_image_crop_keypoints.csv") + root_dir = os.path.dirname(csv_file) + m = deepforest(config="point", config_args={"train": + {"csv_file": csv_file, + "root_dir": root_dir, + "fast_dev_run": False}, + "validation": + {"csv_file": csv_file, + "root_dir": root_dir}, + "log_root": str(tmp_path_factory.mktemp("logs"))}) + return m + + +def test_point_prediction(point_model): + path = get_data("2019_BLAN_3_751000_4330000_image_crop.jpg") + prediction = point_model.predict_image(path=path) + + assert isinstance(prediction, pd.DataFrame) + assert not prediction.empty + assert set(prediction.columns) == {"x", "y", "label", "score", "image_path", "geometry"} + assert (prediction["score"] >= 0).all() and (prediction["score"] <= 1).all() + assert (prediction["label"] == "Tree").all() + + +def test_point_evaluation(point_model): + """Predict the sample image and check precision/recall against bundled ground truth.""" + path = get_data("2019_BLAN_3_751000_4330000_image_crop.jpg") + prediction = point_model.predict_image(path=path) + assert prediction is not None, "Model returned no predictions for the sample image" + + ground_df = pd.read_csv( + get_data("2019_BLAN_3_751000_4330000_image_crop_keypoints.csv") + ) + + results = evaluate.evaluate_geometry( + predictions=prediction, + ground_df=ground_df, + geometry_type="point", + distance_threshold=40, + ) + + assert results["point_recall"] >= 0.7, ( + f"point_recall {results['point_recall']:.2f} is below 0.7" + ) + assert results["point_precision"] >= 0.5, ( + f"point_precision {results['point_precision']:.2f} is below 0.5" + ) + +# Test train +def test_train_single(point_model): + point_model.create_trainer(limit_train_batches=1) + point_model.trainer.fit(point_model) + +def test_eval_single(point_model): + point_model.config.validation.val_accuracy_interval = 1 + point_model.create_trainer(limit_train_batches=1) + results = point_model.trainer.validate(point_model) + + assert len(results) == 1 + metrics = results[0] + assert "val_mae" in metrics + assert "point_precision" in metrics + assert "point_recall" in metrics + assert metrics["val_mae"] <= 5.0, f"Expected val_mae to be <= 5.0, got {metrics['val_mae']:.2f}" + assert 0.3 <= metrics["point_precision"] <= 1.0, f"Expected point_precision to be in [0.3, 1.0], got {metrics['point_precision']:.2f}" + assert 0.3 <= metrics["point_recall"] <= 1.0, f"Expected point_recall to be in [0.3, 1.0], got {metrics['point_recall']:.2f}" + +def test_treeformer_forward_pass_train(): + """Training forward pass returns a loss dict with a scalar, differentiable loss.""" + model = TreeFormerModel(backbone="pvt_v2_b0", num_classes=1) + model.train() + + images = torch.rand(2, 3, 128, 128) + targets = [ + {"points": torch.rand(5, 2) * 128, "labels": torch.zeros(5, dtype=torch.int64)}, + {"points": torch.rand(3, 2) * 128, "labels": torch.zeros(3, dtype=torch.int64)}, + ] + + output = model(images, targets) + + assert isinstance(output, dict) + assert "loss" in output + assert output["loss"].ndim == 0 + assert output["loss"].requires_grad + + +def test_treeformer_forward_pass_val(): + """Eval forward pass returns one prediction dict per image.""" + model = TreeFormerModel(backbone="pvt_v2_b0", num_classes=1) + model.eval() + + images = torch.rand(2, 3, 128, 128) + + with torch.no_grad(): + output = model(images) + + assert isinstance(output, list) + assert len(output) == 2 + + for pred in output: + assert "points" in pred + assert "scores" in pred + assert "labels" in pred + n = pred["points"].shape[0] + assert pred["points"].shape == (n, 2) + assert pred["scores"].shape == (n,) + assert pred["labels"].shape == (n,) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 0be8a665b..30c5a6c73 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -11,8 +11,7 @@ import torch from shapely import geometry -from deepforest import get_data -from deepforest import utilities +from deepforest import get_data, utilities @pytest.fixture() @@ -726,9 +725,15 @@ def test_format_geometry_point(): "scores": torch.tensor([0.9, 0.8]) } - # Format geometry should raise ValueError since point predictions are not supported - with pytest.raises(ValueError, match="Point predictions are not yet supported for formatting"): - utilities.format_geometry(prediction, geom_type="point") + result = utilities.format_geometry(prediction, geom_type="point") + + assert isinstance(result, pd.DataFrame) + assert list(result.columns) == ["x", "y", "label", "score", "geometry"] + assert len(result) == 2 + assert result.iloc[0]["x"] == 10 + assert result.iloc[0]["y"] == 20 + assert result.iloc[0]["score"] == pytest.approx(0.9) + assert isinstance(result.iloc[0]["geometry"], geometry.Point) def test_format_geometry_polygon():