Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@

include src/deepforest/conf/config.yaml
include src/deepforest/conf/point.yaml
include src/deepforest/conf/point_pretrain.yaml
include src/deepforest/data/testfile_deepforest.csv
include src/deepforest/data/testfile_multi.csv
include src/deepforest/data/classes.csv
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dependencies = [
"rasterio",
"safetensors<0.6.0",
"scikit-image>=0.25.2",
"scipy>=1.15.3",
"setuptools",
"shapely>2.0.0",
"slidingwindow",
Expand Down
63 changes: 63 additions & 0 deletions src/deepforest/conf/point_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# NEON pretraining config for point model.

defaults:
- config
- _self_

architecture: treeformer

# Train from ImageNet PvT weights, not the released point checkpoint.
model:
name: null
revision: main

num_classes: 1
label_dict:
Tree: 0

batch_size: 32
workers: 8
accelerator: auto
precision: bf16-mixed
matmul_precision: high
score_thresh: 0.15
patch_size: 512

# Replace the CSV paths here with the training dataset e.g. full NEON pretrain (all_annotations_train_fold_0.csv)
train:
csv_file: MISSING
root_dir: MISSING
lr: 2e-4
optimizer:
type: AdamW
weight_decay: 1e-5
epochs: 40
scheduler:
type: cosine
params:
T_max: 40
eta_min: 1e-6
augmentations:
- HorizontalFlip: {p: 0.5}
- VerticalFlip: {p: 0.5}
- Rotate: {degrees: 180, p: 0.5}

validation:
csv_file: MISSING
root_dir: MISSING
val_accuracy_interval: 1

point:
backbone: OpenGVLab/pvt_v2_b3
score_integration_radius: 2
distance_threshold: 10.0
density_sigma: 3.0
mae_weight: 0.025
ot_weight: 0.1
density_l1_weight: 0.05
count_cls_weight: 0.1
sinkhorn_reg: 1.0
num_of_iter_in_ot: 100
losses: [count, ot, density_l1]
norm_cood: true
enforce_count: false
23 changes: 22 additions & 1 deletion src/deepforest/conf/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,34 @@ class CropModelConfig:

@dataclass
class PointConfig:
"""Configuration for point models."""
"""Configuration for point models.

The loss fields configure training for density/point models such as
TreeFormer; defaults mirror the ``TreeFormerModel`` constructor.
``losses`` selects the active terms (``None`` enables all of
``count``, ``ot``, ``density_l1``, ``count_cls``). ``norm_cood``
normalises OT coordinates to [-1, 1] (global transport) and affects
only the OT loss, not inference; ``enforce_count`` rescales the
density map to a predicted count and does affect inference.
"""

backbone: str = "pvt_v2_b3"
score_integration_radius: int = 5
nms_distance_thresh: float = 5.0
distance_threshold: float = 10.0

# Training loss hyperparameters (used by TreeFormer / density models).
density_sigma: float = 5.0
mae_weight: float = 1.0
ot_weight: float = 0.1
density_l1_weight: float = 0.01
count_cls_weight: float = 1.0
sinkhorn_reg: float = 1.0
num_of_iter_in_ot: int = 100
losses: list[str] | None = None
norm_cood: bool = False
enforce_count: bool = True


@dataclass
class Config:
Expand Down
Empty file.
240 changes: 240 additions & 0 deletions src/deepforest/losses/ot_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""PyTorch implementation of Sinkhorn-Knopp for optimal transport.

Based on `ot.bregman.sinkhorn` from the Python Optimal Transport library,
https://pythonot.github.io, rewritten in PyTorch and adapted from the
DM-Count/TreeFormer repositories.

Some small changes have been made to the original code for stability and
logging.

Original implementation:
https://github.com/cvlab-stonybrook/DM-Count/blob/master/losses/ot_loss.py

Reference: M. Cuturi, "Sinkhorn Distances: Lightspeed Computation of Optimal
Transport", NeurIPS 2013.
"""

from typing import cast

import torch
from torch.nn import Module

M_EPS = 1e-16


def sinkhorn_knopp(
a: torch.Tensor,
b: torch.Tensor,
C: torch.Tensor,
reg: float = 1e-1,
maxIter: int = 1000,
stopThr: float = 1e-9,
log: bool = False,
eval_freq: int = 10,
warm_start: dict | None = None,
) -> tuple[torch.Tensor, dict] | torch.Tensor:
"""Solve entropic-regularized OT via Sinkhorn-Knopp matrix scaling.

Minimises ``<gamma, C>_F + reg * sum(gamma * log(gamma))``
subject to ``gamma @ 1 = a`` and ``gamma.T @ 1 = b``.

Args:
a: (na,) source measure (sums to 1).
b: (nb,) target measure (sums to 1).
C: (na, nb) cost matrix.
reg: Entropic regularization strength (> 0).
maxIter: Maximum number of Sinkhorn iterations.
stopThr: Early-stop threshold on marginal error.
log: If True, return ``(P, log_dict)``; otherwise return ``P``.
eval_freq: Check convergence every this many iterations.
warm_start: Optional dict with keys ``"u"`` and ``"v"`` to resume
from a previous solve.

Returns:
P (na, nb) optimal transport plan, and optionally a log dict
containing ``"u"``, ``"v"``, ``"alpha"``, ``"beta"`` (dual variables)
and ``"err"`` (list of marginal errors).
"""
device = a.device
na, nb = C.shape

assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b doesn't match C"
assert reg > 0, "reg must be > 0"

log_dict: dict = {"err": []} if log else {}

if warm_start is not None:
u = warm_start["u"]
v = warm_start["v"]
else:
u = torch.ones(na, dtype=a.dtype, device=device) / na
v = torch.ones(nb, dtype=b.dtype, device=device) / nb

K = torch.exp(C / -reg)

it = 1
err = 1.0
while err > stopThr and it <= maxIter:
upre, vpre = u, v
v = b / (torch.mv(K.t(), u) + M_EPS)
u = a / (torch.mv(K, v) + M_EPS)

if torch.any(torch.isnan(u) | torch.isinf(u)) or torch.any(
torch.isnan(v) | torch.isinf(v)
):
u, v = upre, vpre
break

if log and it % eval_freq == 0:
b_hat = torch.mv(K.t(), u) * v
err = (b - b_hat).pow(2).sum().item()
log_dict["err"].append(err)

it += 1

if log:
log_dict["u"] = u
log_dict["v"] = v
log_dict["alpha"] = reg * torch.log(u + M_EPS)
log_dict["beta"] = reg * torch.log(v + M_EPS)
log_dict["its"] = it
log_dict["K_min"] = K.min().item()
# Final marginal error (recompute if not already stored from last eval).
b_hat = torch.mv(K.t(), u) * v
log_dict["sinkhorn_err"] = (b - b_hat).pow(2).sum().item()

P = u.unsqueeze(1) * K * v.unsqueeze(0)
return (P, log_dict) if log else P


class OT_Loss(Module):
def __init__(self, norm_coord, device, num_of_iter_in_ot=100, reg=1.0):
super().__init__()
self.device = device
self.norm_coord = norm_coord
self.num_of_iter_in_ot = num_of_iter_in_ot
self.reg = reg

@torch.autocast("cuda", enabled=False)
@torch.autocast("cpu", enabled=False)
def forward(self, normed_density, unnormed_density, points):
# Disable AMP autocast and force float32 to prevent float16
# underflow in Sinkhorn. K = exp(C / -reg) underflows to zero
# in float16 for squared distances > ~11, collapsing the
# transport plan entirely.
normed_density = normed_density.float()
unnormed_density = unnormed_density.float()
points = [p.float() for p in points]

batch_size = normed_density.size(0)
assert len(points) == batch_size
output_h = normed_density.size(2)
output_w = normed_density.size(3)

# Define grid coordinates (centered)
x_cood = (
torch.arange(output_w, dtype=torch.float32, device=self.device) + 0.5
).unsqueeze(0)
y_cood = (
torch.arange(output_h, dtype=torch.float32, device=self.device) + 0.5
).unsqueeze(0)

# Optionally normalize coordinates to [-1, 1]. We generally
# recommend this.
if self.norm_coord:
x_cood = x_cood / output_w * 2 - 1
y_cood = y_cood / output_h * 2 - 1

loss_terms = []
ot_obj_values = torch.zeros([1]).to(self.device)
wd = 0 # Wasserstein distance
n_active = 0 # Total number of points over all images
total_its = 0 # Accumulated Sinkhorn iterations
total_K_min = 0.0 # Accumulated K_min for diagnostics
total_beta_abs_max = 0.0 # Accumulated max |beta| for divergence canary
total_sinkhorn_err = 0.0 # Accumulated final marginal error

for idx, im_points in enumerate(points):
if len(im_points) > 0:
n_active += 1

# compute l2 square distance, it should be source target distance. [#gt, #cood * #cood]
if self.norm_coord:
x = im_points[:, 0].unsqueeze(1) / output_w * 2 - 1
y = im_points[:, 1].unsqueeze(1) / output_h * 2 - 1
else:
x = im_points[:, 0].unsqueeze(1)
y = im_points[:, 1].unsqueeze(1)

x_dis = (
-2 * torch.matmul(x, x_cood) + x * x + x_cood * x_cood
) # [#gt, #cood]
y_dis = -2 * torch.matmul(y, y_cood) + y * y + y_cood * y_cood
y_dis.unsqueeze_(2)
x_dis.unsqueeze_(1)
dis = y_dis + x_dis
dis = dis.view((dis.size(0), -1)) # size of [#gt, #cood * #cood]

source_prob = normed_density[idx][0].view([-1]).detach()
target_prob = (torch.ones([len(im_points)]) / len(im_points)).to(
self.device
)

# use sinkhorn to solve OT, compute optimal beta.
P, log = sinkhorn_knopp(
target_prob,
source_prob,
dis,
self.reg,
maxIter=self.num_of_iter_in_ot,
log=True,
)
wd += torch.sum(dis * P).item()

log = cast(dict[str, torch.Tensor], log)
total_its += log["its"]
total_K_min += log["K_min"]
total_sinkhorn_err += log["sinkhorn_err"]
# beta (dual variable = reg * log(v + eps))
beta = log["beta"] # [#cood * #cood]
total_beta_abs_max = max(total_beta_abs_max, beta.abs().max().item())
ot_obj_values = ot_obj_values + torch.sum(
normed_density[idx] * beta.view([1, output_h, output_w])
)
# compute the gradient of OT loss to predicted density (unnormed_density).
# im_grad = beta / source_count - < beta, source_density> / (source_count)^2
source_density = unnormed_density[idx][0].view([-1]).detach()
source_count = source_density.sum()
im_grad_1 = (
(source_count) / (source_count * source_count + 1e-8) * beta
) # size of [#cood * #cood]
im_grad_2 = (source_density * beta).sum() / (
source_count * source_count + 1e-8
) # size of 1
im_grad = im_grad_1 - im_grad_2
im_grad = im_grad.detach().view([1, output_h, output_w])
# Define loss = <im_grad, predicted density>. The gradient of loss w.r.t prediced density is im_grad.
loss_terms.append(torch.sum(unnormed_density[idx] * im_grad))

if n_active > 0:
loss = torch.stack(loss_terms).sum() / n_active
ot_obj_values = ot_obj_values / n_active
else:
# All images in this batch have zero points. Keep loss connected
# to unnormed_density so DDP gradient buckets fire on every rank;
# the 0.0 multiplier means no actual gradient flows.
loss = 0.0 * unnormed_density.sum()

avg_its = total_its / n_active if n_active > 0 else 0
avg_K_min = total_K_min / n_active if n_active > 0 else 0.0
avg_beta_abs_max = total_beta_abs_max / n_active if n_active > 0 else 0.0
avg_sinkhorn_err = total_sinkhorn_err / n_active if n_active > 0 else 0.0
return (
loss,
wd,
ot_obj_values,
avg_its,
avg_K_min,
avg_beta_abs_max,
avg_sinkhorn_err,
)
25 changes: 21 additions & 4 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,8 +805,25 @@ def training_step(self, batch, batch_idx):
images, targets, image_names = batch
loss_dict = self.model.forward(images, targets)

# sum of regression and classification loss
losses = sum(loss_dict.values())
# Models that compute a combined total (e.g. TreeFormer) return it under
# "loss" alongside diagnostics. We return both so that we can log via Lightning
# but we don't wantt to backprop.
losses = loss_dict["loss"] if "loss" in loss_dict else sum(loss_dict.values())

# Guard against non-finite losses
if not torch.isfinite(losses):
non_finite = {
k: v.item()
for k, v in loss_dict.items()
if torch.is_tensor(v) and not torch.isfinite(v)
}
warnings.warn(
f"Non-finite loss {losses.item()}; non-finite components: "
f"{non_finite}. Skipping batch.",
stacklevel=2,
)
trainable = [p for p in self.model.parameters() if p.requires_grad]
return 0.0 * sum(p.sum() for p in trainable) if trainable else None

# Log loss
for key, value in loss_dict.items():
Expand Down Expand Up @@ -839,8 +856,8 @@ def validation_step(self, batch, batch_idx):
with torch.no_grad():
loss_dict = self.model.forward(images, targets)

# sum of regression and classification loss
losses = sum(loss_dict.values())
# Similar to train_step, only backprop through loss terms.
losses = loss_dict["loss"] if "loss" in loss_dict else sum(loss_dict.values())

# Log losses
for key, value in loss_dict.items():
Expand Down
Loading
Loading