2016-02-22 5 views
1

J'essaie d'archiver tout ce qui est décrit pour R dans Python 3. Mais jusqu'à présent, je ne vais pas plus loin.Utilisation du modèle de classification d'image préformé MXNet en Python

Le tutoriel en R est décrit ici: http://mxnet.readthedocs.org/en/latest/R-package/classifyRealImageWithPretrainedModel.html

Comment puis-je faire la même chose en Python? En utilisant le modèle suivant: https://github.com/dmlc/mxnet-model-gallery/blob/master/imagenet-1k-inception-bn.md

Sincères salutations, Kevin

+0

Pour votre information, le R doc a déménagé ici: http://mxnet.io/tutorials/r /classifyRealImageWithPretrainedModel.html – Leopd

Répondre

0

En ce moment, vous pouvez faire des choses beaucoup plus en utilisant Python mxnet que d'utiliser R. J'utilise l'API de gluons, ce qui rend l'écriture de code encore plus simple , et cela permet de charger des modèles pré-montés.

Le modèle utilisé dans le didacticiel auquel vous faites référence est un Inception model. La liste de tous les modèles pré-formés disponibles peut être trouvée here.

Le reste des actions du didacticiel est la normalisation et l'augmentation des données. Vous pouvez faire la normalisation des nouvelles données similaires à la façon dont ils normalisent sur la page API:

image = image/255 
normalized = mx.image.color_normalize(image, 
             mean=mx.nd.array([0.485, 0.456, 0.406]), 
             std=mx.nd.array([0.229, 0.224, 0.225])) 

La liste de l'augmentation possible est disponible here.

Voici l'exemple exécutable pour vous. Je l'ai fait une seule augmentation, et vous pouvez ajouter plus de paramètres à mx.image.CreateAugmenter si vous voulez faire plus d'entre eux:

%matplotlib inline 
import mxnet as mx 
from mxnet.gluon.model_zoo import vision 
from matplotlib.pyplot import imshow 

def plot_mx_array(array, clip=False): 
    """ 
    Array expected to be 3 (channels) x heigh x width, and values are floats between 0 and 255. 
    """ 
    assert array.shape[2] == 3, "RGB Channel should be last" 
    if clip: 
     array = array.clip(0,255) 
    else: 
     assert array.min().asscalar() >= 0, "Value in array is less than 0: found " + str(array.min().asscalar()) 
     assert array.max().asscalar() <= 255, "Value in array is greater than 255: found " + str(array.max().asscalar()) 
    array = array/255 
    np_array = array.asnumpy() 
    imshow(np_array) 


inception_model = vision.inception_v3(pretrained=True) 

with open("/Volumes/Unix/workspace/MxNet/2018-02-20T19-43-45/types_of_data_augmentation/output_4_0.png", 'rb') as open_file: 
    encoded_image = open_file.read() 
    example_image = mx.image.imdecode(encoded_image) 
    example_image = example_image.astype("float32") 
    plot_mx_array(example_image) 


augmenters = mx.image.CreateAugmenter(data_shape=(1, 100, 100)) 

for augementer in augmenters: 
    example_image = augementer(example_image) 

plot_mx_array(example_image) 

example_image = example_image/255 
normalized_image = mx.image.color_normalize(example_image, 
             mean=mx.nd.array([0.485, 0.456, 0.406]), 
             std=mx.nd.array([0.229, 0.224, 0.225]))