训练数据集基类¶
- class darts.utils.data.training_dataset.DualCovariatesTrainingDataset[source]¶
基础类:
TrainingDataset
,ABC
DualCovariatesTorchModel 训练数据集的抽象类。它包含由 (past_target, historic_future_covariates, future_covariates, static_covariates, future_target) 组成的 np.ndarray。协变量是可选的,可以为 None。
- class darts.utils.data.training_dataset.FutureCovariatesTrainingDataset[source]¶
基础类:
TrainingDataset
,ABC
FutureCovariatesTorchModel 训练数据集的抽象类。它包含由 (past_target, future_covariate, static_covariates, future_target) 组成的 np.ndarray。协变量是可选的,可以为 None。
- class darts.utils.data.training_dataset.MixedCovariatesTrainingDataset[source]¶
基础类:
TrainingDataset
,ABC
MixedCovariatesTorchModel 训练数据集的抽象类。它包含由 (past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates, future_target) 组成的 np.ndarray。协变量是可选的,可以为 None。
- class darts.utils.data.training_dataset.PastCovariatesTrainingDataset[source]¶
基础类:
TrainingDataset
,ABC
PastCovariatesTorchModel 训练数据集的抽象类。它包含由 (past_target, past_covariate, static_covariates, future_target) 组成的 np.ndarray。协变量是可选的,可以为 None。
- class darts.utils.data.training_dataset.SplitCovariatesTrainingDataset[source]¶
基础类:
TrainingDataset
,ABC
SplitCovariatesTorchModel 训练数据集的抽象类。它包含由 (past_target, past_covariates, future_covariates, static_covariates, future_target) 组成的 np.ndarray。协变量是可选的,可以为 None。
- class darts.utils.data.training_dataset.TrainingDataset[source]¶
基础类:
ABC
,Dataset
Darts 中所有 torch 模型训练数据集的超类。这些包括
- “PastCovariates” 数据集(用于 PastCovariatesTorchModel):包含 (past_target,
past_covariates, static_covariates, future_target)
- “FutureCovariates” 数据集(用于 FutureCovariatesTorchModel):包含 (past_target,
future_covariates, static_covariates, future_target)
- “DualCovariates” 数据集(用于 DualCovariatesTorchModel):包含 (past_target,
historic_future_covariates, future_covariates, static_covariates, future_target)
- “MixedCovariates” 数据集(用于 MixedCovariatesTorchModel):包含 (past_target,
past_covariates, historic_future_covariates, future_covariates, static_covariates, future_target)
- “SplitCovariates” 数据集(用于 SplitCovariatesTorchModel):包含 (past_target,
past_covariates, future_covariates, static_covariates, future_target)
协变量是可选的,可以为 None。
这旨在用于训练(或验证),除了 future_target 之外的所有数据都表示模型输入(future_target 是模型训练预测的输出)。
Darts 的 TorchForecastingModel 可以使用 fit_from_dataset() 方法从正确类型的 TrainingDataset 实例中拟合。
TrainingDataset 继承了 torch Dataset;这意味着实现必须提供 __getitem__() 方法。
它包含 np.ndarray(而不是 TimeSeries),因为训练只需要值,因此在切片时,通过仅返回 TimeSeries 底层数据的 numpy 视图,我们可以获得显著的性能提升。