- class darts.utils.callbacks.TFMProgressBar(enable_sanity_check_bar=True, enable_train_bar=True, enable_validation_bar=True, enable_prediction_bar=True, enable_train_bar_only=False, **kwargs)[source]¶
基类:
TQDMProgressBar
Darts 用于 TorchForecastingModels 的进度条。
允许自定义在哪些模型阶段(健全性检查、训练、验证、预测)显示进度条。
此类是 PyTorch Lightning 的一个 Callback,可以通过 pl_trainer_kwargs 参数传递给 TorchForecastingModel 构造函数。
示例
>>> from darts.models import NBEATSModel >>> from darts.utils.callbacks import TFMProgressBar >>> # only display the training bar and not the validation, prediction, and sanity check bars >>> prog_bar = TFMProgressBar(enable_train_bar_only=True) >>> model = NBEATSModel(1, 1, pl_trainer_kwargs={"callbacks": [prog_bar]})
- 参数
enable_sanity_check_bar (
bool
) – 是否启用健全性检查的进度条。enable_train_bar (
bool
) – 是否启用训练的进度条。enable_validation_bar (
bool
) – 是否启用验证的进度条。enable_prediction_bar (
bool
) – 是否启用预测的进度条。enable_train_bar_only (
bool
) – 是否禁用除训练进度条外的所有进度条。**kwargs – 传递给 PyTorch Lightning 的 TQDMProgressBar 的参数。
属性
回调状态的标识符。
预测的总批次数,对于当前数据加载器,每次 epoch 可能会有所不同。
测试的总批次数,对于当前数据加载器,每次 epoch 可能会有所不同。
训练的总批次数,每次 epoch 可能会有所不同。
验证的总批次数,对于所有验证数据加载器,每次 epoch 可能会有所不同。
验证的总批次数,对于当前数据加载器,每次 epoch 可能会有所不同。
is_disabled
is_enabled
predict_description
predict_progress_bar
process_position
refresh_rate
sanity_check_description
test_description
test_progress_bar
train_description
train_progress_bar
trainer
val_progress_bar
validation_description
方法
disable
()您应该提供一种禁用进度条的方法。
enable
()您应该提供一种启用进度条的方式。
get_metrics
(trainer, pl_module)将从 trainer 收集的进度条指标与 get_standard_metrics 中的标准指标相结合。实现此方法以覆盖进度条中显示的项目。
dict
[str
,Union
[int
,str
,float
,dict
[str
,float
]]]返回
()包含要在进度条中显示项的字典。
重写此方法以自定义预测的 tqdm 进度条。
Tqdm
()init_sanity_tqdm
init_test_tqdm
重写此方法以自定义测试的 tqdm 进度条。
init_train_tqdm
重写此方法以自定义训练的 tqdm 进度条。
init_validation_tqdm
重写此方法以自定义验证的 tqdm 进度条。
property is_disabled: bool¶
property is_enabled: bool¶
load_state_dict(state_dict)¶
加载检查点时调用,实现此方法以根据回调的
state_dict
重新加载回调状态。state_dict (
dict
[str
,Any
]) – 由state_dict
返回的回调状态。on_after_backward(trainer, pl_module)¶
在
loss.backward()
之后、优化器步进之前调用。on_before_backward(trainer, pl_module, loss)¶
在
loss.backward()
之前调用。on_before_optimizer_step(trainer, pl_module, optimizer)¶
在
optimizer.step()
之前调用。on_before_zero_grad(trainer, pl_module, optimizer)¶
在
optimizer.zero_grad()
之前调用。on_exception(trainer, pl_module, exception)¶
当任何 trainer 执行被异常中断时调用。
on_fit_end(trainer, pl_module)¶
拟合结束时调用。
on_fit_start(trainer, pl_module)¶
拟合开始时调用。
on_load_checkpoint(trainer, pl_module, checkpoint)¶
加载模型检查点时调用,用于重新加载状态。
trainer (
Trainer
) – 当前的Trainer
实例。pl_module (
LightningModule
) – 当前的LightningModule
实例。checkpoint (
dict
[str
,Any
]) – 由 Trainer 加载的完整检查点字典。on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
预测批次结束时调用。
(*_)on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
预测批次开始时调用。
(*_)on_predict_end(trainer, pl_module)¶
预测结束时调用。
on_predict_epoch_end(trainer, pl_module)¶
预测 epoch 结束时调用。
on_predict_epoch_start(trainer, pl_module)¶
预测 epoch 开始时调用。
on_predict_start(trainer, pl_module)¶
预测开始时调用。
on_sanity_check_end
验证健全性检查结束时调用。
on_sanity_check_start
验证健全性检查开始时调用。
on_save_checkpoint(trainer, pl_module, checkpoint)¶
保存检查点时调用,让您有机会存储任何其他您可能想保存的内容。
checkpoint (
dict
[str
,Any
]) – 将要保存的检查点字典。on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
测试批次结束时调用。
on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
测试批次开始时调用。
测试结束时调用。
on_test_epoch_end(trainer, pl_module)¶
测试 epoch 结束时调用。
on_test_epoch_start(trainer, pl_module)¶
测试 epoch 开始时调用。
测试开始时调用。
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)¶
训练批次结束时调用。
注意
此处
outputs["loss"]
的值将是相对于accumulate_grad_batches
从training_step
返回的损失的归一化值。on_train_batch_start(trainer, pl_module, batch, batch_idx)¶
训练批次开始时调用。
on_train_end
训练结束时调用。
on_train_epoch_end(trainer, pl_module)¶
训练 epoch 结束时调用。
要在 epoch 结束时访问所有批次输出,您可以将步进输出缓存为
pytorch_lightning.core.LightningModule
的属性,并在该钩子中访问它们。on_train_epoch_start(trainer, *_)¶
训练 epoch 开始时调用。
on_train_start
训练开始时调用。
on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
验证批次开始时调用。
on_validation_end(trainer, pl_module)¶
验证循环结束时调用。
on_validation_epoch_end(trainer, pl_module)¶
- 验证 epoch 结束时调用。
- on_validation_epoch_start(trainer, pl_module)¶
您应该提供一种禁用进度条的方法。
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- 验证循环开始时调用。
您应该提供一种启用进度条的方式。
print(*args, sep=' ', **kwargs)¶
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- 您应该提供一种在不破坏进度条的情况下进行打印的方法。
setup(trainer, pl_module, stage)¶
当 fit、validate、test、predict 或 tune 开始时调用。
def get_metrics(self, trainer, model): # don't show the version number items = super().get_metrics(trainer, model) items.pop("v_num", None) return items
- 验证 epoch 开始时调用。
state_dict
- 保存检查点时调用,实现此方法以生成回调的
state_dict
。 dict
[str
,Any
]
- 包含回调状态的字典。
- 验证 epoch 开始时调用。
property state_key: str¶
-
用于通过
checkpoint["callbacks"][state_key]
从检查点字典中存储和检索回调的状态。如果 1) 回调具有状态,并且 2) 希望保持该回调多个实例的状态,则回调的实现需要提供唯一的 state key。 dict
[str
,Union
[int
,str
,float
,dict
[str
,float
]]]- 验证 epoch 开始时调用。
teardown(trainer, pl_module, stage)¶
- 当 fit、validate、test、predict 或 tune 结束时调用。
包含要在进度条中显示项的字典。
- 验证 epoch 开始时调用。
teardown(trainer, pl_module, stage)¶
- property test_description: str¶
重写此方法以自定义预测的 tqdm 进度条。
- 验证 epoch 开始时调用。
teardown(trainer, pl_module, stage)¶
- property test_progress_bar: tqdm_asyncio¶
init_sanity_tqdm
- 验证 epoch 开始时调用。
teardown(trainer, pl_module, stage)¶
- property total_predict_batches_current_dataloader: Union[int, float]¶
init_test_tqdm
- 验证 epoch 开始时调用。
teardown(trainer, pl_module, stage)¶
-
使用此属性设置进度条中的总迭代次数。如果预测数据加载器大小无限,则可以返回
inf
。 - 验证 epoch 开始时调用。
property state_key: str¶
-
Union
[int
,float
] - 验证 epoch 开始时调用。
property state_key: str¶
- property total_test_batches_current_dataloader: Union[int, float]¶
使用此属性设置进度条中的总迭代次数。如果测试数据加载器大小无限,则可以返回
inf
。
-
使用此属性设置进度条中的总迭代次数。如果训练数据加载器大小无限,则可以返回
inf
。 init_validation_tqdm
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- property total_val_batches: Union[int, float]¶
property is_disabled: bool¶
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- property total_val_batches_current_dataloader: Union[int, float]¶
load_state_dict(state_dict)¶
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
-
使用此属性设置进度条中的总迭代次数。如果验证数据加载器大小无限,则可以返回
inf
。 state_dict (
dict
[str
,Any
]) – 由state_dict
返回的回调状态。- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- property train_description: str¶
在
loss.backward()
之后、优化器步进之前调用。- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- property train_progress_bar: tqdm_asyncio¶
在
loss.backward()
之前调用。- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- property trainer: Trainer¶
在
optimizer.step()
之前调用。- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- Trainer
在
optimizer.zero_grad()
之前调用。
- © 版权所有 2020 - 2025, Unit8 SA (Apache 2.0 许可证)。
当任何 trainer 执行被异常中断时调用。
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_predict_end(trainer, pl_module)¶
拟合开始时调用。
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_predict_epoch_end(trainer, pl_module)¶
加载模型检查点时调用,用于重新加载状态。
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_predict_epoch_start(trainer, pl_module)¶
pl_module (
LightningModule
) – 当前的LightningModule
实例。- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_predict_start(trainer, pl_module)¶
on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_sanity_check_end(*_)¶
on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_sanity_check_start(*_)¶
on_predict_end(trainer, pl_module)¶
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
on_predict_epoch_start(trainer, pl_module)¶
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
on_predict_start(trainer, pl_module)¶
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_test_end(trainer, pl_module)¶
on_sanity_check_end
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_test_epoch_end(trainer, pl_module)¶
on_sanity_check_start
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_test_epoch_start(trainer, pl_module)¶
on_save_checkpoint(trainer, pl_module, checkpoint)¶
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_test_start(trainer, pl_module)¶
checkpoint (
dict
[str
,Any
]) – 将要保存的检查点字典。- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)¶
测试批次结束时调用。
注意
此处
outputs["loss"]
的值是相对于accumulate_grad_batches
进行归一化的损失值,该损失值由training_step
返回。- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_train_batch_start(trainer, pl_module, batch, batch_idx)¶
测试批次开始时调用。
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_train_epoch_end(trainer, pl_module)¶
测试 epoch 结束时调用。
要在周期结束时访问所有批次输出,您可以将步骤输出缓存为
pytorch_lightning.core.LightningModule
的一个属性,并在此钩子中访问它们class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_train_epoch_start(trainer, *_)¶
测试 epoch 开始时调用。
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
训练批次结束时调用。
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
此处
outputs["loss"]
的值将是相对于accumulate_grad_batches
从training_step
返回的损失的归一化值。- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_validation_end(trainer, pl_module)¶
训练批次开始时调用。
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_validation_epoch_end(trainer, pl_module)¶
训练结束时调用。
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_validation_epoch_start(trainer, pl_module)¶
训练 epoch 结束时调用。
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- on_validation_start(trainer, pl_module)¶
on_train_epoch_start(trainer, *_)¶
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- property predict_description: str¶
- 验证 epoch 开始时调用。
str
- property predict_progress_bar: tqdm_asyncio¶
- 验证 epoch 开始时调用。
tqdm_asyncio
- print(*args, sep=' ', **kwargs)¶
on_train_start
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- property process_position: int¶
- 验证 epoch 开始时调用。
int
- property refresh_rate: int¶
- 验证 epoch 开始时调用。
int
- property sanity_check_description: str¶
- 验证 epoch 开始时调用。
str
- setup(trainer, pl_module, stage)¶
on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- state_dict()¶
在保存检查点时调用,实现此方法以生成回调的
state_dict
。- 验证 epoch 开始时调用。
dict
[str
,Any
]- 保存检查点时调用,实现此方法以生成回调的
state_dict
。 包含回调状态的字典。
- property state_key: str¶
回调状态的标识符。
用于通过
checkpoint["callbacks"][state_key]
从检查点字典中存储和检索回调的状态。如果满足以下条件,回调的实现需要提供一个唯一的 state key:1)回调具有状态;2)希望维护该回调的多个实例的状态。- 验证 epoch 开始时调用。
str
- teardown(trainer, pl_module, stage)¶
on_validation_end(trainer, pl_module)¶
- 验证 epoch 开始时调用。
on_validation_start(trainer, pl_module)¶
- property test_description: str¶
- 验证 epoch 开始时调用。
str
- property test_progress_bar: tqdm_asyncio¶
- 验证 epoch 开始时调用。
tqdm_asyncio
- property total_predict_batches_current_dataloader: Union[int, float]¶
预测的总批次数,对于当前数据加载器,每次 epoch 可能会有所不同。
使用此属性设置进度条中的总迭代次数。如果预测数据加载器大小无限,可以返回
inf
。- 验证 epoch 开始时调用。
Union
[int
,float
]
- property total_test_batches_current_dataloader: Union[int, float]¶
测试的总批次数,对于当前数据加载器,每次 epoch 可能会有所不同。
使用此属性设置进度条中的总迭代次数。如果测试数据加载器大小无限,可以返回
inf
。- 验证 epoch 开始时调用。
Union
[int
,float
]
- property total_train_batches: Union[int, float]¶
训练的总批次数,每次 epoch 可能会有所不同。
使用此属性设置进度条中的总迭代次数。如果训练数据加载器大小无限,可以返回
inf
。- 验证 epoch 开始时调用。
Union
[int
,float
]
- property total_val_batches: Union[int, float]¶
验证的总批次数,对于所有验证数据加载器,每次 epoch 可能会有所不同。
使用此属性设置进度条中的总迭代次数。如果预测数据加载器大小无限,可以返回
inf
。- 验证 epoch 开始时调用。
Union
[int
,float
]
- property total_val_batches_current_dataloader: Union[int, float]¶
验证的总批次数,对于当前数据加载器,每次 epoch 可能会有所不同。
使用此属性设置进度条中的总迭代次数。如果验证数据加载器大小无限,可以返回
inf
。- 验证 epoch 开始时调用。
Union
[int
,float
]
- property train_description: str¶
- 验证 epoch 开始时调用。
str
- property train_progress_bar: tqdm_asyncio¶
- 验证 epoch 开始时调用。
tqdm_asyncio
- property trainer: Trainer¶
- 验证 epoch 开始时调用。
Trainer
- property val_progress_bar: tqdm_asyncio¶
- 验证 epoch 开始时调用。
tqdm_asyncio
- property validation_description: str¶
- 验证 epoch 开始时调用。
str