from typing import Tuple
import numpy as np
from etna import SETTINGS
if SETTINGS.torch_required:
import torch
import torch.nn as nn
[docs]class NBeatsBlock(nn.Module):
"""Base N-BEATS block which takes a basis function as an argument."""
def __init__(self, input_size: int, theta_size: int, basis_function: "nn.Module", num_layers: int, layer_size: int):
"""N-BEATS block.
Parameters
----------
input_size:
In-sample size.
theta_size:
Number of parameters for the basis function.
basis_function:
Basis function which takes the parameters and produces backcast and forecast.
num_layers:
Number of layers.
layer_size
Layer size.
"""
super().__init__()
layers = [nn.Linear(in_features=input_size, out_features=layer_size), nn.ReLU()]
for _ in range(num_layers - 1):
layers.append(nn.Linear(in_features=layer_size, out_features=layer_size))
layers.append(nn.ReLU())
self.layers = nn.ModuleList(layers)
self.basis_parameters = nn.Linear(in_features=layer_size, out_features=theta_size)
self.basis_function = basis_function
[docs] def forward(self, x: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor"]:
"""Forward pass.
Parameters
----------
x:
Input data.
Returns
-------
:
Tuple with backcast and forecast.
"""
for layer in self.layers:
x = layer(x)
basis_parameters = self.basis_parameters(x)
return self.basis_function(basis_parameters)
[docs]class GenericBasis(nn.Module):
"""Generic basis function."""
def __init__(self, backcast_size: int, forecast_size: int):
"""Initialize generic basis function.
Parameters
----------
backcast_size:
Number of backcast values.
forecast_size:
Number of forecast values.
"""
super().__init__()
self.backcast_size = backcast_size
self.forecast_size = forecast_size
[docs] def forward(self, theta: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor"]:
"""Forward pass.
Parameters
----------
theta:
Basis function parameters.
Returns
-------
:
Tuple with backcast and forecast.
"""
return theta[:, : self.backcast_size], theta[:, -self.forecast_size :]
[docs]class TrendBasis(nn.Module):
"""Polynomial trend basis function."""
def __init__(self, degree: int, backcast_size: int, forecast_size: int):
"""Initialize trend basis function.
Parameters
----------
degree:
Degree of polynomial for trend modeling.
backcast_size:
Number of backcast values.
forecast_size:
Number of forecast values.
"""
super().__init__()
self.num_poly_terms = degree + 1
self.backcast_time = nn.Parameter(self._trend_tensor(size=backcast_size), requires_grad=False)
self.forecast_time = nn.Parameter(self._trend_tensor(size=forecast_size), requires_grad=False)
def _trend_tensor(self, size: int) -> "torch.Tensor":
"""Prepare trend tensor."""
time = torch.arange(size) / size
degrees = torch.arange(self.num_poly_terms)
trend_tensor = torch.transpose(time[:, None] ** degrees[None], 0, 1)
return trend_tensor
[docs] def forward(self, theta: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor"]:
"""Forward pass.
Parameters
----------
theta:
Basis function parameters.
Returns
-------
:
Tuple with backcast and forecast.
"""
backcast = theta[:, : self.num_poly_terms] @ self.backcast_time
forecast = theta[:, self.num_poly_terms :] @ self.forecast_time
return backcast, forecast
[docs]class SeasonalityBasis(nn.Module):
"""Harmonic seasonality basis function."""
def __init__(self, harmonics: int, backcast_size: int, forecast_size: int):
"""Initialize seasonality basis function.
Parameters
----------
harmonics:
Harmonics range.
backcast_size:
Number of backcast values.
forecast_size:
Number of forecast values.
"""
super().__init__()
freq = torch.arange(harmonics - 1, harmonics / 2 * forecast_size) / harmonics
freq[0] = 0.0
frequency = torch.unsqueeze(freq, 0)
backcast_grid = -2 * np.pi * torch.arange(backcast_size)[:, None] / forecast_size
backcast_grid = backcast_grid * frequency
forecast_grid = 2 * np.pi * torch.arange(forecast_size)[:, None] / forecast_size
forecast_grid = forecast_grid * frequency
self.backcast_cos_template = nn.Parameter(torch.transpose(torch.cos(backcast_grid), 0, 1), requires_grad=False)
self.backcast_sin_template = nn.Parameter(torch.transpose(torch.sin(backcast_grid), 0, 1), requires_grad=False)
self.forecast_cos_template = nn.Parameter(torch.transpose(torch.cos(forecast_grid), 0, 1), requires_grad=False)
self.forecast_sin_template = nn.Parameter(torch.transpose(torch.sin(forecast_grid), 0, 1), requires_grad=False)
[docs] def forward(self, theta: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor"]:
"""Forward pass.
Parameters
----------
theta:
Basis function parameters.
Returns
-------
:
Tuple with backcast and forecast.
"""
params_per_harmonic = theta.shape[1] // 4
backcast_harmonics_cos = theta[:, :params_per_harmonic] @ self.backcast_cos_template
backcast_harmonics_sin = theta[:, params_per_harmonic : 2 * params_per_harmonic] @ self.backcast_sin_template
backcast = backcast_harmonics_sin + backcast_harmonics_cos
forecast_harmonics_cos = (
theta[:, 2 * params_per_harmonic : 3 * params_per_harmonic] @ self.forecast_cos_template
)
forecast_harmonics_sin = theta[:, 3 * params_per_harmonic :] @ self.forecast_sin_template
forecast = forecast_harmonics_sin + forecast_harmonics_cos
return backcast, forecast
[docs]class NBeats(nn.Module):
"""N-BEATS model."""
def __init__(self, blocks: "nn.ModuleList"):
"""Initialize N-BEATS model.
Parameters
----------
blocks:
Model blocks.
"""
super().__init__()
self.blocks = blocks
[docs] def forward(self, x: "torch.Tensor", input_mask: "torch.Tensor") -> "torch.Tensor":
"""Forward pass.
Parameters
----------
x:
Input data.
input_mask:
Input mask.
Returns
-------
:
Forecast tensor.
"""
residuals = x.flip(dims=(1,))
input_mask = input_mask.flip(dims=(1,))
forecast = x[:, -1:]
for block in self.blocks:
backcast, block_forecast = block(residuals)
residuals = (residuals - backcast) * input_mask
forecast = forecast + block_forecast
return forecast