7. Custom Autograd

This matters because not every useful operation exists as a built-in module. Focus on the discipline: define forward and backward clearly, then verify gradients before trusting the layer.

7.1. A custom square function

[ ]:
import torch
from torch.autograd import Function

torch.manual_seed(37)

class Square(Function):
    @staticmethod
    def forward(ctx, input_tensor):
        ctx.save_for_backward(input_tensor)
        return input_tensor ** 2

    @staticmethod
    def backward(ctx, grad_output):
        (input_tensor,) = ctx.saved_tensors
        return grad_output * 2 * input_tensor

x = torch.tensor([1.5, -2.0], requires_grad=True)
y = Square.apply(x)
y.sum().backward()
print(y)
print(x.grad)
assert torch.allclose(x.grad, 2 * x.detach())

7.2. Gradient check

gradcheck numerically verifies that the custom backward matches finite differences.

[ ]:
input_tensor = torch.randn(4, dtype=torch.double, requires_grad=True)
result = torch.autograd.gradcheck(Square.apply, (input_tensor,), eps=1e-6, atol=1e-4)
print('gradcheck:', result)
assert result