diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 4ff14bca5..2be2499e2 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -525,12 +525,15 @@ def predict_dataloader(self, ds, batch_size=None): ) return loader - def predict_image(self, image: np.ndarray | None = None, path: str | None = None): + def predict_image( + self, image: np.ndarray | None = None, path: str | None = None, crop_model=None + ): """Predict a single image with a deepforest model. Args: image: a float32 numpy array of a RGB with channels last format path: optional path to read image from disk instead of passing image arg + crop_model: optional CropModel (or list) to classify detected crops; requires path Returns: result: A pandas dataframe of predictions (Default) @@ -540,6 +543,10 @@ def predict_image(self, image: np.ndarray | None = None, path: str | None = None if path: image = np.array(Image.open(path).convert("RGB")).astype("float32") + elif image is None: + raise ValueError( + "Either image or path must be provided for single image prediction" + ) # sanity checks on input images if not isinstance(image, np.ndarray): @@ -559,19 +566,33 @@ def predict_image(self, image: np.ndarray | None = None, path: str | None = None ) image = image.astype("float32") - result = predict._predict_image_( - model=self.model, - image=image, + ds = prediction.SingleImage( path=path, - iou_threshold=self.config.nms_thresh, - nms_distance_thresh=self.config.point.nms_distance_thresh, + image=image, + patch_overlap=self.config.patch_overlap, + patch_size=max(image.shape[0], image.shape[1]), + return_metadata=True, + ) + dataloader = self.predict_dataloader(ds, batch_size=self.config.batch_size) + + results = predict._dataloader_wrapper_( + model=self, + trainer=self.trainer, + dataloader=dataloader, + crop_model=crop_model, + root_dir=os.path.dirname(path) if path else None, ) # If there were no predictions, return None - if result is None: + if results.empty: return None - else: - result["label"] = result.label.apply(lambda x: self.numeric_to_label_dict[x]) + + # Drop column offsets for single image + results = results.drop(columns=["window_xmin", "window_ymin"], errors="ignore") + + results["label"] = results.label.apply( + lambda x: self.numeric_to_label_dict.get(x, x) + ) if path is None: warnings.warn( @@ -580,11 +601,11 @@ def predict_image(self, image: np.ndarray | None = None, path: str | None = None "please assign results.root_dir = ", stacklevel=2, ) + results = results.drop(columns=["image_path"], errors="ignore") else: - root_dir = os.path.dirname(path) - result = utilities.read_file(result, root_dir=root_dir) + results = utilities.read_file(results, root_dir=os.path.dirname(path)) - return result + return results def predict_file( self, @@ -612,12 +633,15 @@ def predict_file( dataloader = self.predict_dataloader(ds, batch_size=self.config.batch_size) results = predict._dataloader_wrapper_( model=self, - crop_model=crop_model, trainer=self.trainer, dataloader=dataloader, + crop_model=crop_model, root_dir=root_dir, ) + if not results.empty: + results = utilities.read_file(results, root_dir) + results.root_dir = root_dir return results diff --git a/src/deepforest/predict.py b/src/deepforest/predict.py index a4422a25f..7dac7603a 100644 --- a/src/deepforest/predict.py +++ b/src/deepforest/predict.py @@ -9,54 +9,8 @@ from shapely import affinity from torchvision.ops import nms -from deepforest import distributed, utilities +from deepforest import distributed from deepforest.datasets import cropmodel -from deepforest.utilities import read_file - - -def _predict_image_( - model, - image: np.ndarray | None = None, - path: str | None = None, - iou_threshold: float = 0.15, - nms_distance_thresh: float = 5.0, -): - """Predict a single image with a deepforest model. - - Args: - model: a deepforest.main.model object - image: a tensor of shape (channels, height, width) - path: optional path to read image from disk instead of passing image arg - iou_threshold: IoU threshold for box non-max suppression - nms_distance_thresh: Distance threshold in pixels for point NMS, see config.point.nms_distance_thresh - Returns: - df: A pandas dataframe of predictions (Default) - img: The input with predictions overlaid (Optional) - """ - image = torch.tensor(image).permute(2, 0, 1) - image = image / 255 - - with torch.no_grad(): - prediction = model(image.unsqueeze(0)) - - prediction = prediction[0] - geom_type = utilities.determine_geometry_type(prediction) - df = utilities.format_geometry(prediction, geom_type=geom_type) - - # return None for no predictions - if df is None: - return None - - if geom_type == "box" and df.label.nunique() > 1: - df = across_class_nms(df, iou_threshold=iou_threshold) - elif geom_type == "point": - df = reduce_points(df, nms_thresh=nms_distance_thresh) - - # Add image path if provided - if path is not None: - df["image_path"] = os.path.basename(path) - - return df def translate_predictions(predictions: pd.DataFrame) -> pd.DataFrame: @@ -242,18 +196,48 @@ def _flatten_prediction_batches_(batched_results): return pd.concat(flattened, ignore_index=True) -def _dataloader_wrapper_(model, trainer, dataloader, root_dir, crop_model): +def _apply_nms(image_results, config, task="box"): + """Applies non-max suppression: across-class NMS for multi-class boxes, + distance-based suppression for points. + + Args: + image_results: predictions for one image. + config: model configuration providing the NMS thresholds. + task: model task, "box" or "point". + + Returns: + Reduced predictions for the image. """ + if task == "box": + if image_results.label.nunique() > 1: + image_results = across_class_nms( + image_results, iou_threshold=config.nms_thresh + ) + elif task == "point": + image_results = reduce_points( + image_results, nms_thresh=config.point.nms_distance_thresh + ) + elif task == "polygon": + pass + + return image_results + + +def _dataloader_wrapper_(model, trainer, dataloader, crop_model=None, root_dir=None): + """Run inference over a dataloader and reduce predictions per image. + + Returns a plain dataframe of image-space predictions with numeric labels. + Callers are responsible for label mapping and read_file/root_dir formatting. Args: model: deepforest.main object trainer: a pytorch lightning trainer object dataloader: pytorch dataloader object - root_dir: directory of images. If none, uses "image_dir" in config - nms_thresh: Non-max suppression threshold, see config.nms_thresh - crop_model: Optional. A list of crop models to be used for prediction. + crop_model: Optional CropModel (or list) to classify detected crops. + Requires root_dir. + root_dir: directory of images on disk Returns: - results: pandas dataframe with bounding boxes, label and scores for each image in the csv file + results: pandas dataframe with predictions for each image in the dataloader """ batched_results = trainer.predict(model, dataloader) results = distributed.gather_dataframe(_flatten_prediction_batches_(batched_results)) @@ -261,37 +245,20 @@ def _dataloader_wrapper_(model, trainer, dataloader, root_dir, crop_model): if results.empty: return pd.DataFrame() - # Apply across class NMS for each image + # dropna=False keeps a null image_path when image_path is not available. processed_results = [] - for image_path in results.image_path.unique(): - image_results = results[results.image_path == image_path].copy() - if image_results.label.nunique() > 1: - image_results = across_class_nms( - image_results, iou_threshold=model.config.nms_thresh - ) + for _, group in results.groupby("image_path", dropna=False): + processed_results.append(_apply_nms(group, model.config, task=model.model.task)) - if crop_model: - # Flag to check if only one model is passed - is_single_model = len(crop_model) == 1 + results = pd.concat(processed_results, ignore_index=True) - for i, crop_model_item in enumerate(crop_model): - crop_model_results = _predict_crop_model_( - crop_model=crop_model_item, - results=image_results, - path=image_path, - trainer=trainer, - model_index=i, - is_single_model=is_single_model, - ) - - processed_results.append(crop_model_results) - else: - processed_results.append(image_results) - - if processed_results: - results = pd.concat(processed_results, ignore_index=True) - - results = read_file(results, root_dir) + if crop_model is not None: + if root_dir is None: + raise ValueError("crop_model requires a path/root_dir ") + results.root_dir = root_dir + results = _crop_models_wrapper_( + crop_models=crop_model, trainer=trainer, results=results + ) return results