diff --git a/docs/_static/metadata_prior_example.png b/docs/_static/metadata_prior_example.png new file mode 100644 index 000000000..f6eca4fd6 Binary files /dev/null and b/docs/_static/metadata_prior_example.png differ diff --git a/docs/user_guide/03_cropmodels.md b/docs/user_guide/03_cropmodels.md index be17c2f2a..e5024fd44 100644 --- a/docs/user_guide/03_cropmodels.md +++ b/docs/user_guide/03_cropmodels.md @@ -2,137 +2,214 @@ ## Classifying Objects After Object Detection -One of the most requested features since the early days of DeepForest was the ability to apply a follow-up model to predicted bounding boxes. For example, if you use the 'tree' or 'bird' backbone, you might want to classify each detection with your own model without retraining the upstream detector. +DeepForest provides a cascaded approach to object classification. The `CropModel` is a simple classification model that can be applied to each detected object from a deepforest detection model. +For example, a user might use one of the prebuilt detection models, such as the bird detector, but wish to classify each detection to a finer set of labels. -Beginning in version 1.4.0, the `CropModel` class can be used in conjunction with `predict_tile` and `predict_image` methods. The general workflow involves first applying the object detection model, extracting the prediction locations into images (which can optionally be saved to disk), and then applying a second model on each cropped image. +## Using existing models -New columns `cropmodel_label` and `cropmodel_score` will appear alongside the object detection model's label and score. +```python +from deepforest import main +from deepforest.model import CropModel -## Benefits +detector = main.deepforest() +detector.load_model("weecology/deepforest-tree") -Why would you want to apply a model directly to each crop? Why not train a multi-class object detection model? +genus_model = CropModel.load_model("weecology/cropmodel-tree-genus") +crop_results = detector.predict_tile(path=path, crop_model=crop_model) +``` -While that approach is certainly valid, there are a few key benefits to using CropModels, especially in common use cases: +A `CropModel` is a PyTorch Lightning object. -- **Flexible Labeling**: Object detection models require that all objects of a particular class be annotated within an image, which can be impossible for detailed category labels. For example, you might have bounding boxes for all 'trees' in an image, but only have species or health labels for a small portion of them based on ground surveys. Training a multi-class object detection model would mean training on only a portion of your available data. -- **Simpler and Extendable**: CropModels decouple detection and classification workflows, allowing separate handling of challenges like class imbalance and incomplete labels, without reducing the quality of the detections. Two-stage object detection models can be finicky with similar classes and often require expertise in managing learning rates. -- **New Data and Multi-sensor Learning**: In many applications, the data needed for detection and classification may differ. The CropModel concept provides an extendable piece that allows for advanced pipelines. +```python +import torch +from deepforest.model import CropModel -## Considerations +# Test forward pass +x = torch.rand(4, 3, 224, 224) +output = crop_model.forward(x) +assert output.shape == (4, 2) +``` -- **Efficiency**: Using a CropModel will be slower, as for each detection, the sensor data needs to be cropped and passed to the detector. This is less efficient than using a combined classification/detection system like multi-class detection models. While modern GPUs mitigate this to some extent, it is still something to be mindful of. -- **Lack of Spatial Awareness**: The model knows only about the pixels inside the crop and cannot use features outside the bounding box. This lack of spatial awareness can be a major limitation. It is possible, but untested, that multi-class detection models might perform better in such tasks. A box attention mechanism, like in [this paper](https://arxiv.org/abs/2111.13087), could be a better approach. +### Training -## Single Crop Model +```python +# Load data and write crops to disk -Consider a test file with tree boxes and an 'Alive/Dead' label that comes with all DeepForest installations: +df = read_file('path/to/annotations') -```python +# Initialize model +crop_model = CropModel(num_classes=2) -import pandas as pd -from deepforest import model -from deepforest import main as m -from deepforest.utilities import get_data +boxes = df[['xmin', 'ymin', 'xmax', 'ymax']].values.tolist() +labels=df.label.values -df = pd.read_csv(get_data("testfile_multi.csv")) -crop_model = model.CropModel(num_classes=2) -# Or set up the crop model or load weights model.CropModel.load_from_checkpoint() +crop_model.write_crops(boxes,labels,images,savedir) -m.create_trainer() -result = m.predict_tile(path=path, crop_model=crop_model) +# Create trainer +crop_model.create_trainer( + max_epochs=10, + accelerator="gpu", + devices=1 +) + +# Load data +crop_model.load_from_disk( + train_dir="path/to/train", + val_dir="path/to/val" +) + +# Train +crop_model.trainer.fit(crop_model) + +# Validate +crop_model.trainer.validate(crop_model) + +# Save checkpoint +crop_model.trainer.save_checkpoint("model.ckpt") ``` +### Evaluation + +The model provides several evaluation metrics: + ```python -result.head() -# Output: -# xmin ymin xmax ... image_path cropmodel_label cropmodel_score -# 0 273.0 230.0 313.0 ... SOAP_061.png 1 0.519510 -# 1 47.0 82.0 81.0 ... SOAP_061.png 1 0.506423 -# 2 0.0 72.0 34.0 ... SOAP_061.png 1 0.505258 -# 3 341.0 40.0 374.0 ... SOAP_061.png 1 0.517231 -# 4 0.0 183.0 26.0 ... SOAP_061.png 1 0.513122 +# Get validation metrics +metrics = crop_model.trainer.validate(crop_model) + +# Get confusion matrix +labels, predictions = crop_model.val_dataset_confusion() ``` -## Multiple Crop Models +(spatial-temporal-metadata)= +## Spatial-Temporal Metadata -You can also pass multiple crop models to `predict_tile`. Each model's predictions and confidence scores will be stored in separate columns. +In biodiversity monitoring, species distributions vary by location and season. The CropModel supports an optional spatial-temporal metadata embedding that provides location and date context alongside image features to improve classification. The metadata signal is by default "gentle" — it contributes only a small portion of the feature vector. This means the model still classifies primarily from visual appearance but can use location/season as a soft prior. When metadata is not provided at inference time, the model gracefully reverts to image-only classification. + +### How It Works + +When `use_metadata=True`, the CropModel: + +1. Encodes `(lat, lon, day_of_year)` +2. Projects features through a small metadata embedding layer +3. Concatenates this with the image features +4. Classifies the image using the combined features. + +### Inference with Metadata + +Pass a `metadata` dict to `predict_tile`: ```python -crop_model1 = model.CropModel(num_classes=2) -crop_model2 = model.CropModel(num_classes=3) -result = m.predict_tile(path=path, crop_model=[crop_model1, crop_model2]) +from deepforest import main +from deepforest.model import CropModel + +m = main.deepforest() +m.create_trainer() + +crop_model = CropModel(config_args={"use_metadata": True}) +crop_model.load_from_disk(train_dir="path/to/train", val_dir="path/to/val", + metadata_csv="metadata.csv") +crop_model.create_trainer(max_epochs=10) +crop_model.trainer.fit(crop_model) + +result = m.predict_tile( + path="image.tif", + crop_model=crop_model, + metadata={"lat": 35.2, "lon": -120.4, "date": "2024-06-15"} +) ``` -```python -result.head() -# Output: -# xmin ymin xmax ymax label score image_path cropmodel_label_0 cropmodel_score_0 cropmodel_label_1 cropmodel_score_1 -# 0 273.0 230.0 313.0 275.0 Tree 0.882591 SOAP_061.png 0 0.650223 1 0.383726 -# 1 47.0 82.0 81.0 120.0 Tree 0.740889 SOAP_061.png 0 0.621586 1 0.376401 -# 2 0.0 72.0 34.0 116.0 Tree 0.735777 SOAP_061.png 0 0.614928 1 0.394649 -# 3 341.0 40.0 374.0 77.0 Tree 0.668367 SOAP_061.png 0 0.598883 1 0.386490 -# 4 0.0 183.0 26.0 235.0 Tree 0.664668 SOAP_061.png 0 0.538162 1 0.439823 +All detected crops in the tile share the same metadata. If `metadata` is omitted, the model falls back to image-only classification. + +### Training with Metadata + +Training requires a CSV file that maps each crop image filename to its spatial-temporal metadata: + +```text +filename,lat,lon,date +bird_001.png,35.2,-120.4,2024-06-15 +bird_002.png,35.2,-120.4,2024-06-15 +mammal_001.png,40.1,-105.3,2024-07-20 ``` -A `CropModel` is a PyTorch Lightning object and can also be used like any other model. +- `filename` matches the image basename +- `date` is an ISO format string +- One CSV covers both train and val sets (filenames are unique) -```python +Pass the CSV when loading data: -import torch +```python from deepforest.model import CropModel -# Test forward pass -x = torch.rand(4, 3, 224, 224) -output = crop_model.forward(x) -assert output.shape == (4, 2) +crop_model = CropModel(config_args={"use_metadata": True}) +crop_model.load_from_disk( + train_dir="path/to/train", + val_dir="path/to/val", + metadata_csv="metadata.csv" +) +crop_model.create_trainer(max_epochs=10) +crop_model.trainer.fit(crop_model) ``` -## Writing Crops to Disk +### Configuration -We can either classify crops in memory or save them to disk. +The metadata embedding is controlled by three config parameters: ```python +crop_model = CropModel(config_args={ + "use_metadata": True, # Enable metadata fusion (default: False) + "metadata_dim": 32, # Embedding dimension (default: 32) + "metadata_dropout": 0.5, # Dropout on metadata path (default: 0.5) +}) +``` -import os +Or in `config.yaml`: -boxes = df[['xmin', 'ymin', 'xmax', 'ymax']].values.tolist() -image_path = os.path.join(os.path.dirname(get_data("SOAP_061.png")), df["image_path"].iloc[0]) -crop_model.write_crops(boxes=boxes, labels=df.label.values, image_path=image_path, savedir=tmpdir) +```yaml +cropmodel: + use_metadata: True + metadata_dim: 32 + metadata_dropout: 0.5 ``` -This saves each crop in labeled folders (`Alive/Dead`). +### Visualizing Metadata Priors -## Training +After training a metadata-enabled CropModel, it can be useful to inspect the +spatial-temporal branch by itself. The +{download}`metadata prior visualization script ` +loads a checkpoint, evaluates a lat/lon grid for one or more dates, and writes: -You can train a new model using PyTorch Lightning: +- A CSV with metadata-only logits, probabilities, and relative scores +- PNG maps for selected species and dates +- GeoTIFF rasters for GIS workflows -```python - -from deepforest.model import CropModel +For example: -crop_model.create_trainer(fast_dev_run=True) -# Get the data stored from the write_crops step above. -crop_model.load_from_disk(train_dir=tmpdir, val_dir=tmpdir) -crop_model.trainer.fit(crop_model) -crop_model.trainer.validate(crop_model) +```bash +uv run python docs/user_guide/examples/visualize_metadata_priors.py \ + --checkpoint path/to/metadata_cropmodel.ckpt \ + --species "Morus bassanus" \ + --dates 2024-04-15 \ + --bounds -98 18 -55 48 \ + --cell-degrees 1.0 \ + --output-dir outputs/metadata_prior_maps ``` -### Sampler +The map below shows a relative metadata prior for Northern Gannet +(`Morus bassanus`) on April 15, 2024. It reflects the learned metadata branch, +not image evidence. -Many classification tasks have imbalanced data, meaning that one class appears many more times than others. This leads to the model often choosing this class, regardless of visual appearance. To reduce this effect, a weighted_sampler randomly chooses images to show in training weighted by their inverse frequency. This means that rarer crops are shown more often to offset the common classes. This leads to better performance for rare classes, but can reduce performance on common classes by a small amount. To active the sampler, set the config: cropmodel -> sampler -> 'weighted_random'. +```{image} ../_static/metadata_prior_example.png +:alt: Metadata prior map for Morus bassanus over the western Atlantic +:width: 650px +``` -# Customizing +## Advanced Usage -The `CropModel` makes very few assumptions about the architecture and simply provides a container to make predictions at each detection. To specify a custom CropModel, use the `model` argument. +## Balance classes during training -```python -from deepforest.model import CropModel -from torchvision.models import resnet101 -backbone = resnet101(weights='DEFAULT') -crop_model = CropModel(num_classes=2, model=backbone) -``` +Many classification tasks have imbalanced data, meaning that one class appears many more times than others. This leads to the model often choosing this class, regardless of visual appearance. To reduce this effect, a weighted_sampler randomly chooses images to show in training weighted by their inverse frequency. This means that rarer crops are shown more often to offset the common classes. This leads to better performance for rare classes, but can reduce performance on common classes by a small amount. To active the sampler, set the config: cropmodel -> sampler -> 'weighted_random'. -## Configuring Image Resize +## Image Resize The CropModel can be configured to resize input images to different dimensions. By default, images are resized to 224x224 pixels, but this can be customized through the config: @@ -156,8 +233,6 @@ cropmodel: resize_interpolation: nearest # or 'bilinear' (default) ``` -The `resize_interpolation` option controls how crops are scaled to the target size. The default is `bilinear`. Use `nearest` when training on small crops where bilinear smoothing would blur important details. This is particularly useful when using custom models that expect different input sizes or when working with high-resolution imagery where preserving more detail is important. - ## Custom Transforms One detail to keep in mind is that the preprocessing transform will differ for backbones. Make sure to check the final lines: @@ -209,7 +284,20 @@ class CustomCropModel(CropModel): model = CustomCropModel() ``` -## Making Predictions Outside of predict_tile +### Reloading a Dataset + +```python +from deepforest.model import CropModel + +crop_model = CropModel.load_from_checkpoint("/path/to/ckpt") +crop_model.load_from_disk( + train_dir="/dir/train/", + val_dir="/dir/val/" +) +crop_model.create_trainer() +``` + +## Making predictions outside of predict_tile While `predict_tile` provides a convenient way to run predictions on detected objects, you can also use the CropModel directly for classification tasks. This is useful when you have pre-cropped images or want to run classification independently. @@ -224,13 +312,6 @@ import numpy as np # Load a trained model from checkpoint cropmodel = CropModel.load_from_checkpoint("path/to/checkpoint.ckpt") -# The model will automatically load: -# - The model architecture and weights -# - The label dictionary mapping class names to indices -# - The number of classes -# - Any hyperparameters saved during training -``` - ### Making Predictions on a Dataset ```python @@ -261,157 +342,16 @@ label, score = cropmodel.postprocess_predictions(crop_results) label_names = [cropmodel.numeric_to_label_dict[x] for x in label] ``` -### Reloading a Dataset from Disk for Validation - -```python -from deepforest.model import CropModel - -crop_model = CropModel.load_from_checkpoint("/blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/classification/checkpoints/d7e956055e23433a8892a8928a357385.ckpt") -crop_model.load_from_disk( - train_dir="/blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/classification/crops/train/d7e956055e23433a8892a8928a357385", - val_dir="/blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/classification/crops/val/d7e956055e23433a8892a8928a357385" -) -crop_model.create_trainer() -true_label, predicted_label = crop_model.val_dataset_confusion() -``` - -### Making Predictions on Single Images - -You can also make predictions on individual images or batches: - -```python -import torch -from PIL import Image - -# Load and preprocess a single image -image = Image.open("path/to/image.jpg") -transform = cropmodel.get_transform(augmentations=None) -tensor = transform(image).unsqueeze(0) # Add batch dimension - -# Make prediction -with torch.no_grad(): - output = cropmodel(tensor) - # Convert to numpy for postprocessing - output = output.cpu().numpy() - # Use the same postprocessing method - label, score = cropmodel.postprocess_predictions([output]) - class_name = cropmodel.numeric_to_label_dict[label[0]] - confidence = score[0] -``` - -## Model Architecture and Training - -The CropModel uses a ResNet-50 backbone by default, but can be customized with any PyTorch model. The model includes: - -- A classification head with the specified number of classes -- Standard image preprocessing (resize to 224x224, normalization) -- Data augmentation during training (random horizontal flips) -- Accuracy and precision metrics for evaluation - -### Training - -```python -# Initialize model -crop_model = CropModel(num_classes=2) - -# Create trainer -crop_model.create_trainer( - max_epochs=10, - accelerator="gpu", - devices=1 -) - -# Load data -crop_model.load_from_disk( - train_dir="path/to/train", - val_dir="path/to/val" -) - -# Train -crop_model.trainer.fit(crop_model) - -# Validate -crop_model.trainer.validate(crop_model) - -# Save checkpoint -crop_model.trainer.save_checkpoint("model.ckpt") -``` - -### Evaluation - -The model provides several evaluation metrics: - -```python -# Get validation metrics -metrics = crop_model.trainer.validate(crop_model) - -# Get confusion matrix -images, labels, predictions = crop_model.val_dataset_confusion(return_images=True) -``` - -### Confusion Matrix Visualization +## Multiple Crop Models -You can visualize the confusion matrix in several ways: +You can also pass multiple crop models to `predict_tile`. Each model's predictions and confidence scores will be stored in separate columns. ```python -import matplotlib.pyplot as plt -from torchmetrics.classification import MulticlassConfusionMatrix -import seaborn as sns - -# Method 1: Using torchmetrics -metric = MulticlassConfusionMatrix(num_classes=crop_model.num_classes) -metric.update(preds=predictions, target=labels) -fig, ax = metric.plot() -plt.title("Confusion Matrix") -plt.show() - -# Method 2: Using seaborn with val_dataset_confusion -images, labels, predictions = crop_model.val_dataset_confusion(return_images=True) -confusion_matrix = np.zeros((crop_model.num_classes, crop_model.num_classes)) -for true, pred in zip(labels, predictions): - confusion_matrix[true][pred] += 1 - -# Plot with seaborn -plt.figure(figsize=(10, 8)) -sns.heatmap(confusion_matrix, - annot=True, - fmt='g', - xticklabels=list(crop_model.label_dict.keys()), - yticklabels=list(crop_model.label_dict.keys())) -plt.title("Confusion Matrix") -plt.xlabel("Predicted") -plt.ylabel("True") -plt.show() - -# Get per-class metrics -from torchmetrics.classification import MulticlassPrecision, MulticlassRecall, MulticlassF1Score - -precision = MulticlassPrecision(num_classes=crop_model.num_classes) -recall = MulticlassRecall(num_classes=crop_model.num_classes) -f1 = MulticlassF1Score(num_classes=crop_model.num_classes) - -precision_score = precision(torch.tensor(predictions), torch.tensor(labels)) -recall_score = recall(torch.tensor(predictions), torch.tensor(labels)) -f1_score = f1(torch.tensor(predictions), torch.tensor(labels)) - -print(f"Precision: {precision_score:.3f}") -print(f"Recall: {recall_score:.3f}") -print(f"F1 Score: {f1_score:.3f}") +crop_model1 = model.CropModel(num_classes=2) +crop_model2 = model.CropModel(num_classes=3) +result = m.predict_tile(path=path, crop_model=[crop_model1, crop_model2]) ``` -This will give you a comprehensive view of your model's performance, including: -- A visual confusion matrix showing true vs predicted classes -- Per-class precision, recall, and F1 scores -- The ability to identify which classes are most commonly confused with each other - -The confusion matrix is particularly useful for: -- Identifying class imbalance issues -- Finding classes that are frequently confused -- Understanding the model's strengths and weaknesses -- Guiding decisions about data collection and model improvement - -## Advanced Usage - ### Custom Model Architecture You can use any PyTorch model as the backbone: @@ -444,6 +384,20 @@ class CustomCropModel(CropModel): return loss ``` +## Benefits + +Why would you want to apply a model directly to each crop? Why not train a multi-class object detection model? +While that approach is certainly valid, there are a few key benefits to using CropModels, especially in common use cases: + +- **Flexible Labeling**: Object detection models require that all objects of a particular class be annotated within an image, which can be impossible for detailed category labels. For example, you might have bounding boxes for all 'trees' in an image, but only have species or health labels for a small portion of them based on ground surveys. Training a multi-class object detection model would mean training on only a portion of your available data. +- **Simpler and Extendable**: CropModels decouple detection and classification workflows, allowing separate handling of challenges like class imbalance and incomplete labels, without reducing the quality of the detections. Two-stage object detection models can be finicky with similar classes and often require expertise in managing learning rates. +- **New Data and Multi-sensor Learning**: In many applications, the data needed for detection and classification may differ. The CropModel concept provides an extendable piece that allows for advanced pipelines. + +## Considerations + +- **Efficiency**: Using a CropModel will be slower, as for each detection, the sensor data needs to be cropped and passed to the detector. This is less efficient than using a combined classification/detection system like multi-class detection models. While modern GPUs mitigate this to some extent, it is still something to be mindful of. +- **Lack of Spatial Awareness**: The model knows only about the pixels inside the crop and cannot use features outside the bounding box. This lack of spatial awareness can be a major limitation. It is possible, but untested, that multi-class detection models might perform better in such tasks. A box attention mechanism, like in [this paper](https://arxiv.org/abs/2111.13087), could be a better approach. See the {ref}`spatial-temporal-metadata` section for an optional way to incorporate location and season information. + ## Pre-2.0 compatability Before DeepForest 2.0, the CropModel object did not save num_classes or label_dict hyperparameters, making it awkward to reload the checkpoint without these data. The model weights loaded, but the label_dict needed to be supplied independently. This is fixed in 2.0, if you recieve the warning. diff --git a/docs/user_guide/09_configuration_file.md b/docs/user_guide/09_configuration_file.md index bda0de7ba..c655b5f77 100644 --- a/docs/user_guide/09_configuration_file.md +++ b/docs/user_guide/09_configuration_file.md @@ -319,3 +319,15 @@ crop_model = CropModel() # Or use custom resize dimensions crop_model = CropModel(config_args={"resize": [300, 300]}) ``` + +### use_metadata + +Boolean flag to enable spatial-temporal metadata fusion. When `True`, the model accepts `(lat, lon, date)` alongside image crops and learns a small embedding that is concatenated with image features. Default is `False`. See {ref}`spatial-temporal-metadata` for usage details. + +### metadata_dim + +Dimension of the metadata embedding vector. A smaller value makes the metadata signal more gentle relative to the 2048-dim image features. Default is `32`. + +### metadata_dropout + +Dropout rate applied to the metadata embedding path. Higher values reduce the model's reliance on location/date information. Default is `0.5`. diff --git a/docs/user_guide/examples/visualize_metadata_priors.py b/docs/user_guide/examples/visualize_metadata_priors.py new file mode 100644 index 000000000..14a6d108d --- /dev/null +++ b/docs/user_guide/examples/visualize_metadata_priors.py @@ -0,0 +1,327 @@ +"""Map metadata-only class priors from a metadata-enabled CropModel checkpoint. + +This script visualizes what the spatial-temporal embedding branch contributes +to each class, independent of image content. It evaluates a coarse lat/lon grid +for one or more dates, then writes CSV score rasters and PNG maps. +""" + +from __future__ import annotations + +import argparse +import datetime as dt +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import rasterio +import torch +from rasterio.transform import from_origin + +from deepforest.model import CropModel + +try: + import contextily as ctx +except ImportError: # pragma: no cover - contextily is an optional visual enhancement. + ctx = None + + +SPECIES_ALIASES = { + "Northern Gannet": "Morus bassanus", + "Common Eider": "Somateria mollissima", +} + +DEFAULT_SPECIES = ["Morus bassanus", "Somateria mollissima"] +DEFAULT_DATES = ["2024-01-15", "2024-04-15", "2024-07-15", "2024-10-15"] +DEFAULT_BOUNDS = (-98.0, 18.0, -55.0, 48.0) # Gulf of Mexico + western Atlantic + + +def day_of_year(date: str) -> float: + return float(dt.datetime.strptime(date, "%Y-%m-%d").timetuple().tm_yday) + + +def resolve_species(species: list[str]) -> list[str]: + return [SPECIES_ALIASES.get(name, name) for name in species] + + +def make_grid( + bounds: tuple[float, float, float, float], cell_degrees: float +) -> pd.DataFrame: + min_lon, min_lat, max_lon, max_lat = bounds + lons = np.arange(min_lon + cell_degrees / 2, max_lon, cell_degrees) + lats = np.arange(min_lat + cell_degrees / 2, max_lat, cell_degrees) + lon_grid, lat_grid = np.meshgrid(lons, lats) + return pd.DataFrame( + { + "lon": lon_grid.ravel(), + "lat": lat_grid.ravel(), + } + ) + + +def load_metadata_model(checkpoint: str, device: str) -> CropModel: + model = CropModel.load_from_checkpoint(checkpoint, map_location=device) + model.eval() + model.to(device) + if ( + getattr(model, "metadata_encoder", None) is None + or getattr(model, "classifier", None) is None + ): + raise ValueError( + "Checkpoint is not metadata-enabled. Expected CropModel.metadata_encoder " + "and CropModel.classifier." + ) + return model + + +def metadata_prior_scores( + model: CropModel, + grid: pd.DataFrame, + date: str, + device: str, +) -> pd.DataFrame: + """Compute metadata-only logits and probabilities for every grid cell/class.""" + metadata = torch.tensor( + np.column_stack( + [ + grid["lat"].to_numpy(), + grid["lon"].to_numpy(), + np.full(len(grid), day_of_year(date)), + ] + ), + dtype=torch.float32, + device=device, + ) + + with torch.no_grad(): + meta_features = model.metadata_encoder(metadata) + meta_dim = meta_features.shape[1] + classifier = model.classifier + meta_weights = classifier.weight[:, -meta_dim:] + logits = meta_features @ meta_weights.T + if classifier.bias is not None: + logits = logits + classifier.bias + probabilities = torch.softmax(logits, dim=1) + + labels = model.numeric_to_label_dict + rows = [] + logits_np = logits.cpu().numpy() + probs_np = probabilities.cpu().numpy() + for class_idx, label in labels.items(): + class_scores = pd.DataFrame( + { + "date": date, + "class_idx": class_idx, + "species": label, + "lat": grid["lat"].to_numpy(), + "lon": grid["lon"].to_numpy(), + "metadata_logit": logits_np[:, class_idx], + "metadata_probability": probs_np[:, class_idx], + } + ) + rows.append(class_scores) + return pd.concat(rows, ignore_index=True) + + +def select_species_scores(scores: pd.DataFrame, species: list[str]) -> pd.DataFrame: + available = set(scores["species"].unique()) + missing = [name for name in species if name not in available] + if missing: + examples = sorted(available)[:20] + raise ValueError( + f"Species not found in checkpoint label_dict: {missing}. " + f"First available labels: {examples}" + ) + + selected = scores[scores["species"].isin(species)].copy() + selected["relative_score"] = selected.groupby(["date", "species"])[ + "metadata_logit" + ].transform( + lambda x: (x - x.min()) / (x.max() - x.min()) if x.max() > x.min() else 0.0 + ) + return selected + + +def _safe_name(value: str) -> str: + return value.lower().replace(" ", "_").replace("/", "_") + + +def plot_species_map( + scores: pd.DataFrame, + species: str, + date: str, + bounds: tuple[float, float, float, float], + output_path: Path, + plot_column: str, + cell_degrees: float, + cmap: str, + use_basemap: bool, +) -> None: + subset = scores[(scores["species"] == species) & (scores["date"] == date)] + pivot = subset.pivot(index="lat", columns="lon", values=plot_column).sort_index() + min_lon, min_lat, max_lon, max_lat = bounds + + fig, ax = plt.subplots(figsize=(12, 8)) + ax.set_xlim(min_lon, max_lon) + ax.set_ylim(min_lat, max_lat) + ax.set_aspect("equal") + + if use_basemap and ctx is not None: + try: + ctx.add_basemap( + ax, + crs="EPSG:4326", + source=ctx.providers.Esri.OceanBasemap, + attribution_size=5, + zorder=0, + ) + except Exception as exc: + print(f"Could not add basemap tiles: {exc}") + + image = ax.imshow( + pivot.to_numpy(), + extent=[ + pivot.columns.min() - cell_degrees / 2, + pivot.columns.max() + cell_degrees / 2, + pivot.index.min() - cell_degrees / 2, + pivot.index.max() + cell_degrees / 2, + ], + origin="lower", + cmap=cmap, + alpha=0.75, + zorder=2, + vmin=0 if plot_column == "relative_score" else None, + vmax=1 if plot_column == "relative_score" else None, + ) + fig.colorbar(image, ax=ax, label=plot_column.replace("_", " ")) + ax.set_title(f"{species} metadata prior, {date}") + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") + ax.grid(color="white", linewidth=0.3, alpha=0.4) + fig.savefig(output_path, dpi=250, bbox_inches="tight") + plt.close(fig) + + +def write_species_geotiff( + scores: pd.DataFrame, + species: str, + date: str, + output_path: Path, + plot_column: str, + cell_degrees: float, +) -> None: + subset = scores[(scores["species"] == species) & (scores["date"] == date)] + pivot = subset.pivot(index="lat", columns="lon", values=plot_column).sort_index() + array = np.flipud(pivot.to_numpy()).astype("float32") + transform = from_origin( + pivot.columns.min() - cell_degrees / 2, + pivot.index.max() + cell_degrees / 2, + cell_degrees, + cell_degrees, + ) + with rasterio.open( + output_path, + "w", + driver="GTiff", + height=array.shape[0], + width=array.shape[1], + count=1, + dtype="float32", + crs="EPSG:4326", + transform=transform, + ) as dst: + dst.write(array, 1) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Visualize metadata-only species priors from a CropModel checkpoint." + ) + parser.add_argument( + "--checkpoint", required=True, help="Metadata-enabled CropModel checkpoint." + ) + parser.add_argument( + "--species", + nargs="+", + default=DEFAULT_SPECIES, + help="Scientific names to map. Common aliases supported: Northern Gannet, Common Eider.", + ) + parser.add_argument( + "--dates", nargs="+", default=DEFAULT_DATES, help="YYYY-MM-DD dates to map." + ) + parser.add_argument( + "--bounds", + nargs=4, + type=float, + default=DEFAULT_BOUNDS, + metavar=("MIN_LON", "MIN_LAT", "MAX_LON", "MAX_LAT"), + ) + parser.add_argument( + "--cell-degrees", type=float, default=1.0, help="Grid cell size in degrees." + ) + parser.add_argument( + "--output-dir", type=Path, default=Path("outputs/metadata_prior_maps") + ) + parser.add_argument( + "--plot-column", + default="relative_score", + choices=["relative_score", "metadata_probability", "metadata_logit"], + help="Score column used for PNG coloring. CSV always contains all score columns.", + ) + parser.add_argument("--cmap", default="viridis") + parser.add_argument("--device", default="cpu") + parser.add_argument("--no-basemap", action="store_true") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + species = resolve_species(args.species) + grid = make_grid(tuple(args.bounds), args.cell_degrees) + model = load_metadata_model(args.checkpoint, args.device) + + all_scores = [] + for date in args.dates: + scores = metadata_prior_scores( + model=model, grid=grid, date=date, device=args.device + ) + selected = select_species_scores(scores, species) + all_scores.append(selected) + + for species_name in species: + output_stem = args.output_dir / f"{_safe_name(species_name)}_{date}" + plot_species_map( + scores=selected, + species=species_name, + date=date, + bounds=tuple(args.bounds), + output_path=output_stem.with_suffix(".png"), + plot_column=args.plot_column, + cell_degrees=args.cell_degrees, + cmap=args.cmap, + use_basemap=not args.no_basemap, + ) + write_species_geotiff( + scores=selected, + species=species_name, + date=date, + output_path=output_stem.with_suffix(".tif"), + plot_column=args.plot_column, + cell_degrees=args.cell_degrees, + ) + print(f"Wrote {output_stem.with_suffix('.png')}") + print(f"Wrote {output_stem.with_suffix('.tif')}") + + combined = pd.concat(all_scores, ignore_index=True) + csv_path = args.output_dir / "metadata_prior_scores.csv" + combined.to_csv(csv_path, index=False) + print(f"Wrote {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/src/deepforest/conf/config.yaml b/src/deepforest/conf/config.yaml index 63ef1a137..4059a2c52 100644 --- a/src/deepforest/conf/config.yaml +++ b/src/deepforest/conf/config.yaml @@ -142,6 +142,11 @@ cropmodel: normalize: # Number of pixels to expand bbox crop windows for better prediction context. expand: 0 + # Spatial-temporal metadata fusion (optional). + # When True, the model accepts (lat, lon, date) alongside image crops. + use_metadata: False + metadata_dim: 32 + metadata_dropout: 0.5 point: score_integration_radius: 5 diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index e9bcd133d..b39a20548 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -127,6 +127,9 @@ class CropModelConfig: resize_interpolation: str = "bilinear" normalize: Any = None expand: int = 0 + use_metadata: bool = False + metadata_dim: int = 32 + metadata_dropout: float = 0.5 @dataclass diff --git a/src/deepforest/datasets/cropmodel.py b/src/deepforest/datasets/cropmodel.py index 971098e3b..e80e545fa 100644 --- a/src/deepforest/datasets/cropmodel.py +++ b/src/deepforest/datasets/cropmodel.py @@ -8,6 +8,7 @@ import numpy as np import rasterio as rio +import torch from rasterio.windows import Window from torch.utils.data import Dataset from torchvision import transforms @@ -82,6 +83,7 @@ def __init__( resize_interpolation: str = "bilinear", normalize=None, expand: int = 0, + metadata=None, ): self.df = df @@ -100,6 +102,10 @@ def __init__( raise ValueError("expand must be >= 0") self.expand = int(expand) + # Optional spatial-temporal metadata per crop. + # Dict mapping crop index to (lat, lon, day_of_year). + self.metadata = metadata + unique_image = self.df["image_path"].unique() assert len(unique_image) == 1, ( "There should be only one unique image for this class object" @@ -149,4 +155,9 @@ def __getitem__(self, idx): else: image = box + if self.metadata is not None: + lat, lon, doy = self.metadata[idx] + meta_tensor = torch.tensor([lat, lon, doy], dtype=torch.float32) + return image, meta_tensor + return image diff --git a/src/deepforest/datasets/training.py b/src/deepforest/datasets/training.py index 6226ebd14..3b001744c 100644 --- a/src/deepforest/datasets/training.py +++ b/src/deepforest/datasets/training.py @@ -1,5 +1,6 @@ """Dataset model for object detection tasks.""" +import datetime import math import os from abc import abstractmethod @@ -8,6 +9,7 @@ import cv2 import kornia.augmentation as K import numpy as np +import pandas as pd import shapely import torch import torchvision @@ -877,3 +879,64 @@ def _classes_in(root): ) return train_ds, val_ds + + +class MetadataImageFolder(Dataset): + """Wrapper that adds spatial-temporal metadata to an ImageFolder dataset. + + Expects a CSV file with columns: filename, lat, lon, date. + The date column should be an ISO format string (e.g., "2024-06-15") + + Args: + image_folder: A FixedClassImageFolder (or ImageFolder) dataset. + metadata_csv: Path to CSV with columns filename, lat, lon, date. + + Returns per sample: + (image, label, metadata_tensor) where metadata_tensor is shape (3,) + containing [lat, lon, day_of_year]. + """ + + def __init__(self, image_folder, metadata_csv): + self.image_folder = image_folder + metadata_df = pd.read_csv(metadata_csv) + self._meta_lookup = {} + for _, row in metadata_df.iterrows(): + date = datetime.datetime.strptime(str(row["date"]), "%Y-%m-%d") + doy = float(date.timetuple().tm_yday) + self._meta_lookup[row["filename"]] = ( + float(row["lat"]), + float(row["lon"]), + doy, + ) + + def __len__(self): + return len(self.image_folder) + + def __getitem__(self, idx): + image, label = self.image_folder[idx] + filepath = self.image_folder.samples[idx][0] + filename = os.path.basename(filepath) + + if filename in self._meta_lookup: + lat, lon, doy = self._meta_lookup[filename] + else: + lat, lon, doy = 0.0, 0.0, 1.0 + + metadata = torch.tensor([lat, lon, doy], dtype=torch.float32) + return image, label, metadata + + @property + def targets(self): + return self.image_folder.targets + + @property + def class_to_idx(self): + return self.image_folder.class_to_idx + + @property + def samples(self): + return self.image_folder.samples + + @property + def imgs(self): + return self.image_folder.imgs diff --git a/src/deepforest/main.py b/src/deepforest/main.py index c30a9737a..9b6d5195d 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -1,4 +1,5 @@ # entry point for deepforest model +import datetime import importlib import os import warnings @@ -577,6 +578,7 @@ def predict_tile( iou_threshold=0.15, dataloader_strategy="single", crop_model=None, + metadata=None, ): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and @@ -593,6 +595,10 @@ def predict_tile( - "batch" loads the entire image into GPU memory and creates views of an image as batch, requires in the entire tile to fit into GPU memory. CPU parallelization is possible for loading images. - "window" loads only the desired window of the image from the raster dataset. Most memory efficient option, but cannot parallelize across windows. crop_model: a deepforest.model.CropModel object to predict on crops + metadata: Optional dict with keys "lat", "lon", "date" for + spatial-temporal context. "date" should be an ISO format + string (e.g., "2024-06-15"). Used by CropModel when + use_metadata=True in config. Returns: pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple @@ -727,6 +733,20 @@ def predict_tile( root_dir = None if crop_model is not None: + # Build per-crop metadata from image-level metadata dict + if metadata is not None: + date_str = metadata.get("date", None) + if date_str is not None: + doy = float( + datetime.datetime.strptime(str(date_str), "%Y-%m-%d") + .timetuple() + .tm_yday + ) + else: + doy = 1.0 + lat = float(metadata.get("lat", 0.0)) + lon = float(metadata.get("lon", 0.0)) + cropmodel_results = [] for path in paths: image_result = mosaic_results[ @@ -735,8 +755,19 @@ def predict_tile( if image_result.empty: continue image_result.root_dir = os.path.dirname(path) + + # Create per-crop metadata dict if metadata was provided + per_crop_metadata = None + if metadata is not None: + per_crop_metadata = dict.fromkeys( + range(len(image_result)), (lat, lon, doy) + ) + cropmodel_result = predict._crop_models_wrapper_( - crop_model, self.trainer, image_result + crop_model, + self.trainer, + image_result, + metadata=per_crop_metadata, ) cropmodel_results.append(cropmodel_result) cropmodel_results = pd.concat(cropmodel_results) diff --git a/src/deepforest/model.py b/src/deepforest/model.py index 97ebaf91c..f91a88f88 100644 --- a/src/deepforest/model.py +++ b/src/deepforest/model.py @@ -1,5 +1,6 @@ # Model - common class import json +import math import os from collections.abc import Mapping @@ -16,7 +17,7 @@ from torchvision import models, transforms from deepforest import utilities -from deepforest.datasets.training import create_aligned_image_folders +from deepforest.datasets.training import MetadataImageFolder, create_aligned_image_folders class BaseModel: @@ -104,6 +105,23 @@ def create_crop_backbone( return m +def create_crop_feature_backbone( + architecture: str = "resnet50", + pretrained: bool = True, +) -> tuple[torch.nn.Module, int]: + """Create a crop backbone that returns feature vectors.""" + if architecture not in _CROP_BACKBONES: + raise ValueError( + f"Unknown CropModel architecture '{architecture}'. " + f"Choose from {sorted(_CROP_BACKBONES)}." + ) + factory, default_weights = _CROP_BACKBONES[architecture] + m = factory(weights=default_weights if pretrained else None) + feature_dim = m.fc.in_features + m.fc = torch.nn.Identity() + return m, feature_dim + + def simple_resnet_50(num_classes: int = 2) -> torch.nn.Module: """Create a simple ResNet-50 model for classification. @@ -119,6 +137,66 @@ def simple_resnet_50(num_classes: int = 2) -> torch.nn.Module: return create_crop_backbone("resnet50", num_classes=num_classes) +def resnet50_backbone(): + """Create a ResNet-50 backbone that outputs 2048-dim feature vectors. + + Returns: + tuple: (backbone, feature_dim) where backbone is the model and + feature_dim is the output dimension (2048). + """ + return create_crop_feature_backbone("resnet50") + + +class SpatialTemporalEncoder(torch.nn.Module): + """Encode (lat, lon, day_of_year) into a fixed-size embedding. + + Uses sinusoidal features for smooth, periodic representation of + geographic coordinates and seasonality, followed by a small MLP. + + Args: + embed_dim: Output embedding dimension. Default 32. + dropout: Dropout rate on the embedding. Default 0.5. + + Input: + metadata: tensor of shape (batch, 3) with [lat, lon, day_of_year]. + lat in [-90, 90], lon in [-180, 180], day_of_year in [1, 366]. + + Output: + tensor of shape (batch, embed_dim). + """ + + def __init__(self, embed_dim: int = 32, dropout: float = 0.5): + super().__init__() + self.mlp = torch.nn.Sequential( + torch.nn.Linear(6, embed_dim), + torch.nn.ReLU(), + torch.nn.Dropout(dropout), + ) + + def forward(self, metadata): + lat = metadata[:, 0:1] + lon = metadata[:, 1:2] + doy = metadata[:, 2:3] + + lat_norm = lat / 90.0 + lon_norm = lon / 180.0 + doy_norm = (doy - 1) / 365.0 + + features = torch.cat( + [ + torch.sin(math.pi * lat_norm), + torch.cos(math.pi * lat_norm), + torch.sin(math.pi * lon_norm), + torch.cos(math.pi * lon_norm), + torch.sin(2 * math.pi * doy_norm), + torch.cos(2 * math.pi * doy_norm), + ], + dim=1, + ) + + return self.mlp(features) + + class CropModel(LightningModule, PyTorchModelHubMixin): """A PyTorch Lightning module for classifying image crops from object detection models. @@ -150,6 +228,9 @@ def __init__( super().__init__() self.model = model + self.backbone = None + self.metadata_encoder = None + self.classifier = None # Set the argument as the self.config, this way when reloading the checkpoint, self.config exists and is not overwritten. self.config = config if self.config is None: @@ -210,21 +291,44 @@ def create_model(self, num_classes: int, architecture: str | None = None): } ) - self.model = create_crop_backbone( - architecture=architecture, - num_classes=num_classes, - ) + use_metadata = self.config["cropmodel"].get("use_metadata", False) + + if use_metadata: + metadata_dim = self.config["cropmodel"].get("metadata_dim", 32) + metadata_dropout = self.config["cropmodel"].get("metadata_dropout", 0.5) + + backbone, feature_dim = create_crop_feature_backbone( + architecture=architecture, + ) + self.backbone = backbone + self.metadata_encoder = SpatialTemporalEncoder( + embed_dim=metadata_dim, dropout=metadata_dropout + ) + self.classifier = torch.nn.Linear(feature_dim + metadata_dim, num_classes) + self.model = None + else: + self.backbone = None + self.metadata_encoder = None + self.classifier = None + self.model = create_crop_backbone( + architecture=architecture, + num_classes=num_classes, + ) def create_trainer(self, **kwargs): """Create a pytorch lightning trainer object.""" self.trainer = Trainer(**kwargs) - def load_from_disk(self, train_dir, val_dir): + def load_from_disk(self, train_dir, val_dir, metadata_csv=None): """Load the training and validation datasets from disk. Args: train_dir (str): The directory containing the training dataset. val_dir (str): The directory containing the validation dataset. + metadata_csv (str, optional): Path to a CSV file mapping image + filenames to spatial-temporal metadata. The CSV should have + columns: filename, lat, lon, date. Required when + use_metadata=True in config. Defaults to None. Returns: None @@ -235,6 +339,13 @@ def load_from_disk(self, train_dir, val_dir): transform_train=self.get_transform(augmentations=["HorizontalFlip"]), transform_val=self.get_transform(augmentations=None), ) + + if metadata_csv is not None and self.config["cropmodel"].get( + "use_metadata", False + ): + self.train_ds = MetadataImageFolder(self.train_ds, metadata_csv) + self.val_ds = MetadataImageFolder(self.val_ds, metadata_csv) + self.label_dict = self.train_ds.class_to_idx # Create a reverse mapping from numeric indices to class labels @@ -242,7 +353,7 @@ def load_from_disk(self, train_dir, val_dir): self.num_classes = len(self.label_dict) - if self.model is None: + if self.model is None and self.backbone is None: self.create_model(self.num_classes) def get_transform(self, augmentations): @@ -403,14 +514,24 @@ def normalize(self): mean=list(norm_cfg["mean"]), std=list(norm_cfg["std"]) ) - def forward(self, x): - if self.model is None: + def forward(self, x, metadata=None): + if self.backbone is not None: + image_features = self.backbone(x) + if metadata is not None: + meta_features = self.metadata_encoder(metadata) + else: + meta_dim = self.classifier.in_features - image_features.shape[1] + meta_features = torch.zeros( + x.shape[0], meta_dim, device=x.device, dtype=x.dtype + ) + combined = torch.cat([image_features, meta_features], dim=1) + return self.classifier(combined) + elif self.model is not None: + return self.model(x) + else: raise AttributeError( "CropModel is not initialized. Provide 'num_classes' or load from a checkpoint." ) - output = self.model(x) - - return output def train_dataloader(self): """Train data loader.""" @@ -465,20 +586,27 @@ def val_dataloader(self): return val_loader def training_step(self, batch, batch_idx): - x, y = batch - outputs = self.forward(x) + if len(batch) == 3: + x, y, metadata = batch + else: + x, y = batch + metadata = None + outputs = self.forward(x, metadata=metadata) loss = F.cross_entropy(outputs, y) self.log("train_loss", loss) return loss def predict_step(self, batch, batch_idx): - # Check if batch is a tuple for validation_dataloader - if isinstance(batch, list): - x, y = batch + # Inference: batch may be (images, metadata), (images, labels, metadata), or a single images tensor. + if isinstance(batch, (list, tuple)) and len(batch) == 3: + images, _labels, metadata = batch + elif isinstance(batch, (list, tuple)) and len(batch) == 2: + images, metadata = batch else: - x = batch - outputs = self.forward(x) + images = batch + metadata = None + outputs = self.forward(images, metadata=metadata) yhat = F.softmax(outputs, 1) return yhat @@ -492,8 +620,12 @@ def postprocess_predictions(self, predictions): return label, score def validation_step(self, batch, batch_idx): - x, y = batch - outputs = self(x) + if len(batch) == 3: + x, y, metadata = batch + else: + x, y = batch + metadata = None + outputs = self(x, metadata=metadata) loss = F.cross_entropy(outputs, y) self.log("val_loss", loss) diff --git a/src/deepforest/predict.py b/src/deepforest/predict.py index 7396a023f..8703864b8 100644 --- a/src/deepforest/predict.py +++ b/src/deepforest/predict.py @@ -293,6 +293,7 @@ def _predict_crop_model_( augmentations=None, model_index=0, is_single_model=False, + metadata=None, ): """Predicts crop model on a raster file. @@ -340,6 +341,7 @@ def _predict_crop_model_( resize_interpolation=resize_interpolation, normalize=normalize, expand=expand, + metadata=metadata, ) # Create dataloader @@ -375,7 +377,7 @@ def _predict_crop_model_( def _crop_models_wrapper_( - crop_models, trainer, results, transform=None, augmentations=None + crop_models, trainer, results, transform=None, augmentations=None, metadata=None ): if crop_models is not None and not isinstance(crop_models, list): crop_models = [crop_models] @@ -398,6 +400,7 @@ def _crop_models_wrapper_( transform=transform, augmentations=augmentations, is_single_model=is_single_model, + metadata=metadata, ) crop_results.append(crop_result) diff --git a/tests/test_metadata_cropmodel.py b/tests/test_metadata_cropmodel.py new file mode 100644 index 000000000..04e76f4f6 --- /dev/null +++ b/tests/test_metadata_cropmodel.py @@ -0,0 +1,76 @@ +"""Tests for spatial-temporal metadata embeddings in CropModel.""" + +import os + +import numpy as np +import pandas as pd +import pytest +import torch +from PIL import Image +from torchvision.datasets import ImageFolder + +from deepforest import get_data +from deepforest.datasets.cropmodel import BoundingBoxDataset +from deepforest.datasets.training import MetadataImageFolder +from deepforest.model import CropModel, SpatialTemporalEncoder + +def test_crop_model_metadata_forward(): + cm = CropModel(config_args={"use_metadata": True, "metadata_dim": 32}) + cm.create_model(num_classes=5) + x = torch.rand(4, 3, 224, 224) + meta = torch.tensor([[35.0, -120.0, 145.0]] * 4) + out = cm.forward(x, metadata=meta) + assert out.shape == (4, 5) + +def test_crop_model_metadata_none(): + """When use_metadata=True but metadata=None, model should still predict.""" + cm = CropModel(config_args={"use_metadata": True}) + cm.create_model(num_classes=5) + x = torch.rand(4, 3, 224, 224) + out = cm.forward(x, metadata=None) + assert out.shape == (4, 5) + + +def test_crop_model_no_metadata_backward_compat(): + cm = CropModel() + cm.create_model(num_classes=2) + x = torch.rand(4, 3, 224, 224) + out = cm.forward(x) + assert out.shape == (4, 2) + assert cm.backbone is None + assert cm.metadata_encoder is None + assert cm.classifier is None + + +def test_training_step_with_metadata(): + cm = CropModel(config_args={"use_metadata": True}) + cm.create_model(num_classes=3) + x = torch.rand(4, 3, 224, 224) + y = torch.tensor([0, 1, 2, 0]) + meta = torch.rand(4, 3) + batch = (x, y, meta) + loss = cm.training_step(batch, 0) + assert isinstance(loss, torch.Tensor) + assert loss.ndim == 0 + + +@pytest.fixture() +def bbox_df(): + df = pd.read_csv(get_data("testfile_multi.csv")) + single_image = df.image_path.unique()[0] + return df[df.image_path == single_image].reset_index(drop=True) + + +def test_bounding_box_dataset_with_metadata(bbox_df): + root_dir = os.path.dirname(get_data("SOAP_061.png")) + n = len(bbox_df) + metadata = dict.fromkeys(range(n), (35.0, -120.0, 145.0)) + ds = BoundingBoxDataset(bbox_df, root_dir=root_dir, metadata=metadata) + item = ds[0] + assert isinstance(item, tuple) + assert len(item) == 2 + assert item[0].shape[0] == 3 + assert item[1].shape == (3,) + assert item[1][0] == 35.0 + assert item[1][1] == -120.0 + assert item[1][2] == 145.0