7. Custom Autograd
This matters because not every useful operation exists as a built-in module. Focus on the discipline: define forward and backward clearly, then verify gradients before trusting the layer.
7.1. A custom square function
[ ]:
import torch
from torch.autograd import Function
torch.manual_seed(37)
class Square(Function):
@staticmethod
def forward(ctx, input_tensor):
ctx.save_for_backward(input_tensor)
return input_tensor ** 2
@staticmethod
def backward(ctx, grad_output):
(input_tensor,) = ctx.saved_tensors
return grad_output * 2 * input_tensor
x = torch.tensor([1.5, -2.0], requires_grad=True)
y = Square.apply(x)
y.sum().backward()
print(y)
print(x.grad)
assert torch.allclose(x.grad, 2 * x.detach())
7.2. Gradient check
gradcheck numerically verifies that the custom backward matches finite differences.
[ ]:
input_tensor = torch.randn(4, dtype=torch.double, requires_grad=True)
result = torch.autograd.gradcheck(Square.apply, (input_tensor,), eps=1e-6, atol=1e-4)
print('gradcheck:', result)
assert result