2016-08-30 1 views
1

Je trouve une fonction comme type() pour identifier quelle variable est CudaTensor ou Normal.Torch, comment vérifier une variable est CUDA ou non?

require('cutorch') 

x = torch.Tensor(3,3) 
x = x:cuda() 

if type(x) == 'CudaTensor' then -- What function should be used? 
    print('x is CUDA tensor') 
else 
    print('x is normal tensor') 
end 

Répondre

1

Utilisez :type() méthode de tenseur:

cutorch = require('cutorch') 

x = torch.Tensor(3,3) 
x = x:cuda() 

if x:type() == 'torch.CudaTensor' then 
    print('x is CUDA tensor') 
else 
    print('x is normal tensor') 
end