使用 Informer 模型进行多元概率时间序列预测

发布于 2023 年 3 月 10 日
在 GitHub 上更新
Open In Colab

引言

几个月前,我们介绍了 Time Series Transformer,它是应用于预测任务的原始 Transformer 模型(Vaswani et al., 2017),并展示了一个 单变量 概率预测任务的示例(即单独预测每个时间序列的一维分布)。在本文中,我们将介绍 Informer 模型(Zhou, Haoyi, et al., 2021),这篇 AAAI 2021 的最佳论文模型现已 加入 🤗 Transformers。我们将演示如何使用 Informer 模型进行 多元 概率预测任务,即预测未来时间序列目标值 向量 的分布。请注意,本文介绍的方法也适用于原始的 Time Series Transformer 模型。

多元概率时间序列预测

就概率预测的建模方面而言,Transformer/Informer 在处理多元时间序列时无需改变模型结构。在单变量和多变量设置中,模型都接收一个向量序列,因此唯一的变化在于输出或发射(emission)端。

对高维数据的完整联合条件分布进行建模,计算成本可能会非常高。因此,各种方法会采用一些近似分布,最简单的是将数据建模为来自同一分布族的独立分布,或对完整协方差矩阵进行某种低秩近似等。在这里,我们将采用独立(或对角)发射,我们所实现的分布族(见 这里)均支持这种方法。

Informer 模型深入解析

Informer 基于原始的 Transformer(Vaswani et al., 2017),并引入了两项主要改进。为了理解这些改进,我们先回顾一下原始 Transformer 的缺点:

  1. 标准自注意力机制的平方级计算复杂度: 原始 Transformer 的计算复杂度为 O(T2D)O(T^2 D),其中 TT 是时间序列的长度,DD 是隐藏状态的维度。对于长序列时间序列预测(也称为 LSTF 问题),这种计算成本可能非常高。为了解决这个问题,Informer 采用了一种新的自注意力机制,称为 ProbSparse 注意力,其时间和空间复杂度为 O(TlogT)O(T \log T)
  2. 堆叠层时的内存瓶颈: 当堆叠 NN 个编码器/解码器层时,原始 Transformer 的内存使用量为 O(NT2)O(N T^2),这限制了模型处理长序列的能力。Informer 使用了一种 蒸馏 (Distilling) 操作,将层与层之间的输入大小缩减为其一半。通过这种方式,它将总内存使用量降低到 O(NTlogT)O(N\cdot T \log T)

正如你所见,Informer 模型的动机与 Longformer (Beltagy et el., 2020)、Sparse Transformer (Child et al., 2019) 以及其他 NLP 论文类似,都是为了在 输入序列很长时 降低自注意力机制的平方级复杂度。现在,让我们通过代码示例深入了解 ProbSparse 注意力和 蒸馏 操作。

ProbSparse 注意力机制

ProbSparse 的核心思想是,标准自注意力的得分呈长尾分布,其中“活跃”的查询 (query) 位于得分的“头部”,而“懒惰”的查询则位于“尾部”。所谓“活跃”查询,指的是查询 qiq_i,其点积 qi,ki\langle q_i,k_i \rangle 贡献了 主要的注意力权重,而“懒惰”查询的点积则生成 微不足道 的注意力权重。这里,qiq_ikik_i 分别是注意力矩阵 QQKK 的第 ii 行。

informer_full_vs_sparse_attention
标准自注意力 vs ProbSparse 注意力,图片来自 Autoformer (Wu, Haixu, et al., 2021)

基于“活跃”和“懒惰”查询的理念,ProbSparse 注意力机制会选出“活跃”的查询,并创建一个简化的查询矩阵 QreducedQ_{reduced},用于以 O(TlogT)O(T \log T) 的复杂度计算注意力权重。让我们通过一个代码示例来更详细地了解这一点。

回顾一下标准自注意力的公式:

Attention(Q,K,V)=softmax(QKTdk)V \textrm{Attention}(Q, K, V) = \textrm{softmax}(\frac{QK^T}{\sqrt{d_k}} )V

其中 QRLQ×dQ\in \mathbb{R}^{L_Q \times d}, KRLK×dK\in \mathbb{R}^{L_K \times d}, VRLV×dV\in \mathbb{R}^{L_V \times d}。注意,在实践中,自注意力计算中的查询和键的输入长度通常是相等的,即 LQ=LK=TL_Q = L_K = T,其中 TT 是时间序列的长度。因此,QKTQK^T 乘法的计算复杂度为 O(T2d)O(T^2 \cdot d)。在 ProbSparse 注意力中,我们的目标是创建一个新的 QreduceQ_{reduce} 矩阵并定义:

ProbSparseAttention(Q,K,V)=softmax(QreduceKTdk)V \textrm{ProbSparseAttention}(Q, K, V) = \textrm{softmax}(\frac{Q_{reduce}K^T}{\sqrt{d_k}} )V

其中 QreduceQ_{reduce} 矩阵仅选择前 uu 个“活跃”查询。这里,u=clogLQu = c \cdot \log L_Qcc 是 ProbSparse 注意力的超参数,称为 采样因子 (sampling factor)。由于 QreduceQ_{reduce} 仅选择前 uu 个查询,其大小为 clogLQ×dc\cdot \log L_Q \times d,因此乘法 QreduceKTQ_{reduce}K^T 仅需 O(LKlogLQ)=O(TlogT)O(L_K \log L_Q) = O(T \log T) 的计算量。

这很好!但是我们如何选择这 uu 个“活跃”查询来创建 QreduceQ_{reduce} 呢?让我们来定义 查询稀疏性度量 (Query Sparsity Measurement)

查询稀疏性度量

查询稀疏性度量 M(qi,K)M(q_i, K) 用于从 QQ 中选出 uu 个“活跃”查询 qiq_i 来创建 QreduceQ_{reduce}。理论上,占主导地位的 qi,ki\langle q_i,k_i \rangle 点积对会促使“活跃”查询 qiq_i 的概率分布 偏离 均匀分布,如下图所示。因此,实际查询分布与均匀分布之间的 KL 散度 (Kullback–Leibler divergence) 被用来定义稀疏性度量。

informer_probsparse
ProbSparse 注意力机制图解,来自官方 代码库

在实践中,该度量定义为:

M(qi,K)=maxjqikjTd1Lkj=1LkqikjTd M(q_i, K) = \max_j \frac{q_ik_j^T}{\sqrt{d}}-\frac{1}{L_k} \sum_{j=1}^{L_k}\frac{q_ik_j^T}{\sqrt{d}}

这里需要理解的重点是,当 M(qi,K)M(q_i, K) 越大时,查询 qiq_i 就应该被包含在 QreduceQ_{reduce} 中,反之亦然。

但是我们如何以非平方级的时间复杂度计算 qikjTq_ik_j^T 这一项呢?回想一下,大多数点积 qi,ki\langle q_i,k_i \rangle 无论如何都会生成微不足道的注意力(即长尾分布特性),因此从 KK 中随机采样一个键 (key) 的子集就足够了,这个子集在代码中称为 K_sample

现在,我们可以查看 `probsparse_attention` 的代码了。

from torch import nn
import math


def probsparse_attention(query_states, key_states, value_states, sampling_factor=5):
    """
    Compute the probsparse self-attention.
    Input shape: Batch x Time x Channel

    Note the additional `sampling_factor` input.
    """
    # get input sizes with logs
    L_K = key_states.size(1)
    L_Q = query_states.size(1)
    log_L_K = np.ceil(np.log1p(L_K)).astype("int").item()
    log_L_Q = np.ceil(np.log1p(L_Q)).astype("int").item()

    # calculate a subset of samples to slice from K and create Q_K_sample
    U_part = min(sampling_factor * L_Q * log_L_K, L_K)

    # create Q_K_sample (the q_i * k_j^T term in the sparsity measurement)
    index_sample = torch.randint(0, L_K, (U_part,))
    K_sample = key_states[:, index_sample, :]
    Q_K_sample = torch.bmm(query_states, K_sample.transpose(1, 2))

    # calculate the query sparsity measurement with Q_K_sample
    M = Q_K_sample.max(dim=-1)[0] - torch.div(Q_K_sample.sum(dim=-1), L_K)

    # calculate u to find the Top-u queries under the sparsity measurement
    u = min(sampling_factor * log_L_Q, L_Q)
    M_top = M.topk(u, sorted=False)[1]

    # calculate Q_reduce as query_states[:, M_top]
    dim_for_slice = torch.arange(query_states.size(0)).unsqueeze(-1)
    Q_reduce = query_states[dim_for_slice, M_top]  # size: c*log_L_Q x channel

    # and now, same as the canonical
    d_k = query_states.size(-1)
    attn_scores = torch.bmm(Q_reduce, key_states.transpose(-2, -1))  # Q_reduce x K^T
    attn_scores = attn_scores / math.sqrt(d_k)
    attn_probs = nn.functional.softmax(attn_scores, dim=-1)
    attn_output = torch.bmm(attn_probs, value_states)

    return attn_output, attn_scores

注意,在实现中,为了稳定性,UpartU_{part} 在计算中包含了 LQL_Q(更多信息请参阅 此讨论)。

我们成功了!请注意,这只是 probsparse_attention 的部分实现,完整实现可在 🤗 Transformers 中找到。

蒸馏 (Distilling)

由于 ProbSparse 自注意力机制的存在,编码器的特征图存在一些可以被移除的冗余。因此,蒸馏操作被用来将编码器层之间的输入大小减少一半,从而在理论上去除这种冗余。在实践中,Informer 的“蒸馏”操作只是在每个编码器层之间添加了带最大池化的一维卷积层。设 XnX_n 为第 nn 个编码器层的输出,则蒸馏操作定义为

Xn+1=MaxPool(ELU(Conv1d(Xn)) X_{n+1} = \textrm{MaxPool} ( \textrm{ELU}(\textrm{Conv1d}(X_n))

让我们在代码中看看这个操作。

from torch import nn

# ConvLayer is a class with forward pass applying ELU and MaxPool1d
def informer_encoder_forward(x_input, num_encoder_layers=3, distil=True):
    # Initialize the convolution layers
    if distil:
        conv_layers = nn.ModuleList([ConvLayer() for _ in range(num_encoder_layers - 1)])
        conv_layers.append(None)
    else:
        conv_layers = [None] * num_encoder_layers
    
    # Apply conv_layer between each encoder_layer
    for encoder_layer, conv_layer in zip(encoder_layers, conv_layers):
        output = encoder_layer(x_input)
        if conv_layer is not None:
            output = conv_layer(loutput)
    
    return output

通过将每层的输入减少一半,我们的内存使用量从 O(NT2)O(N\cdot T^2) 降至 O(NTlogT)O(N\cdot T \log T),其中 NN 是编码器/解码器层的数量。这正是我们想要的!

Informer 模型现已加入 🤗 Transformers 库,其名称就是 InformerModel。在下面的章节中,我们将展示如何在一个自定义的多元时间序列数据集上训练这个模型。

配置环境

首先,让我们安装必要的库:🤗 Transformers、🤗 Datasets、🤗 Evaluate、🤗 Accelerate 和 GluonTS

正如我们将要展示的,GluonTS 将用于转换数据以创建特征,以及创建适当的训练、验证和测试批次。

!pip install -q transformers datasets evaluate accelerate gluonts ujson

加载数据集

在这篇博文中,我们将使用 traffic_hourly 数据集,该数据集可在 Hugging Face Hub 上找到。这个数据集包含了 Lai 等人 (2017) 使用的旧金山交通数据集。它包含 862 个小时时间序列,显示了 2015 年至 2016 年旧金山湾区高速公路的道路占用率,范围在 [0,1][0, 1] 之间。

该数据集是 Monash 时间序列预测 存储库的一部分,这是一个汇集了多个领域时间序列数据集的集合。它可以被看作是时间序列预测领域的 GLUE 基准

from datasets import load_dataset

dataset = load_dataset("monash_tsf", "traffic_hourly")

可以看到,该数据集包含 3 个划分:训练集、验证集和测试集。

dataset

>>> DatasetDict({
        train: Dataset({
            features: ['start', 'target', 'feat_static_cat', 'feat_dynamic_real', 'item_id'],
            num_rows: 862
        })
        test: Dataset({
            features: ['start', 'target', 'feat_static_cat', 'feat_dynamic_real', 'item_id'],
            num_rows: 862
        })
        validation: Dataset({
            features: ['start', 'target', 'feat_static_cat', 'feat_dynamic_real', 'item_id'],
            num_rows: 862
        })
    })

每个样本都包含几个键,其中 starttarget 是最重要的。让我们看一下数据集中的第一个时间序列。

train_example = dataset["train"][0]
train_example.keys()

>>> dict_keys(['start', 'target', 'feat_static_cat', 'feat_dynamic_real', 'item_id'])

start 简单地表示时间序列的开始时间(以 datetime 格式),而 target 包含时间序列的实际值。

start 对于向时间序列值添加时间相关特征(例如“一年中的月份”)作为模型的额外输入非常有用。由于我们知道数据的频率是 hourly(每小时),我们知道例如第二个值的时间戳是 2015-01-01 01:00:012015-01-01 02:00:01,依此类推。

print(train_example["start"])
print(len(train_example["target"]))

>>> 2015-01-01 00:00:01
    17448

验证集包含与训练集相同的数据,只是时间上延长了 prediction_length。这使我们能够根据真实值来验证模型的预测。

测试集的数据又比验证集长一个 prediction_length (或者相对于训练集长了 prediction_length 的某个倍数,用于在多个滚动窗口上进行测试)。

validation_example = dataset["validation"][0]
validation_example.keys()

>>> dict_keys(['start', 'target', 'feat_static_cat', 'feat_dynamic_real', 'item_id'])

初始值与相应的训练样本完全相同。但是,这个样本比训练样本多了 prediction_length=48(48 小时或 2 天)个额外的值。让我们来验证一下。

freq = "1H"
prediction_length = 48

assert len(train_example["target"]) + prediction_length == len(
    dataset["validation"][0]["target"]
)

让我们将其可视化。

import matplotlib.pyplot as plt

num_of_samples = 150

figure, axes = plt.subplots()
axes.plot(train_example["target"][-num_of_samples:], color="blue")
axes.plot(
    validation_example["target"][-num_of_samples - prediction_length :],
    color="red",
    alpha=0.5,
)

plt.show()

png

让我们对数据进行划分。

train_dataset = dataset["train"]
test_dataset = dataset["test"]

将 `start` 更新为 `pd.Period`

我们要做的第一件事是使用数据的 freq 将每个时间序列的 start 特征转换为 pandas 的 Period 索引。

from functools import lru_cache

import pandas as pd
import numpy as np


@lru_cache(10_000)
def convert_to_pandas_period(date, freq):
    return pd.Period(date, freq)


def transform_start_field(batch, freq):
    batch["start"] = [convert_to_pandas_period(date, freq) for date in batch["start"]]
    return batch

我们现在使用 datasetsset_transform 功能来即时地、原地完成这个转换。

from functools import partial

train_dataset.set_transform(partial(transform_start_field, freq=freq))
test_dataset.set_transform(partial(transform_start_field, freq=freq))

现在,让我们使用 GluonTS 中的 MultivariateGrouper 将数据集转换为多元时间序列。这个 grouper 会将独立的 1 维时间序列转换成一个 2D 矩阵。

from gluonts.dataset.multivariate_grouper import MultivariateGrouper

num_of_variates = len(train_dataset)

train_grouper = MultivariateGrouper(max_target_dim=num_of_variates)
test_grouper = MultivariateGrouper(
    max_target_dim=num_of_variates,
    num_test_dates=len(test_dataset) // num_of_variates, # number of rolling test windows
)

multi_variate_train_dataset = train_grouper(train_dataset)
multi_variate_test_dataset = test_grouper(test_dataset)

注意,目标现在是 2 维的,其中第一维是变量的数量(时间序列的数量),第二维是时间序列的值(时间维度)。

multi_variate_train_example = multi_variate_train_dataset[0]
print("multi_variate_train_example["target"].shape =", multi_variate_train_example["target"].shape)

>>> multi_variate_train_example["target"].shape = (862, 17448)

定义模型

接下来,让我们实例化一个模型。该模型将从头开始训练,因此我们不会使用 from_pretrained 方法,而是根据 config 来随机初始化模型。

我们为模型指定了几个额外的参数。

  • prediction_length(在我们的例子中是 48 小时):这是 Informer 解码器将学习预测的时间范围;
  • context_length:如果未指定 context_length,模型会将 context_length(编码器的输入)设置为与 prediction_length 相等;
  • 给定频率的 lags:这些参数指定了一种高效的“回看”机制,我们将过去的值与当前值连接起来作为额外的特征。例如,对于 Daily 频率,我们可能会考虑 [1, 7, 30, ...] 的回看,而对于 Minute 数据,我们可能会考虑 [1, 30, 60, 60*24, ...] 等;
  • 时间特征的数量:在我们的例子中,这个值将是 5,因为我们将添加 HourOfDayDayOfWeek、... 和 Age 特征(详见下文)。

让我们检查一下 GluonTS 为给定频率(“hourly”)提供的默认 lags。

from gluonts.time_feature import get_lags_for_frequency

lags_sequence = get_lags_for_frequency(freq)
print(lags_sequence)

>>> [1, 2, 3, 4, 5, 6, 7, 23, 24, 25, 47, 48, 49, 71, 72, 73, 95, 96, 97, 119, 120, 
     121, 143, 144, 145, 167, 168, 169, 335, 336, 337, 503, 504, 505, 671, 672, 673, 719, 720, 721]

这意味着对于每个时间步,它会回看最多 721 小时(约 30 天)作为额外特征。然而,由此产生的特征向量大小将是 len(lags_sequence)*num_of_variates,在我们的情况下将是 34480!这是行不通的,所以我们将使用我们自己设定的合理 lags。

我们再检查一下 GluonTS 提供的默认时间特征。

from gluonts.time_feature import time_features_from_frequency_str

time_features = time_features_from_frequency_str(freq)
print(time_features)

>>> [<function hour_of_day at 0x7f3809539240>, <function day_of_week at 0x7f3809539360>, <function day_of_month at 0x7f3809539480>, <function day_of_year at 0x7f38095395a0>]

在这种情况下,有四个额外的特征,即“小时”、“星期几”、“月中的天”和“年中的天”。这意味着对于每个时间步,我们将把这些特征作为标量值添加进去。例如,考虑时间戳 2015-01-01 01:00:01,这四个额外的特征将会是

from pandas.core.arrays.period import period_array

timestamp = pd.Period("2015-01-01 01:00:01", freq=freq)
timestamp_as_index = pd.PeriodIndex(data=period_array([timestamp]))
additional_features = [
    (time_feature.__name__, time_feature(timestamp_as_index))
    for time_feature in time_features
]
print(dict(additional_features))

>>> {'hour_of_day': array([-0.45652174]), 'day_of_week': array([0.]), 'day_of_month': array([-0.5]), 'day_of_year': array([-0.5])}

请注意,GluonTS 将小时和天编码为 [-0.5, 0.5] 之间的值。有关 time_features 的更多信息,请参阅此链接。除了这 4 个特征外,我们还将在数据转换中添加一个“age”特征,具体如下。

我们现在已经具备了定义模型所需的一切。

from transformers import InformerConfig, InformerForPrediction

config = InformerConfig(
    # in the multivariate setting, input_size is the number of variates in the time series per time step
    input_size=num_of_variates,
    # prediction length:
    prediction_length=prediction_length,
    # context length:
    context_length=prediction_length * 2,
    # lags value copied from 1 week before:
    lags_sequence=[1, 24 * 7],
    # we'll add 5 time features ("hour_of_day", ..., and "age"):
    num_time_features=len(time_features) + 1,
    
    # informer params:
    dropout=0.1,
    encoder_layers=6,
    decoder_layers=4,
    # project input from num_of_variates*len(lags_sequence)+num_time_features to:
    d_model=64,
)

model = InformerForPrediction(config)

默认情况下,模型使用对角学生 t 分布(但这是可配置的)。

model.config.distribution_output

>>> 'student_t'

定义转换

接下来,我们定义数据的转换,特别是用于创建时间特征(基于数据集或通用特征)的转换。

我们再次使用 GluonTS 库来完成此操作。我们定义了一个转换链(Chain),这有点类似于图像处理中的 torchvision.transforms.Compose。它允许我们将多个转换组合成一个单一的流水线。

from gluonts.time_feature import TimeFeature
from gluonts.dataset.field_names import FieldName
from gluonts.transform import (
    AddAgeFeature,
    AddObservedValuesIndicator,
    AddTimeFeatures,
    AsNumpyArray,
    Chain,
    ExpectedNumInstanceSampler,
    InstanceSplitter,
    RemoveFields,
    SelectFields,
    SetField,
    TestSplitSampler,
    Transformation,
    ValidationSplitSampler,
    VstackFeatures,
    RenameFields,
)

下面的转换都附有注释,以解释它们的作用。总的来说,我们将遍历数据集中的各个时间序列,并添加/删除字段或特征。

from transformers import PretrainedConfig


def create_transformation(freq: str, config: PretrainedConfig) -> Transformation:
    # create list of fields to remove later
    remove_field_names = []
    if config.num_static_real_features == 0:
        remove_field_names.append(FieldName.FEAT_STATIC_REAL)
    if config.num_dynamic_real_features == 0:
        remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
    if config.num_static_categorical_features == 0:
        remove_field_names.append(FieldName.FEAT_STATIC_CAT)

    return Chain(
        # step 1: remove static/dynamic fields if not specified
        [RemoveFields(field_names=remove_field_names)]
        # step 2: convert the data to NumPy (potentially not needed)
        + (
            [
                AsNumpyArray(
                    field=FieldName.FEAT_STATIC_CAT,
                    expected_ndim=1,
                    dtype=int,
                )
            ]
            if config.num_static_categorical_features > 0
            else []
        )
        + (
            [
                AsNumpyArray(
                    field=FieldName.FEAT_STATIC_REAL,
                    expected_ndim=1,
                )
            ]
            if config.num_static_real_features > 0
            else []
        )
        + [
            AsNumpyArray(
                field=FieldName.TARGET,
                # we expect an extra dim for the multivariate case:
                expected_ndim=1 if config.input_size == 1 else 2,
            ),
            # step 3: handle the NaN's by filling in the target with zero
            # and return the mask (which is in the observed values)
            # true for observed values, false for nan's
            # the decoder uses this mask (no loss is incurred for unobserved values)
            # see loss_weights inside the xxxForPrediction model
            AddObservedValuesIndicator(
                target_field=FieldName.TARGET,
                output_field=FieldName.OBSERVED_VALUES,
            ),
            # step 4: add temporal features based on freq of the dataset
            # these serve as positional encodings
            AddTimeFeatures(
                start_field=FieldName.START,
                target_field=FieldName.TARGET,
                output_field=FieldName.FEAT_TIME,
                time_features=time_features_from_frequency_str(freq),
                pred_length=config.prediction_length,
            ),
            # step 5: add another temporal feature (just a single number)
            # tells the model where in the life the value of the time series is
            # sort of running counter
            AddAgeFeature(
                target_field=FieldName.TARGET,
                output_field=FieldName.FEAT_AGE,
                pred_length=config.prediction_length,
                log_scale=True,
            ),
            # step 6: vertically stack all the temporal features into the key FEAT_TIME
            VstackFeatures(
                output_field=FieldName.FEAT_TIME,
                input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
                + (
                    [FieldName.FEAT_DYNAMIC_REAL]
                    if config.num_dynamic_real_features > 0
                    else []
                ),
            ),
            # step 7: rename to match HuggingFace names
            RenameFields(
                mapping={
                    FieldName.FEAT_STATIC_CAT: "static_categorical_features",
                    FieldName.FEAT_STATIC_REAL: "static_real_features",
                    FieldName.FEAT_TIME: "time_features",
                    FieldName.TARGET: "values",
                    FieldName.OBSERVED_VALUES: "observed_mask",
                }
            ),
        ]
    )

定义 `InstanceSplitter`

接下来,为了进行训练/验证/测试,我们创建一个 InstanceSplitter,它用于从数据集中采样窗口(记住,由于时间和内存的限制,我们不能将整个历史值传递给模型)。

instance splitter 从数据中随机采样大小为 context_length 的窗口以及紧随其后的 prediction_length 大小的窗口,并为相应窗口中 time_series_fields 里的任何时间键添加 past_future_ 前缀。instance splitter 可以配置为三种不同的模式。

  1. mode="train":在这种模式下,我们从给定数据集(训练数据集)中随机采样上下文和预测长度窗口。
  2. mode="validation":在这种模式下,我们从给定数据集(用于回溯测试或验证似然计算)中采样最后一个上下文长度窗口和预测窗口。
  3. mode="test":在这种模式下,我们仅采样最后一个上下文长度窗口(用于预测用例)。
from gluonts.transform.sampler import InstanceSampler
from typing import Optional


def create_instance_splitter(
    config: PretrainedConfig,
    mode: str,
    train_sampler: Optional[InstanceSampler] = None,
    validation_sampler: Optional[InstanceSampler] = None,
) -> Transformation:
    assert mode in ["train", "validation", "test"]

    instance_sampler = {
        "train": train_sampler
        or ExpectedNumInstanceSampler(
            num_instances=1.0, min_future=config.prediction_length
        ),
        "validation": validation_sampler
        or ValidationSplitSampler(min_future=config.prediction_length),
        "test": TestSplitSampler(),
    }[mode]

    return InstanceSplitter(
        target_field="values",
        is_pad_field=FieldName.IS_PAD,
        start_field=FieldName.START,
        forecast_start_field=FieldName.FORECAST_START,
        instance_sampler=instance_sampler,
        past_length=config.context_length + max(config.lags_sequence),
        future_length=config.prediction_length,
        time_series_fields=["time_features", "observed_mask"],
    )

创建 DataLoaders

接下来,是时候创建 DataLoaders了,它让我们能够获取成批的 (输入, 输出) 对——换句话说就是 (past_values, future_values)。

from typing import Iterable

import torch
from gluonts.itertools import Cached, Cyclic
from gluonts.dataset.loader import as_stacked_batches


def create_train_dataloader(
    config: PretrainedConfig,
    freq,
    data,
    batch_size: int,
    num_batches_per_epoch: int,
    shuffle_buffer_length: Optional[int] = None,
    cache_data: bool = True,
    **kwargs,
) -> Iterable:
    PREDICTION_INPUT_NAMES = [
        "past_time_features",
        "past_values",
        "past_observed_mask",
        "future_time_features",
    ]
    if config.num_static_categorical_features > 0:
        PREDICTION_INPUT_NAMES.append("static_categorical_features")

    if config.num_static_real_features > 0:
        PREDICTION_INPUT_NAMES.append("static_real_features")

    TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
        "future_values",
        "future_observed_mask",
    ]

    transformation = create_transformation(freq, config)
    transformed_data = transformation.apply(data, is_train=True)
    if cache_data:
        transformed_data = Cached(transformed_data)

    # we initialize a Training instance
    instance_splitter = create_instance_splitter(config, "train")

    # the instance splitter will sample a window of
    # context length + lags + prediction length (from all the possible transformed time series, 1 in our case)
    # randomly from within the target time series and return an iterator.
    stream = Cyclic(transformed_data).stream()
    training_instances = instance_splitter.apply(stream)
    
    return as_stacked_batches(
        training_instances,
        batch_size=batch_size,
        shuffle_buffer_length=shuffle_buffer_length,
        field_names=TRAINING_INPUT_NAMES,
        output_type=torch.tensor,
        num_batches_per_epoch=num_batches_per_epoch,
    )
def create_backtest_dataloader(
    config: PretrainedConfig,
    freq,
    data,
    batch_size: int,
    **kwargs,
):
    PREDICTION_INPUT_NAMES = [
        "past_time_features",
        "past_values",
        "past_observed_mask",
        "future_time_features",
    ]
    if config.num_static_categorical_features > 0:
        PREDICTION_INPUT_NAMES.append("static_categorical_features")

    if config.num_static_real_features > 0:
        PREDICTION_INPUT_NAMES.append("static_real_features")

    transformation = create_transformation(freq, config)
    transformed_data = transformation.apply(data)

    # we create a Validation Instance splitter which will sample the very last
    # context window seen during training only for the encoder.
    instance_sampler = create_instance_splitter(config, "validation")

    # we apply the transformations in train mode
    testing_instances = instance_sampler.apply(transformed_data, is_train=True)
    
    return as_stacked_batches(
        testing_instances,
        batch_size=batch_size,
        output_type=torch.tensor,
        field_names=PREDICTION_INPUT_NAMES,
    )

def create_test_dataloader(
    config: PretrainedConfig,
    freq,
    data,
    batch_size: int,
    **kwargs,
):
    PREDICTION_INPUT_NAMES = [
        "past_time_features",
        "past_values",
        "past_observed_mask",
        "future_time_features",
    ]
    if config.num_static_categorical_features > 0:
        PREDICTION_INPUT_NAMES.append("static_categorical_features")

    if config.num_static_real_features > 0:
        PREDICTION_INPUT_NAMES.append("static_real_features")

    transformation = create_transformation(freq, config)
    transformed_data = transformation.apply(data, is_train=False)

    # We create a test Instance splitter to sample the very last
    # context window from the dataset provided.
    instance_sampler = create_instance_splitter(config, "test")

    # We apply the transformations in test mode
    testing_instances = instance_sampler.apply(transformed_data, is_train=False)
    
    return as_stacked_batches(
        testing_instances,
        batch_size=batch_size,
        output_type=torch.tensor,
        field_names=PREDICTION_INPUT_NAMES,
    )
train_dataloader = create_train_dataloader(
    config=config,
    freq=freq,
    data=multi_variate_train_dataset,
    batch_size=256,
    num_batches_per_epoch=100,
    num_workers=2,
)

test_dataloader = create_backtest_dataloader(
    config=config,
    freq=freq,
    data=multi_variate_test_dataset,
    batch_size=32,
)

让我们检查第一个批次。

batch = next(iter(train_dataloader))
for k, v in batch.items():
    print(k, v.shape, v.type())

>>> past_time_features torch.Size([256, 264, 5]) torch.FloatTensor
    past_values torch.Size([256, 264, 862]) torch.FloatTensor
    past_observed_mask torch.Size([256, 264, 862]) torch.FloatTensor
    future_time_features torch.Size([256, 48, 5]) torch.FloatTensor
    future_values torch.Size([256, 48, 862]) torch.FloatTensor
    future_observed_mask torch.Size([256, 48, 862]) torch.FloatTensor

可以看到,我们没有像 NLP 模型那样向编码器提供 input_idsattention_mask,而是提供了 past_values,以及 past_observed_maskpast_time_featuresstatic_real_features

解码器的输入包括 future_valuesfuture_observed_maskfuture_time_featuresfuture_values 可以看作是 NLP 中 decoder_input_ids 的等价物。

关于每个参数的详细解释,请参阅文档

前向传播

让我们用刚刚创建的批次进行一次前向传播。

# perform forward pass
outputs = model(
    past_values=batch["past_values"],
    past_time_features=batch["past_time_features"],
    past_observed_mask=batch["past_observed_mask"],
    static_categorical_features=batch["static_categorical_features"]
    if config.num_static_categorical_features > 0
    else None,
    static_real_features=batch["static_real_features"]
    if config.num_static_real_features > 0
    else None,
    future_values=batch["future_values"],
    future_time_features=batch["future_time_features"],
    future_observed_mask=batch["future_observed_mask"],
    output_hidden_states=True,
)
print("Loss:", outputs.loss.item())

>>> Loss: -1071.5718994140625

请注意,模型返回了一个损失值。这是因为解码器自动将 future_values 向右移动一个位置以获得标签。这使得我们可以在预测值和标签之间计算损失。损失是预测分布相对于真实值的负对数似然,它会趋向于负无穷大。

另请注意,解码器使用因果掩码来避免看到未来,因为它需要预测的值位于 future_values 张量中。

训练模型

是时候训练模型了!我们将使用一个标准的 PyTorch 训练循环。

我们将在这里使用 🤗 Accelerate 库,它会自动将模型、优化器和数据加载器放置在适当的 device 上。

from accelerate import Accelerator
from torch.optim import AdamW

epochs = 25
loss_history = []

accelerator = Accelerator()
device = accelerator.device

model.to(device)
optimizer = AdamW(model.parameters(), lr=6e-4, betas=(0.9, 0.95), weight_decay=1e-1)

model, optimizer, train_dataloader = accelerator.prepare(
    model,
    optimizer,
    train_dataloader,
)

model.train()
for epoch in range(epochs):
    for idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        outputs = model(
            static_categorical_features=batch["static_categorical_features"].to(device)
            if config.num_static_categorical_features > 0
            else None,
            static_real_features=batch["static_real_features"].to(device)
            if config.num_static_real_features > 0
            else None,
            past_time_features=batch["past_time_features"].to(device),
            past_values=batch["past_values"].to(device),
            future_time_features=batch["future_time_features"].to(device),
            future_values=batch["future_values"].to(device),
            past_observed_mask=batch["past_observed_mask"].to(device),
            future_observed_mask=batch["future_observed_mask"].to(device),
        )
        loss = outputs.loss

        # Backpropagation
        accelerator.backward(loss)
        optimizer.step()

        loss_history.append(loss.item())
        if idx % 100 == 0:
            print(loss.item())

>>> -1081.978515625
    ...
    -2877.723876953125
# view training
loss_history = np.array(loss_history).reshape(-1)
x = range(loss_history.shape[0])
plt.figure(figsize=(10, 5))
plt.plot(x, loss_history, label="train")
plt.title("Loss", fontsize=15)
plt.legend(loc="upper right")
plt.xlabel("iteration")
plt.ylabel("nll")
plt.show()

png

推理

在推理时,建议使用 generate() 方法进行自回归生成,这与 NLP 模型类似。

预测过程涉及从测试实例采样器中获取数据,该采样器将从数据集中每个时间序列的最后一个 context_length 大小的窗口中采样值,并将其传递给模型。请注意,我们将 future_time_features(这些是预先知道的)传递给解码器。

模型将从预测的分布中自回归地采样一定数量的值,并将它们传回解码器以返回预测输出。

model.eval()

forecasts_ = []

for batch in test_dataloader:
    outputs = model.generate(
        static_categorical_features=batch["static_categorical_features"].to(device)
        if config.num_static_categorical_features > 0
        else None,
        static_real_features=batch["static_real_features"].to(device)
        if config.num_static_real_features > 0
        else None,
        past_time_features=batch["past_time_features"].to(device),
        past_values=batch["past_values"].to(device),
        future_time_features=batch["future_time_features"].to(device),
        past_observed_mask=batch["past_observed_mask"].to(device),
    )
    forecasts_.append(outputs.sequences.cpu().numpy())

模型输出一个形状为 (batch_size, number of samples, prediction length, input_size) 的张量。

在这种情况下,我们为 862 个时间序列中的每一个,都得到了未来 48 小时的 100 个可能值(对于批次中的每个样本,批次大小为 1,因为我们只有一个多元时间序列)。

forecasts_[0].shape

>>> (1, 100, 48, 862)

我们将它们垂直堆叠,以获得测试数据集中所有时间序列的预测(以防测试集中有更多的时间序列)。

forecasts = np.vstack(forecasts_)
print(forecasts.shape)

>>> (1, 100, 48, 862)

我们可以将得到的预测与测试集中存在的样本外真实值进行评估。为此,我们将使用 🤗 Evaluate 库,其中包括 MASEsMAPE 指标。

我们为数据集中的每个时间序列变量计算这两个指标。

from evaluate import load
from gluonts.time_feature import get_seasonality

mase_metric = load("evaluate-metric/mase")
smape_metric = load("evaluate-metric/smape")

forecast_median = np.median(forecasts, 1).squeeze(0).T

mase_metrics = []
smape_metrics = []

for item_id, ts in enumerate(test_dataset):
    training_data = ts["target"][:-prediction_length]
    ground_truth = ts["target"][-prediction_length:]
    mase = mase_metric.compute(
        predictions=forecast_median[item_id],
        references=np.array(ground_truth),
        training=np.array(training_data),
        periodicity=get_seasonality(freq),
    )
    mase_metrics.append(mase["mase"])

    smape = smape_metric.compute(
        predictions=forecast_median[item_id],
        references=np.array(ground_truth),
    )
    smape_metrics.append(smape["smape"])
print(f"MASE: {np.mean(mase_metrics)}")

>>> MASE: 1.1913437728068093

print(f"sMAPE: {np.mean(smape_metrics)}")

>>> sMAPE: 0.5322665081607634
plt.scatter(mase_metrics, smape_metrics, alpha=0.2)
plt.xlabel("MASE")
plt.ylabel("sMAPE")
plt.show()

png

为了绘制任何时间序列变量相对于真实测试数据的预测图,我们定义了以下辅助函数。

import matplotlib.dates as mdates


def plot(ts_index, mv_index):
    fig, ax = plt.subplots()

    index = pd.period_range(
        start=multi_variate_test_dataset[ts_index][FieldName.START],
        periods=len(multi_variate_test_dataset[ts_index][FieldName.TARGET]),
        freq=multi_variate_test_dataset[ts_index][FieldName.START].freq,
    ).to_timestamp()

    ax.xaxis.set_minor_locator(mdates.HourLocator())

    ax.plot(
        index[-2 * prediction_length :],
        multi_variate_test_dataset[ts_index]["target"][mv_index, -2 * prediction_length :],
        label="actual",
    )

    ax.plot(
        index[-prediction_length:],
        forecasts[ts_index, ..., mv_index].mean(axis=0),
        label="mean",
    )
    ax.fill_between(
        index[-prediction_length:],
        forecasts[ts_index, ..., mv_index].mean(0)
        - forecasts[ts_index, ..., mv_index].std(axis=0),
        forecasts[ts_index, ..., mv_index].mean(0)
        + forecasts[ts_index, ..., mv_index].std(axis=0),
        alpha=0.2,
        interpolate=True,
        label="+/- 1-std",
    )
    ax.legend()
    fig.autofmt_xdate()

例如:

plot(0, 344)

png

结论

我们如何与其他模型进行比较?Monash 时间序列存储库有一个测试集 MASE 指标的比较表,我们可以将我们的结果添加到其中。

数据集 SES Theta TBATS ETS (DHR-)ARIMA PR CatBoost FFNN DeepAR N-BEATS WaveNet Transformer (uni.) Informer (mv. our)
Traffic Hourly 1.922 1.922 2.482 2.294 2.535 1.281 1.571 0.892 0.825 1.100 1.066 0.821 1.191

可以看出,也许令一些人惊讶的是,多元预测通常比单变量预测要*差*,原因在于估计跨序列相关性/关系的难度。估计值增加的额外方差常常损害最终的预测,或者模型学习到虚假的关联。我们推荐阅读这篇论文以获取更多信息。多元模型在大量数据上训练时往往表现良好。

所以,原始的 Transformer 在这里仍然表现最好!未来,我们希望在一个集中的地方更好地对这些模型进行基准测试,以便更容易地复现多篇论文的结果。敬请期待更多内容!

资源

我们建议查看 Informer 文档以及本博文顶部链接的示例 notebook

社区

你好,

感谢这篇非常有用的博客。
当我尝试从 Monash-University/monash_tsf 加载任何小时数据集时,出现 "DatasetGenerationError: An error occurred while generating the dataset" 这个错误。
我能够加载除小时数据集之外的其他数据集。
您知道为什么吗?
如何解决这个问题?
我想尝试多元概率时间序列预测。我可以使用哪些数据集?

谢谢,
Nouda

·
文章作者

导致错误的确切命令是什么?

注册登录 以发表评论