2012-03-26 1 views
2

edit: Je ne cherche pas à déboguer ce code. Si vous connaissez cet algorithme bien connu, vous pourrez peut-être aider. Veuillez noter que l'algorithme produit correctement les coefficients.Cubic Spline Python code produisant des splines linéaires

Ce code pour l'interpolation spline cubique produit des splines linéaires et je n'arrive pas à comprendre pourquoi (pour l'instant). L'algorithme provient de Burden's Numerical Analysis, qui est à peu près identique au pseudo code here, ou vous pouvez trouver ce livre à partir d'un lien dans les commentaires (voir le chapitre 3, ça vaut le coup de le faire quand même). Le code produit les coefficients corrects; Je crois que je me méprends sur la mise en œuvre. Tout commentaire est grandement apprécié. De plus, je suis novice en programmation, donc tout commentaire sur la qualité de mon codage est également le bienvenu. J'ai essayé de télécharger des photos du système linéaire en termes de h, a et c, mais en tant que nouvel utilisateur, je ne peux pas. Si vous voulez un visuel du système linéaire tridiagonal que l'algorithme résout, et qui est mis en place par la var alpha, voir le lien dans les commentaires pour le livre, voir le chapitre 3. Le système est strictement diagonalement dominant, nous savons donc existe une solution unique c0, ..., cn. Une fois que nous connaissons les valeurs ci, les autres coefficients suivent.

import matplotlib.pyplot as plt 

# need some zero vectors... 
def zeroV(m): 
    z = [0]*m 
    return(z) 

#INPUT: n; x0, x1, ... ,xn; a0 = f(x0), a1 =f(x1), ... , an = f(xn). 
def cubic_spline(n, xn, a, xd): 
    """function cubic_spline(n,xn, a, xd) interpolates between the knots 
     specified by lists xn and a. The function computes the coefficients 
     and outputs the ranges of the piecewise cubic splines."""   

    h = zeroV(n-1) 

    # alpha will be values in a system of eq's that will allow us to solve for c 
    # and then from there we can find b, d through substitution. 
    alpha = zeroV(n-1) 

    # l, u, z are used in the method for solving the linear system 
    l = zeroV(n+1) 
    u = zeroV(n) 
    z = zeroV(n+1) 

    # b, c, d will be the coefficients along with a. 
    b = zeroV(n)  
    c = zeroV(n+1) 
    d = zeroV(n)  

for i in range(n-1): 
    # h[i] is used to satisfy the condition that 
    # Si+1(xi+l) = Si(xi+l) for each i = 0,..,n-1 
    # i.e., the values at the knots are "doubled up" 
    h[i] = xn[i+1]-xn[i] 

for i in range(1, n-1): 
    # Sets up the linear system and allows us to find c. Once we have 
    # c then b and d follow in terms of it. 
    alpha[i] = (3./h[i])*(a[i+1]-a[i])-(3./h[i-1])*(a[i] - a[i-1]) 

# I, II, (part of) III Sets up and solves tridiagonal linear system... 
# I 
l[0] = 1  
u[0] = 0  
z[0] = 0 

# II 
for i in range(1, n-1): 
    l[i] = 2*(xn[i+1] - xn[i-1]) - h[i-1]*u[i-1] 
    u[i] = h[i]/l[i] 
    z[i] = (alpha[i] - h[i-1]*z[i-1])/l[i] 

l[n] = 1 
z[n] = 0 
c[n] = 0 

# III... also find b, d in terms of c. 
for j in range(n-2, -1, -1):  
    c[j] = z[j] - u[j]*c[j+1] 
    b[j] = (a[j+1] - a[j])/h[j] - h[j]*(c[j+1] + 2*c[j])/3. 
    d[j] = (c[j+1] - c[j])/(3*h[j]) 

# This is my only addition, which is returning values for Sj(x). The issue I'm having 
# is related to this implemention, i suspect. 
for j in range(n-1): 
    #OUTPUT:S(x)=Sj(x)= aj + bj(x - xj) + cj(x - xj)^2 + dj(x - xj)^3; xj <= x <= xj+1) 
    return(a[j] + b[j]*(xd - xn[j]) + c[j]*((xd - xn[j])**2) + d[j]*((xd - xn[j])**3)) 

Pour l'ennuyer ou overachieving ...

Voici le code pour le test, l'intervalle est x: [1, 9], y: [0, 19,7750212]. La fonction de test est XLN (x), donc nous commençons 1 et augmenter de 0,1 à 9.

ln = [] 
ln_dom = [] 
cub = [] 
step = 1. 
X=[1., 9.] 
FX=[0, 19.7750212] 
while step <= 9.: 
    ln.append(step*log(step)) 
    ln_dom.append(step) 
    cub.append(cubic_spline(2, x, fx, step)) 
    step += 0.1 

... et pour le traçage:

plt.plot(ln_dom, cub, color='blue') 
plt.plot(ln_dom, ln, color='red') 
plt.axis([1., 9., 0, 20], 'equal') 
plt.axhline(y=0, color='black') 
plt.axvline(x=0, color='black') 
plt.show() 
+1

Avez-vous essayé de déboguer le programme? Passer à travers avec un débogueur? Insérer des instructions d'impression de débogage? Si vous ne montrez aucun effort, personne ne va lire tout ce code et le déboguer pour vous. –

+2

Aucun effort? Cela a été des heures et des heures de travail. J'ai essayé de rendre le message simple et précis. J'ai fait des tonnes de débogage. Im cherchant l'aide de quelqu'un avec l'expérience avec l'interpolation cubique de spline, b/c n'importe qui qui l'a fait connaît déjà ce code. C'est le même algorithme que tout le monde utilise. Je ne cherche pas quelqu'un pour franchir ligne par ligne, comme je l'ai dit, les coefficients sont calculés correctement. – daniel

+1

Il me semble que vous interpolez entre deux points seulement: 'X = [1., 9.]'. Si c'est le cas, pourquoi attendriez-vous autre chose qu'une ligne droite? –

Répondre

3

Ok, ça fonctionne. Le problème était dans ma mise en œuvre. Je l'ai fait travailler avec une approche différente, où les splines sont construites individuellement au lieu de continuellement. Il s'agit d'une interpolation spline cubique fonctionnant entièrement par la méthode de la première construction des coefficients des polynômes splines (qui est 99% du travail), puis mise en œuvre. Évidemment, ce n'est pas la seule façon de le faire. Je peux travailler sur une approche différente et l'afficher s'il y a un intérêt. Une chose qui permettrait de clarifier le code serait une image du système linéaire qui est résolu, mais je ne peux pas poster de photos jusqu'à ce que mon représentant se lève jusqu'à 10. Si vous voulez approfondir l'algorithme, voir le lien du livre dans les commentaires ci-dessus.

import matplotlib.pyplot as plt 
from pylab import arange 
from math import e 
from math import pi 
from math import sin 
from math import cos 
from numpy import poly1d 

# need some zero vectors... 
def zeroV(m): 
    z = [0]*m 
    return(z) 

#INPUT: n; x0, x1, ... ,xn; a0 = f(x0), a1 =f(x1), ... , an = f(xn). 
def cubic_spline(n, xn, a): 
"""function cubic_spline(n,xn, a, xd) interpolates between the knots 
    specified by lists xn and a. The function computes the coefficients 
    and outputs the ranges of the piecewise cubic splines."""   

    h = zeroV(n-1) 

    # alpha will be values in a system of eq's that will allow us to solve for c 
    # and then from there we can find b, d through substitution. 
    alpha = zeroV(n-1) 

    # l, u, z are used in the method for solving the linear system 
    l = zeroV(n+1) 
    u = zeroV(n) 
    z = zeroV(n+1) 

    # b, c, d will be the coefficients along with a. 
    b = zeroV(n)  
    c = zeroV(n+1) 
    d = zeroV(n)  

    for i in range(n-1): 
     # h[i] is used to satisfy the condition that 
     # Si+1(xi+l) = Si(xi+l) for each i = 0,..,n-1 
     # i.e., the values at the knots are "doubled up" 
     h[i] = xn[i+1]-xn[i] 

    for i in range(1, n-1): 
     # Sets up the linear system and allows us to find c. Once we have 
     # c then b and d follow in terms of it. 
     alpha[i] = (3./h[i])*(a[i+1]-a[i])-(3./h[i-1])*(a[i] - a[i-1]) 

    # I, II, (part of) III Sets up and solves tridiagonal linear system... 
    # I 
    l[0] = 1  
    u[0] = 0  
    z[0] = 0 

    # II 
    for i in range(1, n-1): 
     l[i] = 2*(xn[i+1] - xn[i-1]) - h[i-1]*u[i-1] 
     u[i] = h[i]/l[i] 
     z[i] = (alpha[i] - h[i-1]*z[i-1])/l[i] 

    l[n] = 1 
    z[n] = 0 
    c[n] = 0 

    # III... also find b, d in terms of c. 
    for j in range(n-2, -1, -1):  
     c[j] = z[j] - u[j]*c[j+1] 
     b[j] = (a[j+1] - a[j])/h[j] - h[j]*(c[j+1] + 2*c[j])/3. 
     d[j] = (c[j+1] - c[j])/(3*h[j]) 

    # Now that we have the coefficients it's just a matter of constructing 
    # the appropriate polynomials and graphing. 
    for j in range(n-1): 
     cub_graph(a[j],b[j],c[j],d[j],xn[j],xn[j+1]) 

    plt.show() 

def cub_graph(a,b,c,d, x_i, x_i_1): 
    """cub_graph takes the i'th coefficient set along with the x[i] and x[i+1]'th 
     data pts, and constructs the polynomial spline between the two data pts using 
     the poly1d python object (which simply returns a polynomial with a given root.""" 

    # notice here that we are just building the cubic polynomial piece by piece 
    root = poly1d(x_i,True) 
    poly = 0 
    poly = d*(root)**3 
    poly = poly + c*(root)**2 
    poly = poly + b*root 
    poly = poly + a 

    # Set up our domain between data points, and plot the function 
    pts = arange(x_i,x_i_1, 0.001) 
    plt.plot(pts, poly(pts), '-') 
    return 

Si vous voulez tester, voici quelques données que vous pouvez utiliser pour commencer, qui vient de la fonction 1.6e^(- 2 x) sin (3 * pi * x) entre 0 et 1:

# These are our data points 
x_vals = [0, 1./6, 1./3, 1./2, 7./12, 2./3, 3./4, 5./6, 11./12, 1] 

# Set up the domain 
x_domain = arange(0,2, 1e-2) 

fx = zeroV(10) 

# Defines the function so we can get our fx values 
def sine_func(x): 
    return(1.6*e**(-2*x)*sin(3*pi*x)) 

for i in range(len(x_vals)): 
    fx[i] = sine_func(x_vals[i]) 

# Run cubic_spline interpolant. 
cubic_spline(10,x_vals,fx) 
+0

Merci pour la mise à jour, c'est très utile pour moi! – gleerman

1

Commentaires sur votre style de codage:


  • Où sont vos commentaires et votre documentation? À tout le moins, fournissez la documentation de la fonction afin que les gens puissent dire comment votre fonction est censée être utilisée.

Au lieu de:

def cubic_spline(xx,yy): 

S'il vous plaît écrire quelque chose comme:

def cubic_spline(xx, yy): 
    """function cubic_spline(xx,yy) interpolates between the knots 
    specified by lists xx and yy. The function returns the coefficients 
    and ranges of the piecewise cubic splines.""" 

  • Vous pouvez faire des listes d'éléments répétés en utilisant l'opérateur * sur une liste.

Comme ceci:

>>> [0] * 10 
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 

Pour que votre fonction zeroV peut être remplacé par [0] * m.

Ne faites pas cela avec des types mutables! (en particulier les listes).

>>> inner_list = [1,2,3] 
>>> outer_list = [inner_list] * 3 
>>> outer_list 
[[1, 2, 3], [1, 2, 3], [1, 2, 3]] 
>>> inner_list[0] = 999 
>>> outer_list 
[[999, 2, 3], [999, 2, 3], [999, 2, 3]] # wut 

  • Math devrait probablement être fait en utilisant numpy ou scipy.

En dehors de cela, you should read Idiomatic Python by David Goodger.

+1

Merci pour les commentaires, je vais certainement lire ce lien. tyvm. – daniel

Questions connexes