@chenyaofo
2019-02-24T11:51:48.000000Z
字数 1986
阅读 667
We observe a very strange case when we compute min value on different data using GPU. Given a random tensor, it only takes 42.1 us(case A). However, when we replace it with the output of a Conv2d layer, it takes a very long time,3518.1 us(case B). We have no idea why there is such a big difference in running time. For convenience, we show the code and running results below. The results are obtained on pytorch 1.0.1.post2(py3.6_cuda10.0.130_cudnn7.4.2_2) with TITAN Xp(Driver Version: 415.27) on Ubuntu 16.04(4.4.0-131-generic).
Code:
import torch
@profile
def f(x):
N, _ = x.shape
therehold = torch.rand((N, 1), device=x.device)
mask = x.ge(therehold)
mask_sum = mask.sum(dim=1, keepdim=False)
res = mask_sum.min()
return res
device = torch.device("cuda:0")
for i in range(100):
x = torch.rand((128, 56 * 56), device=device)
res = f(x)
Run using line profiler:
Timer unit: 1e-06 s
Total time: 0.012572 s
File: a.py
Function: f at line 4
Line # Hits Time Per Hit % Time Line Contents
==============================================================
4 @profile
5 def f(x):
6 100 218.0 2.2 1.7 N, _ = x.shape
7 100 1273.0 12.7 10.1 therehold = torch.rand((N, 1), device=x.device)
8 100 2279.0 22.8 18.1 mask = x.ge(therehold)
9 100 4495.0 45.0 35.8 mask_sum = mask.sum(dim=1, keepdim=False)
10 100 4212.0 42.1 33.5 res = mask_sum.min()
11 100 95.0 0.9 0.8 return res
Code:
import torch
import torch.nn as nn
@profile
def f(x):
N, _ = x.shape
therehold = torch.rand((N, 1), device=x.device)
mask = x.ge(therehold)
mask_sum = mask.sum(dim=1, keepdim=False)
res = mask_sum.min()
return res
device = torch.device("cuda:0")
conv = nn.Conv2d(128, 1, 3, 3).to(device)
for i in range(100):
x = torch.rand((128, 128, 56, 56), device=device)
x = conv(x)
x = x.view(x.shape[0], -1)
res = f(x)
Run using line profiler:
Timer unit: 1e-06 s
Total time: 0.405717 s
File: b.py
Function: f at line 5
Line # Hits Time Per Hit % Time Line Contents
==============================================================
5 @profile
6 def f(x):
7 100 245.0 2.5 0.1 N, _ = x.shape
8 100 2303.0 23.0 0.6 therehold = torch.rand((N, 1), device=x.device)
9 100 3341.0 33.4 0.8 mask = x.ge(therehold)
10 100 5022.0 50.2 1.2 mask_sum = mask.sum(dim=1, keepdim=False)
11 100 394663.0 3946.6 97.3 res = mask_sum.min()
12 100 143.0 1.4 0.0 return res