DeepStateModel¶
- class DeepStateModel(ssm: etna.models.nn.deepstate.state_space_model.CompositeSSM, input_size: int, encoder_length: int, decoder_length: int, num_layers: int = 1, n_samples: int = 5, lr: float = 0.001, train_batch_size: int = 16, test_batch_size: int = 16, optimizer_params: Optional[dict] = None, trainer_params: Optional[dict] = None, train_dataloader_params: Optional[dict] = None, test_dataloader_params: Optional[dict] = None, val_dataloader_params: Optional[dict] = None, split_params: Optional[dict] = None)[source]¶
Bases:
etna.models.base.DeepBaseModel
DeepState model.
Init Deep State Model.
- Parameters
ssm (etna.models.nn.deepstate.state_space_model.CompositeSSM) – state Space Model of the system
input_size (int) – size of the input feature space: features for RNN part.
encoder_length (int) – encoder length
decoder_length (int) – decoder length
num_layers (int) – number of layers in RNN
n_samples (int) – number of samples to use in predictions generation
num_layers – number of layers
lr (float) – learning rate
train_batch_size (int) – batch size for training
test_batch_size (int) – batch size for testing
optimizer_params (Optional[dict]) – parameters for optimizer for Adam optimizer (api reference
torch.optim.Adam
)trainer_params (Optional[dict]) – Pytorch ligthning trainer parameters (api reference
pytorch_lightning.trainer.trainer.Trainer
)train_dataloader_params (Optional[dict]) – parameters for train dataloader like sampler for example (api reference
torch.utils.data.DataLoader
)test_dataloader_params (Optional[dict]) – parameters for test dataloader
val_dataloader_params (Optional[dict]) – parameters for validation dataloader
split_params (Optional[dict]) –
- dictionary with parameters for
torch.utils.data.random_split()
for train-test splitting train_size: (float) value from 0 to 1 - fraction of samples to use for training
generator: (Optional[torch.Generator]) - generator for reproducibile train-test splitting
torch_dataset_size: (Optional[int]) - number of samples in dataset, in case of dataset not implementing
__len__
- dictionary with parameters for
- Inherited-members
Methods
fit
(ts)Fit model.
forecast
(ts, prediction_size[, ...])Make predictions.
get_model
()Get model.
load
(path)Load an object.
params_to_tune
()Get grid for tuning hyperparameters.
predict
(ts, prediction_size[, return_components])Make predictions.
raw_fit
(torch_dataset)Fit model on torch like Dataset.
raw_predict
(torch_dataset)Make inference on torch like Dataset.
save
(path)Save the object.
set_params
(**params)Return new object instance with modified parameters.
to_dict
()Collect all information about etna object in dict.
Attributes
context_size
Context size of the model.