As far as I know you can only activate/deactivate requires_grad
on a tensor, and not on distinct components of that tensor. Instead what you could do is zero out the values outside the band.
First create a mask for the band, you could use torch.ones
with torch.diagflat
:
>>> torch.diagflat(torch.ones(5), offset=1)
By setting the right dimension for torch.ones
as well as the right offset you can generate offset diagonal matrices with consistent shapes.
>>> N = 10; i = -1
>>> torch.diagflat(torch.ones(N-abs(i)), offset=i)
tensor([[0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.]])
>>> N = 10; i = 0
>>> torch.diagflat(torch.ones(N-abs(i)), offset=i)
tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])
>>> N = 10; i = 1
>>> torch.diagflat(torch.ones(N-abs(i)), offset=i)
tensor([[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 0.]])
You get the point, summing these matrices element-wise allows use to get a mask:
>>> N = 10; b = 3
>>> mask = sum(torch.diagflat(torch.ones(N-abs(i)), i) for i in range(-b//2,b//2+1))
>>> mask
tensor([[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[0., 1., 1., 1., 1.],
[0., 0., 1., 1., 1.]])
Then you can zero out the values outside the band on your nn.Linear
:
>>> m = nn.Linear(N, N)
>>> m.weight.data = m.weight * mask
>>> m.weight
Parameter containing:
tensor([[-0.3321, -0.3377, -0.0000, -0.0000, -0.0000],
[-0.4197, 0.1729, 0.2101, 0.0000, 0.0000],
[ 0.3467, 0.2857, -0.3919, -0.0659, 0.0000],
[ 0.0000, -0.4060, 0.0908, 0.0729, -0.1318],
[ 0.0000, -0.0000, -0.4449, -0.0029, -0.1498]], requires_grad=True)
Note, you might need to perform this on each forward pass as the parameters outside the band might get updated to non-zero values during the training. Of course, you can initialize mask
once and keep it in memory.
It would be more convenient to wrap everything into a custom nn.Module
.