0

Je forme un réseau elman avec la bibliothèque neurolab python et mon réseau ne fonctionne pas correctement.Comment réparer le surfitting dans le réseau de neurones d'Elman?

Mais alors que la formation, il montre des erreurs trop grandes:

Epoch: 100; Error: 23752443150.672318; 
Epoch: 200; Error: 284037904.0305649; 
Epoch: 300; Error: 174736152.57367808; 
Epoch: 400; Error: 3318952.136089243; 
Epoch: 500; Error: 299017.4471083774; 
Epoch: 600; Error: 176600.0906688521; 
Epoch: 700; Error: 176599.32080188877; 
Epoch: 800; Error: 185178.21132511366; 
Epoch: 900; Error: 177224.2950528976; 
Epoch: 1000; Error: 176632.86797784362; 
The maximum number of train epochs is reached 

Par conséquent, le réseau échoue lors du test mple MICEX Original:

1758,97 
1626,18 
1688,34 
1609,19 
1654,55 
1669 
1733,17 
1642,97 
1711,53 
1771,05 

prédites MICEX:

[ 1237.59155306] 
[ 1237.59155306] 
[ 1237.59155306] 
[ 1237.59155306] 
[ 1237.59155306] 
[ 1237.59155306] 
[ 1237.59155306] 
[ 1237.59155306] 
[ 1237.59155306] 
[ 1237.59155306] 

Voici mon code:

import neurolab as nl 
import numpy as np 

# Create train samples 
MICEX = [421.08,455.44,430.3,484,515.17,468.85,484.73,514.71,551.72,591.09,644.64,561.78,535.4,534.84,502.81,549.28,611.03,632.97,570.76,552.22,575.74,635.38,598.04,593.88,603.89,639.98,700.65,784.28,892.5,842.52,944.55,1011,1171.44,1320.83,1299.19,1486.85,1281.5,1331.39,1380.24,1448.72,1367.24,1426.83,1550.71,1693.47,1656.97,1655.19,1698.08,1697.28,1570.34,1665.96,1734.42,1677.02,1759.44,1874.73,1850.64,1888.86,1574.33,1660.42,1628.43,1667.35,1925.24,1753.67,1495.33,1348.92,1027.66,731.96,611.32,619.53,624.9,666.05,772.93,920.35,1123.38,971.55,1053.3,1091.98,1197.2,1237.18,1284.95,1370.01,1419.42,1332.64,1450.15,1436.04,1332.62,1309.31,1397.12,1368.9,1440.3,1523.39,1565.52,1687.99,1723.42,1777.84,1813.59,1741.84,1666.3,1666.59,1705.18,1546.05,1366.54,1498.6,1499.62,1402.02,1510.91,1594.32,1518.29,1474.14,1312.24,1386.89,1406.36,1422.38,1459.01,1423.46,1405.19,1477.87,1547.18,1487.46,1440.02,1386.69,1343.99,1331.24,1377.6,1364.54,1463.13,1509.62,1479.35,1503.39,1454.05,1444.71,1369.29,1306.01,1432.03,1476.38,1379.61,1400.71,1411.07,1488.47,1533.68,1396.61,1647.69] 

Brent = [26.8,28.16,28.59,30.05,28.34,27.94,28.76,30.48,29.51,33.01,32.36,35.12,36.98,33.51,41.6,39.33,47.08,48.78,44.03,40.24,45.87,50.14,53.05,49.33,49.83,54.85,59.7,66.68,62.56,58.35,53.41,58.87,65.43,60.05,64.94,72,69,73.28,75.16,69.64,61.37,56.97,64.42,60.13,57.21,60.66,68.42,67.28,68.82,73.26,78.05,73.53,81.75,91.14,88,93.85,91.98,100.04,100.51,112.71,128.27,140.3,123.96,115.17,98.96,65.6,53.49,45.59,45.93,45.84,48.68,50.64,65.8,69.42,71.52,69.32,68.92,75.09,78.36,77.93,71.18,78.03,82.17,87.35,74.6,74.66,78.26,74.42,82.11,83.26,85.45,94.59,100.56,112.1,117.17,126.03,116.68,111.8,117.54,114.49,102.15,109.19,110.37,107.22,111.16,123.04,122.8,119.47,101.62,97.57,104.62,114.92,112.14,108.4,111.17,111.11,114.56,111,109.89,101.74,100.15,101.5,107.7,114.45,108.2,108.9,110.11,110.9,105.79,108.65,107.7,108.14,109.49,112.4,105.52,103.11,94.8,85.96,68.34,57.54,52.95] 
DJIA = [8850.26,8985.44,9233.8,9415.82,9275.06,9801.12,9782.46,10453.92,10488.07,10583.92,10357.7,10225.57,10188.45,10435.48,10139.71,10173.92,10080.27,10027.47,10428.02,10783.01,10489.94,10766.23,10503.76,10192.51,10467.48,10274.97,10640.91,10481.6,10568.7,10440.07,10805.87,10717.5,10864.86,10993.41,11109.32,11367.14,11168.31,11150.22,11185.68,11381.15,11679.07,12080.73,12221.93,12463.15,12621.69,12268.63,12354.35,13062.91,13627.64,13408.62,13211.99,13357.74,13895.63,13930.01,13371.72,13264.82,12650.36,12266.39,12262.89,12820.13,12638.32,11350.01,11378.02,11543.96,10850.66,9325.01,8829.04,8776.39,8000.86,7062.93,7608.92,8168.12,8500.33,8447,9171.61,9496.28,9712.28,9712.73,10344.84,10428.05,10067.33,10325.26,10856.63,11008.61,10136.63,9774.02,10465.94,10014.72,10788.05,11118.49,11006.02,11577.51,11891.93,12226.34,12319.73,12810.54,12569.79,12414.34,12143.24,11613.53,10913.38,11955.01,12045.68,12217.56,12632.91,12952.07,13212.04,13213.63,12393.45,12880.09,13008.68,13090.84,13437.13,13096.46,13025.58,13104.14,13860.58,14054.49,14578.54,14839.8,15115.57,14909.6,15499.54,14810.31,15129.67,15545.75,16086.41,16576.66,15698.85,16321.71,16457.66,16580.84,16717.17,16826.6,16563.3,17098.45,17042.9,17390.52,17828.24,17823.07,17164.95] 
CAC_40 = [2991.75,3084.1,3210.27,3311.42,3134.99,3373.2,3424.79,3557.9,3638.44,3725.44,3625.23,3674.28,3669.63,3732.99,3647.1,3594.28,3640.61,3706.82,3753.75,3821.16,3913.69,4027.16,4067.78,3908.93,4120.73,4229.35,4451.74,4399.36,4600.02,4436.45,4567.41,4715.23,4947.99,5000.45,5220.85,5188.4,4930.18,4965.96,5009.42,5165.04,5250.01,5348.73,5327.64,5541.76,5608.31,5516.32,5634.16,5930.77,6104,6054.93,5751.08,5662.7,5715.69,5841.08,5667.5,5614.08,4871.8,4790.66,4707.07,4996.54,5014.28,4425.61,4392.36,4485.64,4027.15,3487.07,3262.68,3217.97,2962.37,2693.96,2803.94,3159.85,3273.55,3138.93,3426.27,3657.72,3794.96,3601.43,3684.75,3936.33,3737.19,3708.8,3974.01,3816.99,3507.56,3442.89,3643.14,3476.18,3715.18,3833.5,3610.44,3804.78,4005.5,4110.35,3989.18,4106.92,4006.94,3980.78,3672.77,3256.76,2981.96,3242.84,3154.62,3159.81,3298.55,3447.94,3423.81,3212.8,3005.48,3196.65,3291.66,3413.07,3354.82,3429.27,3557.28,3641.07,3732.6,3723,3731.42,3856.75,3948.59,3738.91,3992.69,3933.78,4143.44,4299.89,4295.21,4295.95,4165.72,4408.08,4391.5,4487.39,4519.57,4422.84,4246.14,4381.04,4426.76,4233.09,4390.18,4263.55,4604.25] 
SSEC = [1576.26,1486.02,1476.74,1421.98,1367.16,1348.3,1397.22,1497.04,1590.73,1675.07,1741.62,1595.59,1555.91,1399.16,1386.2,1342.06,1396.7,1320.54,1340.77,1266.5,1191.82,1306,1181.24,1159.15,1060.74,1080.94,1083.03,1162.8,1155.61,1092.82,1099.26,1161.06,1258.05,1299.03,1298.3,1440.22,1641.3,1672.21,1612.73,1658.64,1752.42,1837.99,2099.29,2675.47,2786.34,2881.07,3183.98,3841.27,4109.65,3820.7,4471.03,5218.82,5552.3,5954.77,4871.78,5261.56,4383.39,4348.54,3472.71,3693.11,3433.35,2736.1,2775.72,2397.37,2293.78,1728.79,1871.16,1820.81,1990.66,2082.85,2373.21,2477.57,2632.93,2959.36,3412.06,2667.74,2779.43,2995.85,3195.3,3277.14,2989.29,3051.94,3109.11,2870.61,2592.15,2398.37,2637.5,2638.8,2655.66,2978.83,2820.18,2808.08,2790.69,2905.05,2928.11,2911.51,2743.47,2762.08,2701.73,2567.34,2359.22,2468.25,2333.41,2199.42,2292.61,2428.49,2262.79,2396.32,2372.23,2225.43,2103.63,2047.52,2086.17,2068.88,1980.12,2269.13,2385.42,2365.59,2236.62,2177.91,2300.59,1979.21,1993.8,2098.38,2174.66,2141.61,2220.5,2115.98,2033.08,2056.3,2033.31,2026.36,2039.21,2048.33,2201.56,2217.2,2363.87,2420.18,2682.83,3234.68,3210.36] 


Brent_sample = [62.48, 55.1, 66.8, 65.19, 63.14, 51.85, 53.12, 48.44, 49.5, 44.5] 
DJIA_sample = [18132.7, 17776.12, 17840.52, 18010.68, 17619.51, 17689.86, 16528.03, 16284.7, 17663.54, 17719.92] 
CAC_40_sample = [4922.99, 5031.47, 5042.84, 5084.08, 4812.24, 5081.73, 4652.34, 4453.91, 4880.18, 4951.83] 
SSEC_sample = [3310.3, 3747.9, 4441.66, 4611.74, 4277.22, 3663.73, 3205.99, 3052.78, 3382.56, 3445.4] 



MICEX = np.asarray(MICEX) 
Brent = np.asarray(Brent) 
DJIA = np.asarray(DJIA) 
CAC_40 = np.asarray(CAC_40) 
SSEC = np.asarray(SSEC) 

Brent_sample = np.asarray(Brent_sample) 
DJIA_sample = np.asarray(DJIA_sample) 
CAC_40_sample = np.asarray(CAC_40_sample) 
SSEC_sample = np.asarray(SSEC_sample) 

size = len(MICEX) 

inp = np.vstack((Brent, DJIA, CAC_40, SSEC)).T 
tar = MICEX.reshape(size, 1) 
smp = np.vstack((Brent_sample, DJIA_sample, CAC_40_sample, SSEC_sample)).T 

# Create network with 2 layers and random initialized 
net = nl.net.newelm(
     [[min(inp[:, 0]), max(inp[:, 0])], 
     [min(inp[:, 1]), max(inp[:, 1])], 
     [min(inp[:, 2]), max(inp[:, 2])], 
     [min(inp[:, 3]), max(inp[:, 3])] 
     ], 
     [46, 1], 
     [nl.trans.TanSig(), nl.trans.PureLin()] # SatLinPrm(0.00000001, 421.08, 1925.24) 
       ) 
# Set initialized functions and init 
net.layers[0].initf = nl.init.InitRand([-0.1, 0.1], 'wb') 
net.layers[1].initf = nl.init.InitRand([-0.1, 0.1], 'wb') 
net.init() 

# Changing training method 
# net.trainf = nl.train.train_cg 

# Train network 
error = net.train(inp, tar, epochs=1000, show=100, goal=0.02) 

# Simulate network 
out = net.sim(smp) 
print(smp) 
print('MICEX predictions for the next 10 periods:\n', out) 

Quelqu'un sait-il une solution de ce problème?

+1

Je ne vois aucune erreur. Qu'est-ce qui ne fonctionne pas? L'erreur devient plus petite et semble fonctionner. Une autre chose à remarquer est que l'erreur cesse de grossir, il semble que vous soyez [overfitting] (https://en.wikipedia.org/wiki/Overfitting) votre ensemble d'entraînement, et ça ne va pas généraliser. –

+0

@ john-carpenter Merci! Et comment puis-je réparer ce sur-équipement dans mon cas? –

+1

Il existe quelques méthodes, l'une est de détecter que votre erreur ne diminue plus sur votre ensemble d'entraînement et d'arrêter l'entraînement. L'autre consiste à utiliser un ensemble de validation, que vous ne vous entraînez pas avec, mais encore mesurer l'erreur. Lorsque le taux d'erreur sur le jeu de validation cesse de diminuer, vous vous arrêtez. [Voici un lien SO pour plus de détails] (https://stackoverflow.com/questions/2976452/whats-is-the-difference-between-train-validation-and-test-set-in-neural-networ) –

Répondre

1

Tout d'abord, il s'agit de ne sur-ajustant pas. Vous êtes underfitting, vous ne même pas converger pour les cas d'entraînement. Tu ne peux pas augmenter le nombre d'époques? Laisse le net converger.

+0

Je viens d'utiliser 10000 avec configuration suivante époques: - 25 neurones - réseau Elman code type.Network: http://pastebin.com/GYM2E0ci journal de script: http://pastebin.com/5AJbvXME que vous pouvez voir, l'erreur s'arrête toujours à près de 1766.00 Et la sortie est la même pour toutes les 10 lignes ... –