N-BEATS¶
在此 notebook 中,我们将展示如何在 darts 中使用 N-BEATS 的示例。如果您是 darts 新手,建议您先按照快速入门 notebook 进行学习。
N-BEATS 是一个最先进的模型,展示了在时间序列预测背景下 纯 DL 架构 的潜力。它在 M3 和 M4 竞赛中优于成熟的统计方法。有关模型的更多详细信息,请参阅:https://arxiv.org/pdf/1905.10437.pdf。
[1]:
# fix python path if working locally
from utils import fix_pythonpath_if_working_locally
fix_pythonpath_if_working_locally()
%matplotlib inline
[2]:
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from darts import TimeSeries, concatenate
from darts.dataprocessing.transformers import MissingValuesFiller, Scaler
from darts.datasets import EnergyDataset
from darts.metrics import r2_score
from darts.models import NBEATSModel
from darts.utils.callbacks import TFMProgressBar
warnings.filterwarnings("ignore")
import logging
logging.disable(logging.CRITICAL)
def generate_torch_kwargs():
# run torch models on CPU, and disable progress bars for all model stages except training.
return {
"pl_trainer_kwargs": {
"accelerator": "cpu",
"callbacks": [TFMProgressBar(enable_train_bar_only=True)],
}
}
[3]:
def display_forecast(pred_series, ts_transformed, forecast_type, start_date=None):
plt.figure(figsize=(8, 5))
if start_date:
ts_transformed = ts_transformed.drop_before(start_date)
ts_transformed.univariate_component(0).plot(label="actual")
pred_series.plot(label=("historic " + forecast_type + " forecasts"))
plt.title(f"R2: {r2_score(ts_transformed.univariate_component(0), pred_series)}")
plt.legend()
每日能源发电示例¶
我们在一个径流式水电站的每日能源发电数据集上测试 NBEATS,因为它展现出不同程度的季节性
[4]:
df = EnergyDataset().load().to_dataframe()
df["generation hydro run-of-river and poundage"].plot()
plt.title("Hourly generation hydro run-of-river and poundage")
[4]:
Text(0.5, 1.0, 'Hourly generation hydro run-of-river and poundage')

为简化起见,我们处理每日发电量数据,并使用 MissingValuesFiller
填充数据中存在的缺失值
[5]:
df_day_avg = df.groupby(df.index.astype(str).str.split(" ").str[0]).mean().reset_index()
filler = MissingValuesFiller()
scaler = Scaler()
series = filler.transform(
TimeSeries.from_dataframe(
df_day_avg, "time", ["generation hydro run-of-river and poundage"]
)
).astype(np.float32)
train, val = series.split_after(pd.Timestamp("20170901"))
train_scaled = scaler.fit_transform(train)
val_scaled = scaler.transform(val)
series_scaled = scaler.transform(series)
train_scaled.plot(label="training")
val_scaled.plot(label="val")
plt.title("Daily generation hydro run-of-river and poundage")
[5]:
Text(0.5, 1.0, 'Daily generation hydro run-of-river and poundage')

我们将数据分割为训练集和验证集。通常我们需要使用额外的测试集来验证模型在未见数据上的表现,但在此示例中我们将跳过此步骤。
通用架构¶
N-BEATS 是一种单变量模型架构,提供两种配置:一种是通用配置,另一种是可解释配置。通用架构尽可能少地使用先验知识,没有特征工程,没有缩放,也没有任何可能被视为时间序列特定的内部架构组件。
首先,我们使用具有 N-BEATS 通用架构的模型
[6]:
model_name = "nbeats_run"
model_nbeats = NBEATSModel(
input_chunk_length=30,
output_chunk_length=7,
generic_architecture=True,
num_stacks=10,
num_blocks=1,
num_layers=4,
layer_widths=512,
n_epochs=100,
nr_epochs_val_period=1,
batch_size=800,
random_state=42,
model_name=model_name,
save_checkpoints=True,
force_reset=True,
**generate_torch_kwargs(),
)
[7]:
model_nbeats.fit(train_scaled, val_series=val_scaled)
[7]:
NBEATSModel(generic_architecture=True, num_stacks=10, num_blocks=1, num_layers=4, layer_widths=512, expansion_coefficient_dim=5, trend_polynomial_degree=2, dropout=0.0, activation=ReLU, input_chunk_length=30, output_chunk_length=7, n_epochs=100, nr_epochs_val_period=1, batch_size=800, random_state=42, model_name=nbeats_run, save_checkpoints=True, force_reset=True, pl_trainer_kwargs={'accelerator': 'cpu', 'callbacks': [<darts.utils.callbacks.TFMProgressBar object at 0x2b3d98fd0>]})
让我们从在验证集上表现最好的检查点加载模型。
[8]:
model_nbeats = NBEATSModel.load_from_checkpoint(model_name=model_name, best=True)
让我们看看模型在使用扩展训练窗口和预测范围为 7 的情况下将生成的历史预测
[9]:
pred_series = model_nbeats.historical_forecasts(
series_scaled,
start=val.start_time(),
forecast_horizon=7,
stride=7,
last_points_only=False,
retrain=False,
verbose=True,
)
pred_series = concatenate(pred_series)
[10]:
display_forecast(
pred_series,
series_scaled,
"7 day",
start_date=val.start_time(),
)

可解释模型¶
N-BEATS 提供了一种可解释架构,包含两个堆栈:一个趋势堆栈和一个季节性堆栈。该架构设计如下:
趋势分量在输入到季节性堆栈之前被移除
趋势和季节性的部分预测可作为单独的可解释输出获取
[11]:
model_name = "nbeats_interpretable_run"
model_nbeats = NBEATSModel(
input_chunk_length=30,
output_chunk_length=7,
generic_architecture=False,
num_blocks=3,
num_layers=4,
layer_widths=512,
n_epochs=100,
nr_epochs_val_period=1,
batch_size=800,
random_state=42,
model_name=model_name,
save_checkpoints=True,
force_reset=True,
**generate_torch_kwargs(),
)
[12]:
model_nbeats.fit(series=train_scaled, val_series=val_scaled)
[12]:
NBEATSModel(generic_architecture=False, num_stacks=30, num_blocks=3, num_layers=4, layer_widths=512, expansion_coefficient_dim=5, trend_polynomial_degree=2, dropout=0.0, activation=ReLU, input_chunk_length=30, output_chunk_length=7, n_epochs=100, nr_epochs_val_period=1, batch_size=800, random_state=42, model_name=nbeats_interpretable_run, save_checkpoints=True, force_reset=True, pl_trainer_kwargs={'accelerator': 'cpu', 'callbacks': [<darts.utils.callbacks.TFMProgressBar object at 0x2b3fc0790>]})
[13]:
model_nbeats = NBEATSModel.load_from_checkpoint(model_name=model_name, best=True)
让我们看看模型在使用扩展训练窗口和预测范围为 7 的情况下将生成的历史预测
[14]:
pred_series = model_nbeats.historical_forecasts(
series_scaled,
start=val_scaled.start_time(),
forecast_horizon=7,
stride=7,
last_points_only=False,
retrain=False,
verbose=True,
)
pred_series = concatenate(pred_series)
[15]:
display_forecast(
pred_series, series_scaled, "7 day", start_date=val_scaled.start_time()
)

[ ]: