2017-10-11 6 views
0

Lorsque j'exécute le code suivant, le type de variable se transforme en torch.LongTensor. Comment puis-je avoir cela créer un torch.cuda.LongTensor à la place?define torch.cuda.LongTensor au lieu de torch.LongTensor

# Turn string into list of longs 
def char_tensor(string): 
    tensor = torch.zeros(len(string)).long() 
    for c in range(len(string)):   
     tensor[c] = all_characters.index(string[c]) 
    return Variable(tensor) 

print(char_tensor('abcDEF')) 

sortie:

Variable containing: 
10 
11 
12 
39 
40 
41 
[torch.LongTensor of size 6] 

Répondre

0

La réponse correcte:

# Turn string into list of longs 
def char_tensor(string): 
    tensor = torch.zeros(len(string)).long() 
    for c in range(len(string)):   
     tensor[c] = all_characters.index(string[c]) 
    return Variable(tensor).cuda() 

print(char_tensor('abcDEF'))