Pour les modèles pytorch, j'ai trouvé this tutorial expliquant comment classer une image. J'ai essayé d'appliquer la même procédure pour un modèle de création. Cependant, le modèle échoue pour chaque image que je charge dansLe modèle de création de pytorch génère une mauvaise étiquette pour chaque image d'entrée
code:
# some people need these three lines to make it work
#from torchvision.models.inception import model_urls
#name = 'inception_v3_google'
#model_urls[name] = model_urls[name].replace('https://', 'http://')
from torch.autograd import Variable
import torchvision
import requests
from torchvision import models, transforms
from PIL import Image
import io
from PIL import Image
LABELS_URL = 'https://s3.amazonaws.com/outcome-blog/imagenet/labels.json'
# cat
IMG_URL1 = 'http://farm2.static.flickr.com/1029/762542019_4f197a0de5.jpg'
# dog
IMG_URL2 = 'http://farm3.static.flickr.com/2314/2518519714_98b01968ee.jpg'
# lion
IMG_URL3 = 'http://farm1.static.flickr.com/62/218998565_62930f10fc.jpg'
labels = {int(key):value for (key, value)
in requests.get(LABELS_URL).json().items()}
model = torchvision.models.inception_v3(pretrained=True)
model.training = False
model.transform_input = False
def predict_url_img(url):
response = requests.get(url)
img_pil = Image.open(io.BytesIO(response.content))
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(299),
transforms.ToTensor(),
normalize
])
img_tensor = preprocess(img_pil)
img_tensor.unsqueeze_(0)
img_variable = Variable(img_tensor)
fc_out = model(img_variable)
print("prediction:", labels[fc_out.data.numpy().argmax()])
predict_url_img(IMG_URL1)
predict_url_img(IMG_URL2)
predict_url_img(IMG_URL3)
En sortie, j'obtenir ceci:
('prédiction:', u "piston, aide-plombier")
('prédiction:', sac u'plastic ')
(' prédiction: », u "plongeur, assistant de plombier")
est ainsi la couche de décrochage, ces deux couches se comporte différemment dans la phase de formation et de test, voir [ici] (http://pytorch.org/docs/master/nn.html#torch.nn.Module.eval) – jdhao