NBeatsGenericNet¶
- class NBeatsGenericNet(input_size: int, output_size: int, loss: torch.nn.modules.module.Module, stacks: int, layers: int, layer_size: int, lr: float, optimizer_params: Optional[Dict[str, Any]] = None)[source]¶
Bases:
etna.models.nn.nbeats.nets.NBeatsBaseNet
N-BEATS generic model.
Initialize N-BEATS model.
- Parameters
input_size (int) – Input data size.
output_size (int) – Forecast size.
loss (nn.Module) – Optimisation objective. The loss function should accept three arguments:
y_true
,y_pred
andmask
. The last parameter is a binary mask that denotes which points are valid forecasts.stacks (int) – Number of block stacks in model.
layers (int) – Number of inner layers in each block.
layer_size (int) – Inner layers size in blocks.
lr (float) – Optimizer learning rate.
optimizer_params (Optional[Dict[str, Any]]) – Additional parameters for the optimizer.
Attributes