Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,12 +525,15 @@ def predict_dataloader(self, ds, batch_size=None):
)
return loader

def predict_image(self, image: np.ndarray | None = None, path: str | None = None):
def predict_image(
self, image: np.ndarray | None = None, path: str | None = None, crop_model=None
):
"""Predict a single image with a deepforest model.

Args:
image: a float32 numpy array of a RGB with channels last format
path: optional path to read image from disk instead of passing image arg
crop_model: optional CropModel (or list) to classify detected crops; requires path

Returns:
result: A pandas dataframe of predictions (Default)
Expand All @@ -540,6 +543,10 @@ def predict_image(self, image: np.ndarray | None = None, path: str | None = None

if path:
image = np.array(Image.open(path).convert("RGB")).astype("float32")
elif image is None:
raise ValueError(
"Either image or path must be provided for single image prediction"
)

# sanity checks on input images
if not isinstance(image, np.ndarray):
Expand All @@ -559,19 +566,33 @@ def predict_image(self, image: np.ndarray | None = None, path: str | None = None
)
image = image.astype("float32")

result = predict._predict_image_(
model=self.model,
image=image,
ds = prediction.SingleImage(
path=path,
iou_threshold=self.config.nms_thresh,
nms_distance_thresh=self.config.point.nms_distance_thresh,
image=image,
patch_overlap=self.config.patch_overlap,
patch_size=max(image.shape[0], image.shape[1]),
return_metadata=True,
)
dataloader = self.predict_dataloader(ds, batch_size=self.config.batch_size)

results = predict._dataloader_wrapper_(
model=self,
trainer=self.trainer,
dataloader=dataloader,
crop_model=crop_model,
root_dir=os.path.dirname(path) if path else None,
)

# If there were no predictions, return None
if result is None:
if results.empty:
return None
else:
result["label"] = result.label.apply(lambda x: self.numeric_to_label_dict[x])

# Drop column offsets for single image
results = results.drop(columns=["window_xmin", "window_ymin"], errors="ignore")

results["label"] = results.label.apply(
lambda x: self.numeric_to_label_dict.get(x, x)
)

if path is None:
warnings.warn(
Expand All @@ -580,11 +601,11 @@ def predict_image(self, image: np.ndarray | None = None, path: str | None = None
"please assign results.root_dir = <directory name>",
stacklevel=2,
)
results = results.drop(columns=["image_path"], errors="ignore")
else:
root_dir = os.path.dirname(path)
result = utilities.read_file(result, root_dir=root_dir)
results = utilities.read_file(results, root_dir=os.path.dirname(path))

return result
return results

def predict_file(
self,
Expand Down Expand Up @@ -612,12 +633,15 @@ def predict_file(
dataloader = self.predict_dataloader(ds, batch_size=self.config.batch_size)
results = predict._dataloader_wrapper_(
model=self,
crop_model=crop_model,
trainer=self.trainer,
dataloader=dataloader,
crop_model=crop_model,
root_dir=root_dir,
)

if not results.empty:
results = utilities.read_file(results, root_dir)

results.root_dir = root_dir

return results
Expand Down
127 changes: 47 additions & 80 deletions src/deepforest/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,8 @@
from shapely import affinity
from torchvision.ops import nms

from deepforest import distributed, utilities
from deepforest import distributed
from deepforest.datasets import cropmodel
from deepforest.utilities import read_file


def _predict_image_(
model,
image: np.ndarray | None = None,
path: str | None = None,
iou_threshold: float = 0.15,
nms_distance_thresh: float = 5.0,
):
"""Predict a single image with a deepforest model.

Args:
model: a deepforest.main.model object
image: a tensor of shape (channels, height, width)
path: optional path to read image from disk instead of passing image arg
iou_threshold: IoU threshold for box non-max suppression
nms_distance_thresh: Distance threshold in pixels for point NMS, see config.point.nms_distance_thresh
Returns:
df: A pandas dataframe of predictions (Default)
img: The input with predictions overlaid (Optional)
"""
image = torch.tensor(image).permute(2, 0, 1)
image = image / 255

with torch.no_grad():
prediction = model(image.unsqueeze(0))

prediction = prediction[0]
geom_type = utilities.determine_geometry_type(prediction)
df = utilities.format_geometry(prediction, geom_type=geom_type)

# return None for no predictions
if df is None:
return None

if geom_type == "box" and df.label.nunique() > 1:
df = across_class_nms(df, iou_threshold=iou_threshold)
elif geom_type == "point":
df = reduce_points(df, nms_thresh=nms_distance_thresh)

# Add image path if provided
if path is not None:
df["image_path"] = os.path.basename(path)

return df


def translate_predictions(predictions: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -242,56 +196,69 @@ def _flatten_prediction_batches_(batched_results):
return pd.concat(flattened, ignore_index=True)


def _dataloader_wrapper_(model, trainer, dataloader, root_dir, crop_model):
def _apply_nms(image_results, config, task="box"):
"""Applies non-max suppression: across-class NMS for multi-class boxes,
distance-based suppression for points.

Args:
image_results: predictions for one image.
config: model configuration providing the NMS thresholds.
task: model task, "box" or "point".

Returns:
Reduced predictions for the image.
"""
if task == "box":
if image_results.label.nunique() > 1:
image_results = across_class_nms(
image_results, iou_threshold=config.nms_thresh
)
elif task == "point":
image_results = reduce_points(
image_results, nms_thresh=config.point.nms_distance_thresh
)
elif task == "polygon":
pass

return image_results


def _dataloader_wrapper_(model, trainer, dataloader, crop_model=None, root_dir=None):
"""Run inference over a dataloader and reduce predictions per image.

Returns a plain dataframe of image-space predictions with numeric labels.
Callers are responsible for label mapping and read_file/root_dir formatting.

Args:
model: deepforest.main object
trainer: a pytorch lightning trainer object
dataloader: pytorch dataloader object
root_dir: directory of images. If none, uses "image_dir" in config
nms_thresh: Non-max suppression threshold, see config.nms_thresh
crop_model: Optional. A list of crop models to be used for prediction.
crop_model: Optional CropModel (or list) to classify detected crops.
Requires root_dir.
root_dir: directory of images on disk
Returns:
results: pandas dataframe with bounding boxes, label and scores for each image in the csv file
results: pandas dataframe with predictions for each image in the dataloader
"""
batched_results = trainer.predict(model, dataloader)
results = distributed.gather_dataframe(_flatten_prediction_batches_(batched_results))

if results.empty:
return pd.DataFrame()

# Apply across class NMS for each image
# dropna=False keeps a null image_path when image_path is not available.
processed_results = []
for image_path in results.image_path.unique():
image_results = results[results.image_path == image_path].copy()
if image_results.label.nunique() > 1:
image_results = across_class_nms(
image_results, iou_threshold=model.config.nms_thresh
)
for _, group in results.groupby("image_path", dropna=False):
processed_results.append(_apply_nms(group, model.config, task=model.model.task))

if crop_model:
# Flag to check if only one model is passed
is_single_model = len(crop_model) == 1
results = pd.concat(processed_results, ignore_index=True)

for i, crop_model_item in enumerate(crop_model):
crop_model_results = _predict_crop_model_(
crop_model=crop_model_item,
results=image_results,
path=image_path,
trainer=trainer,
model_index=i,
is_single_model=is_single_model,
)

processed_results.append(crop_model_results)
else:
processed_results.append(image_results)

if processed_results:
results = pd.concat(processed_results, ignore_index=True)

results = read_file(results, root_dir)
if crop_model is not None:
if root_dir is None:
raise ValueError("crop_model requires a path/root_dir ")
results.root_dir = root_dir
results = _crop_models_wrapper_(
crop_models=crop_model, trainer=trainer, results=results
)

return results

Expand Down
Loading