NBeatsInterpretableNet

class NBeatsInterpretableNet(input_size: int, output_size: int, loss: torch.nn.modules.module.Module, trend_blocks: int, trend_layers: int, trend_layer_size: int, degree_of_polynomial: int, seasonality_blocks: int, seasonality_layers: int, seasonality_layer_size: int, num_of_harmonics: int, lr: float, optimizer_params: Optional[Dict[str, Any]] = None)[source]

Bases: etna.models.nn.nbeats.nets.NBeatsBaseNet

Interpretable N-BEATS model.

Initialize N-BEATS model.

Parameters
  • input_size (int) – Input data size.

  • output_size (int) – Forecast size.

  • loss (torch.nn.Module) – Optimisation objective. The loss function should accept three arguments: y_true, y_pred and mask. The last parameter is a binary mask that denotes which points are valid forecasts.

  • trend_blocks (int) – Number of trend blocks.

  • trend_layers (int) – Number of inner layers in each trend block.

  • trend_layer_size (int) – Inner layer size in trend blocks.

  • degree_of_polynomial (int) – Polynomial degree for trend modeling.

  • seasonality_blocks (int) – Number of seasonality blocks.

  • seasonality_layers (int) – Number of inner layers in each seasonality block.

  • seasonality_layer_size (int) – Inner layer size in seasonality blocks.

  • num_of_harmonics (int) – Number of harmonics for seasonality estimation.

  • lr (float) – Optimizer learning rate.

  • optimizer_params (Optional[Dict[str, Any]]) – Additional parameters for the optimizer.

Attributes