#easiest solution is to wrap you model in DataParallel like so: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Model(input_size, output_size) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) model.to(device)
Here is what the above code is Doing:
1. We first check if multiple GPUs are available. If not, we just use the CPU.
2. We then create a model and wrap it in nn.DataParallel.
3. We move the model to the device.
Now, when we do model(input), all the copies of the model across all GPUs will run in parallel, and their outputs will be summed up.