Plotting NBEATS output#

An example plot of autopycoin.models.NBEATS

plot nbeats

Out:

Epoch 1/20

      1/Unknown - 3s 3s/step - loss: 236.4323 - output_1_loss: 125.5018 - output_2_loss: 110.9305 - output_1_mae: 8.1811 - output_2_mae: 8.4217
      7/Unknown - 3s 9ms/step - loss: 130.1454 - output_1_loss: 60.4274 - output_2_loss: 69.7181 - output_1_mae: 5.3354 - output_2_mae: 6.2513 
9/9 [==============================] - 4s 116ms/step - loss: 107.7559 - output_1_loss: 48.8488 - output_2_loss: 58.9071 - output_1_mae: 4.6330 - output_2_mae: 5.6064 - val_loss: 28.2664 - val_output_1_loss: 6.3439 - val_output_2_loss: 21.9225 - val_output_1_mae: 2.0192 - val_output_2_mae: 3.4448
Epoch 2/20

1/9 [==>...........................] - ETA: 0s - loss: 26.5913 - output_1_loss: 5.7781 - output_2_loss: 20.8133 - output_1_mae: 1.9102 - output_2_mae: 3.3815
7/9 [======================>.......] - ETA: 0s - loss: 17.8943 - output_1_loss: 5.2912 - output_2_loss: 12.6031 - output_1_mae: 1.8171 - output_2_mae: 2.6154
9/9 [==============================] - 0s 19ms/step - loss: 17.0341 - output_1_loss: 5.0195 - output_2_loss: 12.0146 - output_1_mae: 1.7654 - output_2_mae: 2.6293 - val_loss: 9.0470 - val_output_1_loss: 3.2642 - val_output_2_loss: 5.7828 - val_output_1_mae: 1.4367 - val_output_2_mae: 1.9238
Epoch 3/20

1/9 [==>...........................] - ETA: 0s - loss: 9.6009 - output_1_loss: 3.0030 - output_2_loss: 6.5978 - output_1_mae: 1.4071 - output_2_mae: 2.0672
6/9 [===================>..........] - ETA: 0s - loss: 8.7424 - output_1_loss: 2.7895 - output_2_loss: 5.9529 - output_1_mae: 1.3381 - output_2_mae: 1.9028
9/9 [==============================] - 0s 20ms/step - loss: 7.7298 - output_1_loss: 2.6528 - output_2_loss: 5.0770 - output_1_mae: 1.2994 - output_2_mae: 1.7538 - val_loss: 4.8394 - val_output_1_loss: 1.9176 - val_output_2_loss: 2.9218 - val_output_1_mae: 1.1013 - val_output_2_mae: 1.3557
Epoch 4/20

1/9 [==>...........................] - ETA: 0s - loss: 4.9613 - output_1_loss: 1.8909 - output_2_loss: 3.0704 - output_1_mae: 1.1051 - output_2_mae: 1.3992
6/9 [===================>..........] - ETA: 0s - loss: 4.9043 - output_1_loss: 1.9614 - output_2_loss: 2.9428 - output_1_mae: 1.1248 - output_2_mae: 1.3552
9/9 [==============================] - 0s 20ms/step - loss: 4.5169 - output_1_loss: 1.9123 - output_2_loss: 2.6046 - output_1_mae: 1.1041 - output_2_mae: 1.2679 - val_loss: 4.1807 - val_output_1_loss: 1.5798 - val_output_2_loss: 2.6009 - val_output_1_mae: 0.9889 - val_output_2_mae: 1.3132
Epoch 5/20

1/9 [==>...........................] - ETA: 0s - loss: 3.4573 - output_1_loss: 1.4071 - output_2_loss: 2.0502 - output_1_mae: 0.9480 - output_2_mae: 1.1720
7/9 [======================>.......] - ETA: 0s - loss: 3.4069 - output_1_loss: 1.5083 - output_2_loss: 1.8985 - output_1_mae: 0.9860 - output_2_mae: 1.1099
9/9 [==============================] - 0s 20ms/step - loss: 3.3444 - output_1_loss: 1.5030 - output_2_loss: 1.8413 - output_1_mae: 0.9812 - output_2_mae: 1.0946 - val_loss: 3.5574 - val_output_1_loss: 1.3821 - val_output_2_loss: 2.1753 - val_output_1_mae: 0.9314 - val_output_2_mae: 1.1719
Epoch 6/20

1/9 [==>...........................] - ETA: 0s - loss: 2.6653 - output_1_loss: 1.2244 - output_2_loss: 1.4409 - output_1_mae: 0.8953 - output_2_mae: 0.9681
7/9 [======================>.......] - ETA: 0s - loss: 2.8864 - output_1_loss: 1.3194 - output_2_loss: 1.5670 - output_1_mae: 0.9260 - output_2_mae: 0.9970
9/9 [==============================] - 0s 19ms/step - loss: 2.8291 - output_1_loss: 1.3181 - output_2_loss: 1.5110 - output_1_mae: 0.9233 - output_2_mae: 0.9787 - val_loss: 3.0871 - val_output_1_loss: 1.1480 - val_output_2_loss: 1.9391 - val_output_1_mae: 0.8554 - val_output_2_mae: 1.1069
Epoch 7/20

1/9 [==>...........................] - ETA: 0s - loss: 2.4455 - output_1_loss: 1.1130 - output_2_loss: 1.3325 - output_1_mae: 0.8555 - output_2_mae: 0.9570
7/9 [======================>.......] - ETA: 0s - loss: 2.5805 - output_1_loss: 1.2191 - output_2_loss: 1.3614 - output_1_mae: 0.8931 - output_2_mae: 0.9413
9/9 [==============================] - 0s 19ms/step - loss: 2.5510 - output_1_loss: 1.2294 - output_2_loss: 1.3216 - output_1_mae: 0.8938 - output_2_mae: 0.9291 - val_loss: 2.9352 - val_output_1_loss: 1.0850 - val_output_2_loss: 1.8501 - val_output_1_mae: 0.8318 - val_output_2_mae: 1.0866
Epoch 8/20

1/9 [==>...........................] - ETA: 0s - loss: 2.2622 - output_1_loss: 1.0602 - output_2_loss: 1.2020 - output_1_mae: 0.8342 - output_2_mae: 0.9010
7/9 [======================>.......] - ETA: 0s - loss: 2.4388 - output_1_loss: 1.1645 - output_2_loss: 1.2743 - output_1_mae: 0.8744 - output_2_mae: 0.9063
9/9 [==============================] - 0s 20ms/step - loss: 2.4114 - output_1_loss: 1.1762 - output_2_loss: 1.2352 - output_1_mae: 0.8761 - output_2_mae: 0.8928 - val_loss: 2.8318 - val_output_1_loss: 1.0023 - val_output_2_loss: 1.8295 - val_output_1_mae: 0.8007 - val_output_2_mae: 1.0875
Epoch 9/20

1/9 [==>...........................] - ETA: 0s - loss: 2.1524 - output_1_loss: 1.0083 - output_2_loss: 1.1441 - output_1_mae: 0.8115 - output_2_mae: 0.8893
6/9 [===================>..........] - ETA: 0s - loss: 2.3350 - output_1_loss: 1.0946 - output_2_loss: 1.2404 - output_1_mae: 0.8481 - output_2_mae: 0.9033
9/9 [==============================] - 0s 21ms/step - loss: 2.2890 - output_1_loss: 1.1396 - output_2_loss: 1.1494 - output_1_mae: 0.8614 - output_2_mae: 0.8637 - val_loss: 2.7938 - val_output_1_loss: 0.9954 - val_output_2_loss: 1.7984 - val_output_1_mae: 0.7990 - val_output_2_mae: 1.0817
Epoch 10/20

1/9 [==>...........................] - ETA: 0s - loss: 2.0338 - output_1_loss: 0.9802 - output_2_loss: 1.0536 - output_1_mae: 0.7989 - output_2_mae: 0.8550
7/9 [======================>.......] - ETA: 0s - loss: 2.2304 - output_1_loss: 1.0998 - output_2_loss: 1.1306 - output_1_mae: 0.8490 - output_2_mae: 0.8577
9/9 [==============================] - 0s 21ms/step - loss: 2.2126 - output_1_loss: 1.1165 - output_2_loss: 1.0962 - output_1_mae: 0.8530 - output_2_mae: 0.8450 - val_loss: 2.7527 - val_output_1_loss: 0.9700 - val_output_2_loss: 1.7827 - val_output_1_mae: 0.7883 - val_output_2_mae: 1.0751
Epoch 11/20

1/9 [==>...........................] - ETA: 0s - loss: 1.9510 - output_1_loss: 0.9578 - output_2_loss: 0.9932 - output_1_mae: 0.7897 - output_2_mae: 0.8276
7/9 [======================>.......] - ETA: 0s - loss: 2.1663 - output_1_loss: 1.0815 - output_2_loss: 1.0848 - output_1_mae: 0.8420 - output_2_mae: 0.8396
9/9 [==============================] - 0s 20ms/step - loss: 2.1485 - output_1_loss: 1.0978 - output_2_loss: 1.0507 - output_1_mae: 0.8458 - output_2_mae: 0.8261 - val_loss: 2.7066 - val_output_1_loss: 0.9501 - val_output_2_loss: 1.7565 - val_output_1_mae: 0.7825 - val_output_2_mae: 1.0676
Epoch 12/20

1/9 [==>...........................] - ETA: 0s - loss: 1.9013 - output_1_loss: 0.9397 - output_2_loss: 0.9616 - output_1_mae: 0.7824 - output_2_mae: 0.8190
6/9 [===================>..........] - ETA: 0s - loss: 2.1356 - output_1_loss: 1.0369 - output_2_loss: 1.0987 - output_1_mae: 0.8264 - output_2_mae: 0.8511
9/9 [==============================] - 0s 20ms/step - loss: 2.0962 - output_1_loss: 1.0819 - output_2_loss: 1.0143 - output_1_mae: 0.8401 - output_2_mae: 0.8134 - val_loss: 2.6761 - val_output_1_loss: 0.9454 - val_output_2_loss: 1.7307 - val_output_1_mae: 0.7805 - val_output_2_mae: 1.0581
Epoch 13/20

1/9 [==>...........................] - ETA: 0s - loss: 1.8536 - output_1_loss: 0.9302 - output_2_loss: 0.9234 - output_1_mae: 0.7792 - output_2_mae: 0.8030
7/9 [======================>.......] - ETA: 0s - loss: 2.0736 - output_1_loss: 1.0537 - output_2_loss: 1.0199 - output_1_mae: 0.8323 - output_2_mae: 0.8151
9/9 [==============================] - 0s 20ms/step - loss: 2.0555 - output_1_loss: 1.0691 - output_2_loss: 0.9865 - output_1_mae: 0.8354 - output_2_mae: 0.8018 - val_loss: 2.6687 - val_output_1_loss: 0.9350 - val_output_2_loss: 1.7338 - val_output_1_mae: 0.7782 - val_output_2_mae: 1.0593
Epoch 14/20

1/9 [==>...........................] - ETA: 0s - loss: 1.8157 - output_1_loss: 0.9162 - output_2_loss: 0.8995 - output_1_mae: 0.7737 - output_2_mae: 0.7939
7/9 [======================>.......] - ETA: 0s - loss: 2.0363 - output_1_loss: 1.0423 - output_2_loss: 0.9940 - output_1_mae: 0.8282 - output_2_mae: 0.8038
9/9 [==============================] - 0s 20ms/step - loss: 2.0185 - output_1_loss: 1.0574 - output_2_loss: 0.9611 - output_1_mae: 0.8312 - output_2_mae: 0.7903 - val_loss: 2.6452 - val_output_1_loss: 0.9316 - val_output_2_loss: 1.7136 - val_output_1_mae: 0.7762 - val_output_2_mae: 1.0524
Epoch 15/20

1/9 [==>...........................] - ETA: 0s - loss: 1.7882 - output_1_loss: 0.9070 - output_2_loss: 0.8812 - output_1_mae: 0.7705 - output_2_mae: 0.7859
7/9 [======================>.......] - ETA: 0s - loss: 2.0051 - output_1_loss: 1.0328 - output_2_loss: 0.9722 - output_1_mae: 0.8247 - output_2_mae: 0.7945
9/9 [==============================] - 0s 20ms/step - loss: 1.9877 - output_1_loss: 1.0473 - output_2_loss: 0.9404 - output_1_mae: 0.8274 - output_2_mae: 0.7819 - val_loss: 2.6406 - val_output_1_loss: 0.9205 - val_output_2_loss: 1.7201 - val_output_1_mae: 0.7710 - val_output_2_mae: 1.0518
Epoch 16/20

1/9 [==>...........................] - ETA: 0s - loss: 1.7616 - output_1_loss: 0.8973 - output_2_loss: 0.8644 - output_1_mae: 0.7654 - output_2_mae: 0.7809
7/9 [======================>.......] - ETA: 0s - loss: 1.9762 - output_1_loss: 1.0231 - output_2_loss: 0.9531 - output_1_mae: 0.8211 - output_2_mae: 0.7865
9/9 [==============================] - 0s 20ms/step - loss: 1.9593 - output_1_loss: 1.0381 - output_2_loss: 0.9212 - output_1_mae: 0.8239 - output_2_mae: 0.7734 - val_loss: 2.6215 - val_output_1_loss: 0.9179 - val_output_2_loss: 1.7036 - val_output_1_mae: 0.7713 - val_output_2_mae: 1.0484
Epoch 17/20

1/9 [==>...........................] - ETA: 0s - loss: 1.7307 - output_1_loss: 0.8897 - output_2_loss: 0.8410 - output_1_mae: 0.7627 - output_2_mae: 0.7725
7/9 [======================>.......] - ETA: 0s - loss: 1.9501 - output_1_loss: 1.0150 - output_2_loss: 0.9352 - output_1_mae: 0.8180 - output_2_mae: 0.7802
9/9 [==============================] - 0s 19ms/step - loss: 1.9335 - output_1_loss: 1.0297 - output_2_loss: 0.9039 - output_1_mae: 0.8207 - output_2_mae: 0.7667 - val_loss: 2.6158 - val_output_1_loss: 0.9161 - val_output_2_loss: 1.6997 - val_output_1_mae: 0.7690 - val_output_2_mae: 1.0471
Epoch 18/20

1/9 [==>...........................] - ETA: 0s - loss: 1.7072 - output_1_loss: 0.8813 - output_2_loss: 0.8259 - output_1_mae: 0.7589 - output_2_mae: 0.7648
7/9 [======================>.......] - ETA: 0s - loss: 1.9249 - output_1_loss: 1.0064 - output_2_loss: 0.9185 - output_1_mae: 0.8150 - output_2_mae: 0.7724
9/9 [==============================] - 0s 20ms/step - loss: 1.9093 - output_1_loss: 1.0211 - output_2_loss: 0.8882 - output_1_mae: 0.8177 - output_2_mae: 0.7597 - val_loss: 2.6201 - val_output_1_loss: 0.9097 - val_output_2_loss: 1.7104 - val_output_1_mae: 0.7664 - val_output_2_mae: 1.0478
Epoch 19/20

1/9 [==>...........................] - ETA: 0s - loss: 1.6886 - output_1_loss: 0.8717 - output_2_loss: 0.8170 - output_1_mae: 0.7550 - output_2_mae: 0.7615
7/9 [======================>.......] - ETA: 0s - loss: 1.9044 - output_1_loss: 0.9983 - output_2_loss: 0.9062 - output_1_mae: 0.8114 - output_2_mae: 0.7670
9/9 [==============================] - 0s 20ms/step - loss: 1.8892 - output_1_loss: 1.0132 - output_2_loss: 0.8760 - output_1_mae: 0.8144 - output_2_mae: 0.7545 - val_loss: 2.6042 - val_output_1_loss: 0.9108 - val_output_2_loss: 1.6934 - val_output_1_mae: 0.7671 - val_output_2_mae: 1.0434
Epoch 20/20

1/9 [==>...........................] - ETA: 0s - loss: 1.6668 - output_1_loss: 0.8637 - output_2_loss: 0.8031 - output_1_mae: 0.7508 - output_2_mae: 0.7570
7/9 [======================>.......] - ETA: 0s - loss: 1.8849 - output_1_loss: 0.9920 - output_2_loss: 0.8929 - output_1_mae: 0.8091 - output_2_mae: 0.7617
9/9 [==============================] - 0s 20ms/step - loss: 1.8701 - output_1_loss: 1.0067 - output_2_loss: 0.8634 - output_1_mae: 0.8120 - output_2_mae: 0.7497 - val_loss: 2.5898 - val_output_1_loss: 0.9036 - val_output_2_loss: 1.6862 - val_output_1_mae: 0.7644 - val_output_2_mae: 1.0412

<matplotlib.legend.Legend object at 0x7fe660459d90>

import tensorflow as tf
from matplotlib import pyplot as plt
import pandas as pd

from autopycoin.data import random_ts
from autopycoin.dataset import WindowGenerator
from autopycoin.models import nbeats


tf.random.set_seed(0)

data = random_ts(n_steps=400, # Number of steps (second dimension)
                 trend_degree=2,
                 periods=[10], # We can combine multiple periods, period is the time length for a cyclical function to reproduce a similar output
                 fourier_orders=[10], # higher is this number, more complex is the output
                 trend_mean=0,
                 trend_std=1,
                 seasonality_mean=0,
                 seasonality_std=1,
                 batch_size=1, # Generate a batch of data (first dimension)
                 n_variables=1, # Number of variables (last dimension)
                 noise=True, # add normal centered noise
                 seed=42)


w = WindowGenerator(
        input_width=80,
        label_width=40,
        shift=40,
        test_size=50,
        valid_size=10,
        flat=True,
        batch_size=16,
        preprocessing = lambda x,y: (x, (x,y)) # NBEATS output
    )

data = pd.DataFrame(data.numpy()[0], columns=['test'])

w = w.from_array(data=data, # Has to be 2D array
        input_columns=['test'],
        known_columns=[],
        label_columns=['test'],
        date_columns=[],)

model1 = nbeats.create_interpretable_nbeats(
            label_width=40,
            forecast_periods=[10],
            backcast_periods=[10],
            forecast_fourier_order=[10],
            backcast_fourier_order=[10],
            p_degree=1,
            trend_n_neurons=200,
            seasonality_n_neurons=200,
            drop_rate=0.,
            share=True)

model1.compile(tf.keras.optimizers.Adam(
    learning_rate=0.0015, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=True,
    name='Adam'),
    loss='mse',
    loss_weights=[1, 1], # In the paper = [0, 1]
    metrics=["mae"])

model1.fit(w.train, validation_data=w.valid, epochs=20)

iterator = iter(w.train)
x, y = iterator.get_next()

input_width = 80

plt.plot(range(input_width, input_width + 40), model1.predict(x)[1].values[0], label='forecast')
# Usefull only if stack = True
plt.plot(range(input_width), model1.predict(x)[0].values[0], label='backcast')
plt.plot(range(input_width, input_width + 40), y[1][0], label='labels')
plt.plot(range(input_width), x[0], label='inputs')
plt.legend()

Total running time of the script: ( 0 minutes 8.579 seconds)

Gallery generated by Sphinx-Gallery