DeepBaseModel

class DeepBaseModel(*, net: etna.models.base.DeepBaseNet, encoder_length: int, decoder_length: int, train_batch_size: int, test_batch_size: int, trainer_params: Optional[dict], train_dataloader_params: Optional[dict], test_dataloader_params: Optional[dict], val_dataloader_params: Optional[dict], split_params: Optional[dict])[source]

Bases: etna.models.base.DeepBaseAbstractModel, etna.models.mixins.SaveNNMixin, etna.models.base.NonPredictionIntervalContextRequiredAbstractModel

Class for partially implemented interfaces for holding deep models.

Init DeepBaseModel.

Parameters
  • net (etna.models.base.DeepBaseNet) – network to train

  • encoder_length (int) – encoder length

  • decoder_length (int) – decoder length

  • train_batch_size (int) – batch size for training

  • test_batch_size (int) – batch size for testing

  • trainer_params (Optional[dict]) – Pytorch ligthning trainer parameters (api reference pytorch_lightning.trainer.trainer.Trainer)

  • train_dataloader_params (Optional[dict]) – parameters for train dataloader like sampler for example (api reference torch.utils.data.DataLoader)

  • test_dataloader_params (Optional[dict]) – parameters for test dataloader

  • val_dataloader_params (Optional[dict]) – parameters for validation dataloader

  • split_params (Optional[dict]) –

    dictionary with parameters for torch.utils.data.random_split() for train-test splitting
    • train_size: (float) value from 0 to 1 - fraction of samples to use for training

    • generator: (Optional[torch.Generator]) - generator for reproducibile train-test splitting

    • torch_dataset_size: (Optional[int]) - number of samples in dataset, in case of dataset not implementing __len__

Inherited-members

Methods

fit(ts)

Fit model.

forecast(ts, prediction_size[, ...])

Make predictions.

get_model()

Get model.

load(path)

Load an object.

params_to_tune()

Get grid for tuning hyperparameters.

predict(ts, prediction_size[, return_components])

Make predictions.

raw_fit(torch_dataset)

Fit model on torch like Dataset.

raw_predict(torch_dataset)

Make inference on torch like Dataset.

save(path)

Save the object.

set_params(**params)

Return new object instance with modified parameters.

to_dict()

Collect all information about etna object in dict.

Attributes

context_size

Context size of the model.

fit(ts: etna.datasets.tsdataset.TSDataset) etna.models.base.DeepBaseModel[source]

Fit model.

Parameters

ts (etna.datasets.tsdataset.TSDataset) – TSDataset with features

Returns

Model after fit

Return type

etna.models.base.DeepBaseModel

forecast(ts: etna.datasets.tsdataset.TSDataset, prediction_size: int, return_components: bool = False) etna.datasets.tsdataset.TSDataset[source]

Make predictions.

This method will make autoregressive predictions.

Parameters
  • ts (etna.datasets.tsdataset.TSDataset) – Dataset with features and expected decoder length for context

  • prediction_size (int) – Number of last timestamps to leave after making prediction. Previous timestamps will be used as a context.

  • return_components (bool) – If True additionally returns forecast components

Returns

Dataset with predictions

Return type

etna.datasets.tsdataset.TSDataset

get_model() etna.models.base.DeepBaseNet[source]

Get model.

Returns

Torch Module

Return type

etna.models.base.DeepBaseNet

predict(ts: etna.datasets.tsdataset.TSDataset, prediction_size: int, return_components: bool = False) etna.datasets.tsdataset.TSDataset[source]

Make predictions.

This method will make predictions using true values instead of predicted on a previous step. It can be useful for making in-sample forecasts.

Parameters
  • ts (etna.datasets.tsdataset.TSDataset) – Dataset with features and expected decoder length for context

  • prediction_size (int) – Number of last timestamps to leave after making prediction. Previous timestamps will be used as a context.

  • return_components (bool) – If True additionally returns prediction components

Returns

Dataset with predictions

Return type

etna.datasets.tsdataset.TSDataset

raw_fit(torch_dataset: torch.utils.data.dataset.Dataset) etna.models.base.DeepBaseModel[source]

Fit model on torch like Dataset.

Parameters

torch_dataset (torch.utils.data.dataset.Dataset) – Torch like dataset for model fit

Returns

Model after fit

Return type

etna.models.base.DeepBaseModel

raw_predict(torch_dataset: torch.utils.data.dataset.Dataset) Dict[Tuple[str, str], numpy.ndarray][source]

Make inference on torch like Dataset.

Parameters

torch_dataset (torch.utils.data.dataset.Dataset) – Torch like dataset for model inference

Returns

Dictionary with predictions

Return type

Dict[Tuple[str, str], numpy.ndarray]

property context_size: int

Context size of the model.