NBeatsInterpretableModel

class NBeatsInterpretableModel(input_size: int, output_size: int, loss: Union[Literal['mse'], Literal['mae'], Literal['smape'], Literal['mape'], torch.nn.modules.module.Module] = 'mse', trend_blocks: int = 3, trend_layers: int = 4, trend_layer_size: int = 256, degree_of_polynomial: int = 2, seasonality_blocks: int = 3, seasonality_layers: int = 4, seasonality_layer_size: int = 2048, num_of_harmonics: int = 1, lr: float = 0.001, window_sampling_limit: Optional[int] = None, optimizer_params: Optional[dict] = None, train_batch_size: int = 1024, test_batch_size: int = 1024, trainer_params: Optional[dict] = None, train_dataloader_params: Optional[dict] = None, test_dataloader_params: Optional[dict] = None, val_dataloader_params: Optional[dict] = None, split_params: Optional[dict] = None, random_state: Optional[int] = None)[source]

Bases: etna.models.nn.nbeats.nbeats.NBeatsBaseModel

Interpretable N-BEATS model.

Paper: https://arxiv.org/pdf/1905.10437.pdf

Official implementation: https://github.com/ServiceNow/N-BEATS

Init interpretable N-BEATS model.

Parameters
  • input_size (int) – Input data size.

  • output_size (int) – Forecast size.

  • loss (Union[Literal['mse'], typing.Literal['mae'], typing.Literal['smape'], typing.Literal['mape'], torch.nn.Module]) – Optimisation objective. The loss function should accept three arguments: y_true, y_pred and mask. The last parameter is a binary mask that denotes which points are valid forecasts. There are several implemented loss functions available in the etna.models.nn.nbeats.metrics module.

  • trend_blocks (int) – Number of trend blocks.

  • trend_layers (int) – Number of inner layers in each trend block.

  • trend_layer_size (int) – Inner layer size in trend blocks.

  • degree_of_polynomial (int) – Polynomial degree for trend modeling.

  • seasonality_blocks (int) – Number of seasonality blocks.

  • seasonality_layers (int) – Number of inner layers in each seasonality block.

  • seasonality_layer_size (int) – Inner layer size in seasonality blocks.

  • num_of_harmonics (int) – Number of harmonics for seasonality estimation.

  • lr (float) – Optimizer learning rate.

  • window_sampling_limit (Optional[int]) – Size of history for sampling training data. If set to None full series history used for sampling.

  • optimizer_params (Optional[dict]) – Additional parameters for the optimizer.

  • train_batch_size (int) – Batch size for training.

  • test_batch_size (int) – Batch size for testing.

  • optimizer_params – Parameters for optimizer for Adam optimizer (api reference torch.optim.Adam).

  • trainer_params (Optional[dict]) – Pytorch lightning trainer parameters (api reference pytorch_lightning.trainer.trainer.Trainer).

  • train_dataloader_params (Optional[dict]) – Parameters for train dataloader like sampler for example (api reference torch.utils.data.DataLoader).

  • test_dataloader_params (Optional[dict]) – Parameters for test dataloader.

  • val_dataloader_params (Optional[dict]) – Parameters for validation dataloader.

  • split_params (Optional[dict]) –

    Dictionary with parameters for torch.utils.data.random_split() for train-test splitting
    • train_size: (float) value from 0 to 1 - fraction of samples to use for training

    • generator: (Optional[torch.Generator]) - generator for reproducibile train-test splitting

    • torch_dataset_size: (Optional[int]) - number of samples in dataset, in case of dataset not implementing __len__

  • random_state (Optional[int]) – Random state for train batches generation.

Inherited-members

Methods

fit(ts)

Fit model.

forecast(ts, prediction_size[, ...])

Make predictions.

get_model()

Get model.

load(path)

Load an object.

params_to_tune()

Get default grid for tuning hyperparameters.

predict(ts, prediction_size[, return_components])

Make predictions.

raw_fit(torch_dataset)

Fit model on torch like Dataset.

raw_predict(torch_dataset)

Make inference on torch like Dataset.

save(path)

Save the object.

set_params(**params)

Return new object instance with modified parameters.

to_dict()

Collect all information about etna object in dict.

Attributes

context_size

Context size of the model.

params_to_tune() Dict[str, etna.distributions.distributions.BaseDistribution][source]

Get default grid for tuning hyperparameters.

This grid tunes parameters: trend_blocks, trend_layers, trend_layer_size, degree_of_polynomial, seasonality_blocks, seasonality_layers, seasonality_layer_size, lr. Other parameters are expected to be set by the user.

Returns

Grid to tune.

Return type

Dict[str, etna.distributions.distributions.BaseDistribution]