Simple Dropout using PyTorch

Link to Jupyter Notebook

Simple dropout to implement uncertainty estimates: Bayesian Neural Networks

Another notebook which uses PyTorch dropout: Link

In addition to predicting a value from a model it is also important to know the confidence in that prediction. Dropout is one way of estimating this. After multiple rounds of predictions, the mean and standard deviation in the prediction can be viewed as the prediction value and the corresponding confidence in the prediction. It is important to note that this is different from the error in the prediction. The model may have error in the prediction but could be precise in that value. It is similar to the idea of accuracy vs precision.

Type of uncertainties: Aleaotric and Epistemic uncertainty

  • Aleatoric uncertainty captures noise inherent in the observations
  • Epistemic uncertainty accounts for uncertainty in the model

The ideal way to measure epistemic uncertainty is to train many different models, each time using a different random seed and possibly varying hyperparameters. Then use all of them for each input and see how much the predictions vary. This is very expensive to do, since it involves repeating the whole training process many times. Fortunately, we can approximate the same effect in a less expensive way: by using dropout – effectively training a huge ensemble of different models all at once. Each training sample is evaluated with a different dropout mask, corresponding to a different random subset of the connections in the full model. Usually we only perform dropout during training and use a single averaged mask for prediction. But instead, let’s use dropout for prediction too. We can compute the output for lots of different dropout masks, then see how much the predictions vary. This turns out to give a reasonable estimate of the epistemic uncertainty in the outputs

# Training set
m = 100
x = (torch.rand(m) - 0.5) * 20 #Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1)
y = x * torch.sin(x) 

ground_state function

We define a simple feed-forward NN to learn the function.

# Define a simple NN 
class MLP(nn.Module):
    def __init__(self, hidden_layers=[20, 20], droprate=0.2, activation='relu'):
        super(MLP, self).__init__()
        
        self.model = nn.Sequential()
        self.model.add_module('input', nn.Linear(1, hidden_layers[0]))
        
        if activation == 'relu':
            self.model.add_module('relu0', nn.ReLU())
        
        elif activation == 'tanh':
            self.model.add_module('tanh0', nn.Tanh())
            
        for i in range(len(hidden_layers)-1):
            self.model.add_module('dropout'+str(i+1), nn.Dropout(p=droprate))
            self.model.add_module('hidden'+str(i+1), nn.Linear(hidden_layers[i], hidden_layers[i+1]))
            
            if activation == 'relu':
                self.model.add_module('relu'+str(i+1), nn.ReLU())
                
            elif activation == 'tanh':
                self.model.add_module('tanh'+str(i+1), nn.Tanh())
                
        self.model.add_module('dropout'+str(i+2), nn.Dropout(p=droprate))
        self.model.add_module('final', nn.Linear(hidden_layers[i+1], 1))
        
    def forward(self, x):
        return self.model(x)
        
# Define the model 
net = MLP(hidden_layers=[200, 100, 80], droprate=0.1).to(device) #Move model to the GPU 
print(net)

# Objective and optimizer 
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.005, weight_decay=0.00001)

# Training loop 
for epoch in range(6000):
    x_dev = x.view(-1, 1).to(device)
    y_dev = y.view(-1, 1).to(device)
    y_hat = net(x_dev)
    loss = criterion(y_hat, y_dev)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 500 == 0:
        print('Epoch[{}] - Loss:{}'.format(epoch, loss.item()))

Once a NN with a dropout implemented is instantiated, the model is called multiple times to predict the output for a given input. While doing so it is important to ensure the model is in train() state.

#Function to evaluate mean and std dev 
def predict_reg(model, X, T=1000):
    
    model = model.train()
    Y_hat = list()
    with torch.no_grad():
        for t in range(T):
            Y_hat.append(model(X.view(-1,1)).squeeze())
    Y_hat = torch.stack(Y_hat)
    
    model = model.eval()
    with torch.no_grad():
        Y_eval = model(X.view(-1,1)).squeeze()

    return Y_hat, Y_eval
#Prediction on the points
XX = torch.linspace(-11, 11, 1000) #New set of points 
y_hat, y_eval = predict_reg(net, XX, T=1000)
mean_y_hat = y_hat.mean(axis=0)
std_y_hat = y_hat.std(axis=0)
#Plotting 
# Visualise mean and mean ± std -> confidence range
fig, ax = plt.subplots(1,1, figsize=(10,10))
ax.plot(XX.numpy(), mean_y_hat.numpy(), 'C1', label='prediction')
ax.fill_between(XX.numpy(), (mean_y_hat + std_y_hat).numpy(), (mean_y_hat - std_y_hat).numpy(), color='C2', label='confidence')
ax.plot(x.numpy(), y.numpy(), 'oC0', label='ground truth')
ax.plot(XX.numpy(), (XX * torch.sin(XX)).numpy(), 'k', label='base function')
ax.axis('equal')
plt.legend()

final

Nifty tech tag lists from Wouter Beeftink