point model training support#1396
Conversation
8f7359f to
17bff61
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1396 +/- ##
==========================================
- Coverage 86.61% 85.75% -0.87%
==========================================
Files 26 28 +2
Lines 3736 4121 +385
==========================================
+ Hits 3236 3534 +298
- Misses 500 587 +87
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
|
This looks good enough for review. I’ll run some more extensive tests to see if there’s any other features that need to be pulled in, but the core losses are all here. |
17bff61 to
e538222
Compare
|
I had one tiny issue that isn't probably related to this PR, but just in case. If you let the trainer default to the csv logger it hits which basically means that the schema is set on trainer.start and it hits metrics it wasn't def on_train_start(self):
# Register the metric key so CSVLogger knows about it from step 1
self.log("epoch level metric", 0.0)I think the right thing to do is merge this PR and address this at the module level, its not specific to this PR. |
|
Which metric is this triggered for? I’ll have a look if I missed something when I pulled these changes from the other branch. |
|
All of the epoch level metrics, this is just the csv logger, the tensorboard and comet logger handle this gracefully. I think this is a different PR that we have an object that captures the metric names and logs them to 0 or 'nan' on train start, since this would be true for any metric in any of the workflows. |
e538222 to
cb8898b
Compare
cb8898b to
aaad58a
Compare
Description
This PR adds training functionality for the default
pointmodel. Most of the code has been distilled from my experimental branch. It should replicate the existing checkpoint, and I'm running a quick test to make sure.Changes are relatively small in scope - added the optimal transport loss code, fleshed out
compute_lossesfor the model and fixed a small bug in point visualization where the VertexAnnotator wouldn't accept a palette (we instead convert points to "detections" with a radius and plot them as circles).I've included a
point_pretrainconfig which was used to train the current default point model on NEON data, replication should be possible with:uv deepforest --config-name point_pretrain trainRelated Issue(s)
Closes #809 as we now support training, visualization of predictions, simple unit tests + should also include multi-class support in theory but we've not tested it.
AI-Assisted Development
Claude code for assistance with merging the changes from the other branch.