Temporal Fusion Transformer (TFTModel) 的 TFT Explainer¶
The TFTExplainer 使用训练好的 TFTModel
,并从模型中提取可解释性信息。
plot_variable_selection()
绘制每个输入特征的变量选择权重。 - 编码器重要性:目标的历史部分、过去协变量和未来协变量的历史部分 - 解码器重要性:未来协变量的未来部分 - 静态协变量重要性:数值和分类静态协变量的重要性plot_attention()
绘制 TFTModel 应用于给定过去和未来输入的 transformer 注意力。注意力在所有注意力头中聚合。
注意力和特征重要性值可以通过 explain()
返回的 TFTExplainabilityResult
提取。方法的描述中提供了示例。
我们还在 TFTModel 的示例 notebook 中展示了如何使用 TFTExplainer,请参见此处。
- class darts.explainability.tft_explainer.TFTExplainer(model, background_series=None, background_past_covariates=None, background_future_covariates=None)[source]¶
基类:
_ForecastingModelExplainer
TFTModel 的解释器类。
定义
背景时序是用于生成可解释性结果的默认 TimeSeries (如果没有将 foreground 传递给
explain()
)。前景时序是可以传递给
explain()
的 TimeSeries,用于替代背景生成可解释性结果。
- 参数
model (darts.models.forecasting.tft_model.TFTModel) – 待解释的拟合好的 TFTModel。
background_series (
Union
[TimeSeries
,Sequence
[TimeSeries
],None
]) – 可选的,用于解释的默认目标时序或时序列表。如果 model 是在单个目标时序上训练的,则此参数是可选的。默认情况下,它是拟合时使用的 series。如果 model 是在多个(时序序列)目标时序上训练的,则此参数是强制的。background_past_covariates (
Union
[TimeSeries
,Sequence
[TimeSeries
],None
]) – 可选的,用于解释的默认过去协变量时序或时序列表。要求与 background_series 相同。background_future_covariates (
Union
[TimeSeries
,Sequence
[TimeSeries
],None
]) – 可选的,用于解释的默认未来协变量时序或时序列表。要求与 background_series 相同。
示例
>>> from darts.datasets import AirPassengersDataset >>> from darts.explainability.tft_explainer import TFTExplainer >>> from darts.models import TFTModel >>> series = AirPassengersDataset().load() >>> model = TFTModel( >>> input_chunk_length=12, >>> output_chunk_length=6, >>> add_encoders={"cyclic": {"future": ["hour"]}} >>> ) >>> model.fit(series) >>> # create the explainer and generate explanations >>> explainer = TFTExplainer(model) >>> results = explainer.explain() >>> # plot the results >>> explainer.plot_attention(results, plot_type="all") >>> explainer.plot_variable_selection(results)
方法
explain
([foreground_series, ...])返回 foreground_series 中所有时序的
TFTExplainabilityResult
结果。plot_attention
(expl_result[, plot_type, ...])绘制 TFTModel 的注意力头。
plot_variable_selection
(expl_result[, ...])根据输入绘制 TFTModel 的变量选择/特征重要性。
- explain(foreground_series=None, foreground_past_covariates=None, foreground_future_covariates=None, horizons=None, target_components=None)[source]¶
返回 foreground_series 中所有时序的
TFTExplainabilityResult
结果。如果 foreground_series 为 None,将使用创建 TFTExplainer 时的 background 输入(可以是创建时传递的 background,也可以是如果 TFTModel 只在单个时序上训练时存储在 TFTModel 中的时序)。对于每个时序,结果包含注意力头、编码器变量重要性、解码器变量重要性、和静态协变量重要性。- 参数
foreground_series (
Union
[TimeSeries
,Sequence
[TimeSeries
],None
]) – 可选的,一个或一个待解释的目标 TimeSeries 序列。可以是多变量的。如果未提供,将解释背景 TimeSeries。foreground_past_covariates (
Union
[TimeSeries
,Sequence
[TimeSeries
],None
]) – 可选的,一个或一个过去协变量 TimeSeries 序列,如果预测模型需要。foreground_future_covariates (
Union
[TimeSeries
,Sequence
[TimeSeries
],None
]) – 可选的,一个或一个未来协变量 TimeSeries 序列,如果预测模型需要。horizons (
Optional
[Sequence
[int
],None
]) – TFTExplainer 不使用此参数。target_components (
Optional
[Sequence
[str
],None
]) – TFTExplainer 不使用此参数。
- 返回值
可解释性结果,包含注意力头、编码器变量重要性、解码器变量重要性、和静态协变量重要性。
- 返回类型
示例
>>> explainer = TFTExplainer(model) # requires `background` if model was trained on multiple series
可选地,提供前景输入以在新输入上生成解释。否则,留空以计算创建 TFTExplainer 时的背景输入上的解释。
>>> explain_results = explainer.explain( >>> foreground_series=foreground_series, >>> foreground_past_covariates=foreground_past_covariates, >>> foreground_future_covariates=foreground_future_covariates, >>> ) >>> attn = explain_results.get_attention() >>> importances = explain_results.get_feature_importances()
- plot_attention(expl_result, plot_type='all', show_index_as='relative', ax=None, max_nr_series=5, show_plot=True)[source]¶
绘制 TFTModel 的注意力头。
- 参数
expl_result (
TFTExplainabilityResult
) – 一个 TFTExplainabilityResult 对象。对应于explain()
的输出。plot_type (
Optional
[Literal
[‘all’, ‘time’, ‘heatmap’],None
]) – 注意力头图的类型。可选值为 (“all”, “time”, “heatmap”) 之一。如果为 “all”,将绘制每个预测范围的注意力(根据 TFTExplainabilityResult 中的预测范围)。最大预测范围对应于训练好的 TFTModel 的 output_chunk_length。如果为 “time”,将绘制所有预测范围的平均注意力。如果为 “heatmap”,将在热力图上绘制每个预测范围的注意力。y 轴表示预测范围,x 轴表示时间/相对索引。show_index_as (
Literal
[‘relative’, ‘time’]) – 要显示的索引类型。可选值为 (“relative”, “time”) 之一。如果为 “relative”,x 轴将从 (-input_chunk_length, output_chunk_length - 1) 绘制。0 对应于第一个预测点。如果为 “time”,x 轴将显示对应 TFTExplainabilityResult 的实际时间索引(或范围索引)。ax (
Optional
[Axes
,None
]) – 可选地,指定一个绘制轴。仅对单个 expl_result 有效。max_nr_series (
int
) – 如果 expl_result 是在多个时序上计算的,则显示的最大图数。show_plot (
bool
) – 是否显示图。
- 返回类型
坐标轴
- plot_variable_selection(expl_result, fig_size=None, max_nr_series=5)[source]¶
根据输入绘制 TFTModel 的变量选择/特征重要性。图中包含三个子图
编码器重要性:包含过去目标、过去协变量和历史未来协变量在编码器(输入块)上的重要性
解码器重要性:包含未来协变量在解码器(输出块)上的重要性
静态协变量重要性:包含数值和/或分类静态协变量的重要性
- 参数
expl_result (
TFTExplainabilityResult
) – 一个 TFTExplainabilityResult 对象。对应于explain()
的输出。fig_size – 要绘制的图的大小。
max_nr_series (
int
) – 如果 expl_result 是在多个时序上计算的,则显示的最大图数。