2017-06-01 3 views
0

J'ai une question concernant le problème suivant:python: valeur non valide rencontrée dans true_divide - mais où?

Je veux tracer la fonction facile:

f (x) = x 1 * x 2/(x 1^2 + x 2^2)

Si x & y sont zéro, vous divisez par zéro, alors j'ai ajouté une exception pour éviter ce cas:

import numpy as np 
import matplotlib.pyplot as plt 
from mpl_toolkits.mplot3d import Axes3D 

def f(x1, x2): 
    return np.where(np.logical_and(x1==0,x2==0), 
        0, 
        x1*x2/(x1*x1+x2*x2)) 

n = 3 
x = y = np.linspace(-5,5,n) 
xv, yv = np.meshgrid(x, y) 
z = f(xv,yv) 

fig = plt.figure() 
ax = fig.add_subplot(111, projection='3d') 
ax.plot_surface(xv,yv,z) 
plt.show() 

ma figure est intrigue et si je ma solution, il inspecte semble aussi être correct. Toutefois, si je lance le code que je reçois une erreur de division:

RuntimeWarning: invalid value encountered in true_divide 

Je l'ai testé déjà la fonction np.where manuellement et il renvoie le = x 1 x 2 = 0 valeur réelle. Cela semble fonctionner.

Est-ce que quelqu'un sait d'où vient cet avertissement?

+0

Je ne peux pas le reproduire. Votre code fonctionne bien pour moi et trace un graphique – MaxU

+0

Les arguments de 'np.where()' sont tous * évalués *, donc l'utiliser comme ça n'éliminera pas l'erreur. –

+0

@WarrenWeckesser si je vous comprends bien 'x1 * x2/(x1 * x1 + x2 * x2)' est également évalué pour x1 = x2 = 0. Connaissez-vous un meilleur moyen que np.where() pour résoudre ce problème? –

Répondre

0

Comme il a été souligné, vous allez évaluer chaque cas en utilisant np.where(). Pour éviter l'erreur, il suffit de la coder dans un niveau inférieur tel que

import numpy as np 
import matplotlib.pyplot as plt 
from mpl_toolkits.mplot3d import Axes3D 

def f(x1, x2): 
    shape = np.shape(x1) 
    y = np.zeros(shape) 
    for i in range(0,shape[0]): 
     for j in range(0,shape[1]): 
      if x1[i,j]!=0 and x2[i,j]!=0: 
       y[i,j] = x1[i,j]*x2[i,j]/(x1[i,j]*x1[i,j]+x2[i,j]*x2[i,j]) 
    return y 

n = 3 
x = y = np.linspace(-5,5,n) 
xv, yv = np.meshgrid(x, y) 
z = f(xv,yv) 

fig = plt.figure() 
ax = fig.add_subplot(111, projection='3d') 
ax.plot_surface(xv,yv,z) 
plt.show()