Je forme un réseau elman avec la bibliothèque neurolab python et mon réseau ne fonctionne pas correctement.Comment réparer le surfitting dans le réseau de neurones d'Elman?
- vecteurs d'entrée de formation: http://pastebin.com/urQX2eEA
- vecteur cible de formation: http://pastebin.com/1JQh1xZv
- vecteur de l'échantillon à tester un réseau: http://pastebin.com/jprZhBHa
Mais alors que la formation, il montre des erreurs trop grandes:
Epoch: 100; Error: 23752443150.672318;
Epoch: 200; Error: 284037904.0305649;
Epoch: 300; Error: 174736152.57367808;
Epoch: 400; Error: 3318952.136089243;
Epoch: 500; Error: 299017.4471083774;
Epoch: 600; Error: 176600.0906688521;
Epoch: 700; Error: 176599.32080188877;
Epoch: 800; Error: 185178.21132511366;
Epoch: 900; Error: 177224.2950528976;
Epoch: 1000; Error: 176632.86797784362;
The maximum number of train epochs is reached
Par conséquent, le réseau échoue lors du test mple MICEX Original:
1758,97
1626,18
1688,34
1609,19
1654,55
1669
1733,17
1642,97
1711,53
1771,05
prédites MICEX:
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
Voici mon code:
import neurolab as nl
import numpy as np
# Create train samples
MICEX = [421.08,455.44,430.3,484,515.17,468.85,484.73,514.71,551.72,591.09,644.64,561.78,535.4,534.84,502.81,549.28,611.03,632.97,570.76,552.22,575.74,635.38,598.04,593.88,603.89,639.98,700.65,784.28,892.5,842.52,944.55,1011,1171.44,1320.83,1299.19,1486.85,1281.5,1331.39,1380.24,1448.72,1367.24,1426.83,1550.71,1693.47,1656.97,1655.19,1698.08,1697.28,1570.34,1665.96,1734.42,1677.02,1759.44,1874.73,1850.64,1888.86,1574.33,1660.42,1628.43,1667.35,1925.24,1753.67,1495.33,1348.92,1027.66,731.96,611.32,619.53,624.9,666.05,772.93,920.35,1123.38,971.55,1053.3,1091.98,1197.2,1237.18,1284.95,1370.01,1419.42,1332.64,1450.15,1436.04,1332.62,1309.31,1397.12,1368.9,1440.3,1523.39,1565.52,1687.99,1723.42,1777.84,1813.59,1741.84,1666.3,1666.59,1705.18,1546.05,1366.54,1498.6,1499.62,1402.02,1510.91,1594.32,1518.29,1474.14,1312.24,1386.89,1406.36,1422.38,1459.01,1423.46,1405.19,1477.87,1547.18,1487.46,1440.02,1386.69,1343.99,1331.24,1377.6,1364.54,1463.13,1509.62,1479.35,1503.39,1454.05,1444.71,1369.29,1306.01,1432.03,1476.38,1379.61,1400.71,1411.07,1488.47,1533.68,1396.61,1647.69]
Brent = [26.8,28.16,28.59,30.05,28.34,27.94,28.76,30.48,29.51,33.01,32.36,35.12,36.98,33.51,41.6,39.33,47.08,48.78,44.03,40.24,45.87,50.14,53.05,49.33,49.83,54.85,59.7,66.68,62.56,58.35,53.41,58.87,65.43,60.05,64.94,72,69,73.28,75.16,69.64,61.37,56.97,64.42,60.13,57.21,60.66,68.42,67.28,68.82,73.26,78.05,73.53,81.75,91.14,88,93.85,91.98,100.04,100.51,112.71,128.27,140.3,123.96,115.17,98.96,65.6,53.49,45.59,45.93,45.84,48.68,50.64,65.8,69.42,71.52,69.32,68.92,75.09,78.36,77.93,71.18,78.03,82.17,87.35,74.6,74.66,78.26,74.42,82.11,83.26,85.45,94.59,100.56,112.1,117.17,126.03,116.68,111.8,117.54,114.49,102.15,109.19,110.37,107.22,111.16,123.04,122.8,119.47,101.62,97.57,104.62,114.92,112.14,108.4,111.17,111.11,114.56,111,109.89,101.74,100.15,101.5,107.7,114.45,108.2,108.9,110.11,110.9,105.79,108.65,107.7,108.14,109.49,112.4,105.52,103.11,94.8,85.96,68.34,57.54,52.95]
DJIA = [8850.26,8985.44,9233.8,9415.82,9275.06,9801.12,9782.46,10453.92,10488.07,10583.92,10357.7,10225.57,10188.45,10435.48,10139.71,10173.92,10080.27,10027.47,10428.02,10783.01,10489.94,10766.23,10503.76,10192.51,10467.48,10274.97,10640.91,10481.6,10568.7,10440.07,10805.87,10717.5,10864.86,10993.41,11109.32,11367.14,11168.31,11150.22,11185.68,11381.15,11679.07,12080.73,12221.93,12463.15,12621.69,12268.63,12354.35,13062.91,13627.64,13408.62,13211.99,13357.74,13895.63,13930.01,13371.72,13264.82,12650.36,12266.39,12262.89,12820.13,12638.32,11350.01,11378.02,11543.96,10850.66,9325.01,8829.04,8776.39,8000.86,7062.93,7608.92,8168.12,8500.33,8447,9171.61,9496.28,9712.28,9712.73,10344.84,10428.05,10067.33,10325.26,10856.63,11008.61,10136.63,9774.02,10465.94,10014.72,10788.05,11118.49,11006.02,11577.51,11891.93,12226.34,12319.73,12810.54,12569.79,12414.34,12143.24,11613.53,10913.38,11955.01,12045.68,12217.56,12632.91,12952.07,13212.04,13213.63,12393.45,12880.09,13008.68,13090.84,13437.13,13096.46,13025.58,13104.14,13860.58,14054.49,14578.54,14839.8,15115.57,14909.6,15499.54,14810.31,15129.67,15545.75,16086.41,16576.66,15698.85,16321.71,16457.66,16580.84,16717.17,16826.6,16563.3,17098.45,17042.9,17390.52,17828.24,17823.07,17164.95]
CAC_40 = [2991.75,3084.1,3210.27,3311.42,3134.99,3373.2,3424.79,3557.9,3638.44,3725.44,3625.23,3674.28,3669.63,3732.99,3647.1,3594.28,3640.61,3706.82,3753.75,3821.16,3913.69,4027.16,4067.78,3908.93,4120.73,4229.35,4451.74,4399.36,4600.02,4436.45,4567.41,4715.23,4947.99,5000.45,5220.85,5188.4,4930.18,4965.96,5009.42,5165.04,5250.01,5348.73,5327.64,5541.76,5608.31,5516.32,5634.16,5930.77,6104,6054.93,5751.08,5662.7,5715.69,5841.08,5667.5,5614.08,4871.8,4790.66,4707.07,4996.54,5014.28,4425.61,4392.36,4485.64,4027.15,3487.07,3262.68,3217.97,2962.37,2693.96,2803.94,3159.85,3273.55,3138.93,3426.27,3657.72,3794.96,3601.43,3684.75,3936.33,3737.19,3708.8,3974.01,3816.99,3507.56,3442.89,3643.14,3476.18,3715.18,3833.5,3610.44,3804.78,4005.5,4110.35,3989.18,4106.92,4006.94,3980.78,3672.77,3256.76,2981.96,3242.84,3154.62,3159.81,3298.55,3447.94,3423.81,3212.8,3005.48,3196.65,3291.66,3413.07,3354.82,3429.27,3557.28,3641.07,3732.6,3723,3731.42,3856.75,3948.59,3738.91,3992.69,3933.78,4143.44,4299.89,4295.21,4295.95,4165.72,4408.08,4391.5,4487.39,4519.57,4422.84,4246.14,4381.04,4426.76,4233.09,4390.18,4263.55,4604.25]
SSEC = [1576.26,1486.02,1476.74,1421.98,1367.16,1348.3,1397.22,1497.04,1590.73,1675.07,1741.62,1595.59,1555.91,1399.16,1386.2,1342.06,1396.7,1320.54,1340.77,1266.5,1191.82,1306,1181.24,1159.15,1060.74,1080.94,1083.03,1162.8,1155.61,1092.82,1099.26,1161.06,1258.05,1299.03,1298.3,1440.22,1641.3,1672.21,1612.73,1658.64,1752.42,1837.99,2099.29,2675.47,2786.34,2881.07,3183.98,3841.27,4109.65,3820.7,4471.03,5218.82,5552.3,5954.77,4871.78,5261.56,4383.39,4348.54,3472.71,3693.11,3433.35,2736.1,2775.72,2397.37,2293.78,1728.79,1871.16,1820.81,1990.66,2082.85,2373.21,2477.57,2632.93,2959.36,3412.06,2667.74,2779.43,2995.85,3195.3,3277.14,2989.29,3051.94,3109.11,2870.61,2592.15,2398.37,2637.5,2638.8,2655.66,2978.83,2820.18,2808.08,2790.69,2905.05,2928.11,2911.51,2743.47,2762.08,2701.73,2567.34,2359.22,2468.25,2333.41,2199.42,2292.61,2428.49,2262.79,2396.32,2372.23,2225.43,2103.63,2047.52,2086.17,2068.88,1980.12,2269.13,2385.42,2365.59,2236.62,2177.91,2300.59,1979.21,1993.8,2098.38,2174.66,2141.61,2220.5,2115.98,2033.08,2056.3,2033.31,2026.36,2039.21,2048.33,2201.56,2217.2,2363.87,2420.18,2682.83,3234.68,3210.36]
Brent_sample = [62.48, 55.1, 66.8, 65.19, 63.14, 51.85, 53.12, 48.44, 49.5, 44.5]
DJIA_sample = [18132.7, 17776.12, 17840.52, 18010.68, 17619.51, 17689.86, 16528.03, 16284.7, 17663.54, 17719.92]
CAC_40_sample = [4922.99, 5031.47, 5042.84, 5084.08, 4812.24, 5081.73, 4652.34, 4453.91, 4880.18, 4951.83]
SSEC_sample = [3310.3, 3747.9, 4441.66, 4611.74, 4277.22, 3663.73, 3205.99, 3052.78, 3382.56, 3445.4]
MICEX = np.asarray(MICEX)
Brent = np.asarray(Brent)
DJIA = np.asarray(DJIA)
CAC_40 = np.asarray(CAC_40)
SSEC = np.asarray(SSEC)
Brent_sample = np.asarray(Brent_sample)
DJIA_sample = np.asarray(DJIA_sample)
CAC_40_sample = np.asarray(CAC_40_sample)
SSEC_sample = np.asarray(SSEC_sample)
size = len(MICEX)
inp = np.vstack((Brent, DJIA, CAC_40, SSEC)).T
tar = MICEX.reshape(size, 1)
smp = np.vstack((Brent_sample, DJIA_sample, CAC_40_sample, SSEC_sample)).T
# Create network with 2 layers and random initialized
net = nl.net.newelm(
[[min(inp[:, 0]), max(inp[:, 0])],
[min(inp[:, 1]), max(inp[:, 1])],
[min(inp[:, 2]), max(inp[:, 2])],
[min(inp[:, 3]), max(inp[:, 3])]
],
[46, 1],
[nl.trans.TanSig(), nl.trans.PureLin()] # SatLinPrm(0.00000001, 421.08, 1925.24)
)
# Set initialized functions and init
net.layers[0].initf = nl.init.InitRand([-0.1, 0.1], 'wb')
net.layers[1].initf = nl.init.InitRand([-0.1, 0.1], 'wb')
net.init()
# Changing training method
# net.trainf = nl.train.train_cg
# Train network
error = net.train(inp, tar, epochs=1000, show=100, goal=0.02)
# Simulate network
out = net.sim(smp)
print(smp)
print('MICEX predictions for the next 10 periods:\n', out)
Quelqu'un sait-il une solution de ce problème?
Je ne vois aucune erreur. Qu'est-ce qui ne fonctionne pas? L'erreur devient plus petite et semble fonctionner. Une autre chose à remarquer est que l'erreur cesse de grossir, il semble que vous soyez [overfitting] (https://en.wikipedia.org/wiki/Overfitting) votre ensemble d'entraînement, et ça ne va pas généraliser. –
@ john-carpenter Merci! Et comment puis-je réparer ce sur-équipement dans mon cas? –
Il existe quelques méthodes, l'une est de détecter que votre erreur ne diminue plus sur votre ensemble d'entraînement et d'arrêter l'entraînement. L'autre consiste à utiliser un ensemble de validation, que vous ne vous entraînez pas avec, mais encore mesurer l'erreur. Lorsque le taux d'erreur sur le jeu de validation cesse de diminuer, vous vous arrêtez. [Voici un lien SO pour plus de détails] (https://stackoverflow.com/questions/2976452/whats-is-the-difference-between-train-validation-and-test-set-in-neural-networ) –