Skip to content
18 changes: 5 additions & 13 deletions python/oneflow/nn/utils/clip_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,31 +100,23 @@ def clip_grad_norm_(
param0_placement = parameters[0].placement
if norm_type == float("inf"):
norms = [
p.grad.detach()
.to_global(sbp=sbp_broadcast)
.abs()
.max()
.to_global(placement=param0_placement)
p.grad.detach().abs().max().to_global(placement=param0_placement)
for p in parameters
]
total_norm = norms[0] if len(norms) == 1 else flow.max(flow.stack(norms))
elif norm_type == float("-inf"):
norms = [
p.grad.detach()
.to_global(sbp=sbp_broadcast)
.abs()
.min()
.to_global(placement=param0_placement)
p.grad.detach().abs().min().to_global(placement=param0_placement)
for p in parameters
]
total_norm = norms[0] if len(norms) == 1 else flow.min(flow.stack(norms))
else:
total_norm = flow.linalg.vector_norm(
flow.stack(
[
flow.linalg.vector_norm(
p.grad.detach().to_global(sbp=sbp_broadcast), norm_type
).to_global(placement=param0_placement)
flow.linalg.vector_norm(p.grad.detach(), norm_type).to_global(
placement=param0_placement
)
for p in parameters
]
),
Expand Down