PatchTSNet

class PatchTSNet(encoder_length: int, patch_len: int, stride: int, num_layers: int, hidden_size: int, feedforward_size: int, nhead: int, lr: float, loss: torch.nn.modules.module.Module, optimizer_params: Optional[dict])[source]

Bases: etna.models.base.DeepBaseNet

PatchTS based Lightning module.

Init PatchTS.

Parameters
  • encoder_length (int) – encoder length

  • patch_len (int) – size of patch

  • stride (int) – step of patch

  • num_layers (int) – number of layers

  • hidden_size (int) – size of the hidden state

  • feedforward_size (int) – size of feedforward layers in transformer

  • nhead (int) – number of transformer heads

  • lr (float) – learning rate

  • loss (torch.nn.Module) – loss function

  • optimizer_params (Optional[dict]) – parameters for optimizer for Adam optimizer (api reference torch.optim.Adam)

Return type

None

Methods

configure_optimizers()

Optimizer configuration.

forward(x, *args, **kwargs)

Forward pass.

make_samples(df, encoder_length, decoder_length)

Make samples from segment DataFrame.

step(batch, *args, **kwargs)

Step for loss computation for training or validation.

Attributes

configure_optimizers() torch.optim.optimizer.Optimizer[source]

Optimizer configuration.

Return type

torch.optim.optimizer.Optimizer

forward(x: etna.models.nn.patchts.PatchTSBatch, *args, **kwargs)[source]

Forward pass.

Parameters

x (etna.models.nn.patchts.PatchTSBatch) – batch of data

Returns

forecast with shape (batch_size, decoder_length, 1)

make_samples(df: pandas.core.frame.DataFrame, encoder_length: int, decoder_length: int) Iterator[dict][source]

Make samples from segment DataFrame.

Parameters
  • df (pandas.core.frame.DataFrame) –

  • encoder_length (int) –

  • decoder_length (int) –

Return type

Iterator[dict]

step(batch: etna.models.nn.patchts.PatchTSBatch, *args, **kwargs)[source]

Step for loss computation for training or validation.

Parameters

batch (etna.models.nn.patchts.PatchTSBatch) – batch of data

Returns

loss, true_target, prediction_target