Temporal Fusion Transformer (TFTModel) 的 TFT Explainer

The TFTExplainer 使用训练好的 TFTModel,并从模型中提取可解释性信息。

  • plot_variable_selection() 绘制每个输入特征的变量选择权重。 - 编码器重要性:目标的历史部分、过去协变量和未来协变量的历史部分 - 解码器重要性:未来协变量的未来部分 - 静态协变量重要性:数值和分类静态协变量的重要性

  • plot_attention() 绘制 TFTModel 应用于给定过去和未来输入的 transformer 注意力。注意力在所有注意力头中聚合。

注意力和特征重要性值可以通过 explain() 返回的 TFTExplainabilityResult 提取。方法的描述中提供了示例。

我们还在 TFTModel 的示例 notebook 中展示了如何使用 TFTExplainer,请参见此处

class darts.explainability.tft_explainer.TFTExplainer(model, background_series=None, background_past_covariates=None, background_future_covariates=None)[source]

基类:_ForecastingModelExplainer

TFTModel 的解释器类。

定义

  • 背景时序是用于生成可解释性结果的默认 TimeSeries (如果没有将 foreground 传递给 explain())。

  • 前景时序是可以传递给 explain()TimeSeries,用于替代背景生成可解释性结果。

参数
  • model (darts.models.forecasting.tft_model.TFTModel) – 待解释的拟合好的 TFTModel

  • background_series (Union[TimeSeries, Sequence[TimeSeries], None]) – 可选的,用于解释的默认目标时序或时序列表。如果 model 是在单个目标时序上训练的,则此参数是可选的。默认情况下,它是拟合时使用的 series。如果 model 是在多个(时序序列)目标时序上训练的,则此参数是强制的。

  • background_past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – 可选的,用于解释的默认过去协变量时序或时序列表。要求与 background_series 相同。

  • background_future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – 可选的,用于解释的默认未来协变量时序或时序列表。要求与 background_series 相同。

示例

>>> from darts.datasets import AirPassengersDataset
>>> from darts.explainability.tft_explainer import TFTExplainer
>>> from darts.models import TFTModel
>>> series = AirPassengersDataset().load()
>>> model = TFTModel(
>>>     input_chunk_length=12,
>>>     output_chunk_length=6,
>>>     add_encoders={"cyclic": {"future": ["hour"]}}
>>> )
>>> model.fit(series)
>>> # create the explainer and generate explanations
>>> explainer = TFTExplainer(model)
>>> results = explainer.explain()
>>> # plot the results
>>> explainer.plot_attention(results, plot_type="all")
>>> explainer.plot_variable_selection(results)

方法

explain([foreground_series, ...])

返回 foreground_series 中所有时序的 TFTExplainabilityResult 结果。

plot_attention(expl_result[, plot_type, ...])

绘制 TFTModel 的注意力头。

plot_variable_selection(expl_result[, ...])

根据输入绘制 TFTModel 的变量选择/特征重要性。

explain(foreground_series=None, foreground_past_covariates=None, foreground_future_covariates=None, horizons=None, target_components=None)[source]

返回 foreground_series 中所有时序的 TFTExplainabilityResult 结果。如果 foreground_seriesNone,将使用创建 TFTExplainer 时的 background 输入(可以是创建时传递的 background,也可以是如果 TFTModel 只在单个时序上训练时存储在 TFTModel 中的时序)。对于每个时序,结果包含注意力头、编码器变量重要性、解码器变量重要性、和静态协变量重要性。

参数
  • foreground_series (Union[TimeSeries, Sequence[TimeSeries], None]) – 可选的,一个或一个待解释的目标 TimeSeries 序列。可以是多变量的。如果未提供,将解释背景 TimeSeries

  • foreground_past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – 可选的,一个或一个过去协变量 TimeSeries 序列,如果预测模型需要。

  • foreground_future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – 可选的,一个或一个未来协变量 TimeSeries 序列,如果预测模型需要。

  • horizons (Optional[Sequence[int], None]) – TFTExplainer 不使用此参数。

  • target_components (Optional[Sequence[str], None]) – TFTExplainer 不使用此参数。

返回值

可解释性结果,包含注意力头、编码器变量重要性、解码器变量重要性、和静态协变量重要性。

返回类型

TFTExplainabilityResult

示例

>>> explainer = TFTExplainer(model)  # requires `background` if model was trained on multiple series

可选地,提供前景输入以在新输入上生成解释。否则,留空以计算创建 TFTExplainer 时的背景输入上的解释。

>>> explain_results = explainer.explain(
>>>     foreground_series=foreground_series,
>>>     foreground_past_covariates=foreground_past_covariates,
>>>     foreground_future_covariates=foreground_future_covariates,
>>> )
>>> attn = explain_results.get_attention()
>>> importances = explain_results.get_feature_importances()
model: TFTModel
plot_attention(expl_result, plot_type='all', show_index_as='relative', ax=None, max_nr_series=5, show_plot=True)[source]

绘制 TFTModel 的注意力头。

参数
  • expl_result (TFTExplainabilityResult) – 一个 TFTExplainabilityResult 对象。对应于 explain() 的输出。

  • plot_type (Optional[Literal[‘all’, ‘time’, ‘heatmap’], None]) – 注意力头图的类型。可选值为 (“all”, “time”, “heatmap”) 之一。如果为 “all”,将绘制每个预测范围的注意力(根据 TFTExplainabilityResult 中的预测范围)。最大预测范围对应于训练好的 TFTModeloutput_chunk_length。如果为 “time”,将绘制所有预测范围的平均注意力。如果为 “heatmap”,将在热力图上绘制每个预测范围的注意力。y 轴表示预测范围,x 轴表示时间/相对索引。

  • show_index_as (Literal[‘relative’, ‘time’]) – 要显示的索引类型。可选值为 (“relative”, “time”) 之一。如果为 “relative”,x 轴将从 (-input_chunk_length, output_chunk_length - 1) 绘制。0 对应于第一个预测点。如果为 “time”,x 轴将显示对应 TFTExplainabilityResult 的实际时间索引(或范围索引)。

  • ax (Optional[Axes, None]) – 可选地,指定一个绘制轴。仅对单个 expl_result 有效。

  • max_nr_series (int) – 如果 expl_result 是在多个时序上计算的,则显示的最大图数。

  • show_plot (bool) – 是否显示图。

返回类型

坐标轴

plot_variable_selection(expl_result, fig_size=None, max_nr_series=5)[source]

根据输入绘制 TFTModel 的变量选择/特征重要性。图中包含三个子图

  • 编码器重要性:包含过去目标、过去协变量和历史未来协变量在编码器(输入块)上的重要性

  • 解码器重要性:包含未来协变量在解码器(输出块)上的重要性

  • 静态协变量重要性:包含数值和/或分类静态协变量的重要性

参数
  • expl_result (TFTExplainabilityResult) – 一个 TFTExplainabilityResult 对象。对应于 explain() 的输出。

  • fig_size – 要绘制的图的大小。

  • max_nr_series (int) – 如果 expl_result 是在多个时序上计算的,则显示的最大图数。