Source code for etna.transforms.timestamp.holiday

import datetime
from typing import List
from typing import Optional

import holidays
import numpy as np
import pandas as pd

from etna.transforms.base import FutureMixin
from etna.transforms.base import IrreversibleTransform


[docs]class HolidayTransform(IrreversibleTransform, FutureMixin): """HolidayTransform generates series that indicates holidays in given dataframe.""" def __init__(self, iso_code: str = "RUS", out_column: Optional[str] = None): """ Create instance of HolidayTransform. Parameters ---------- iso_code: internationally recognised codes, designated to country for which we want to find the holidays out_column: name of added column. Use ``self.__repr__()`` if not given. """ super().__init__(required_features=["target"]) self.iso_code = iso_code self.holidays = holidays.CountryHoliday(iso_code) self.out_column = out_column def _get_column_name(self) -> str: if self.out_column: return self.out_column else: return self.__repr__() def _fit(self, df: pd.DataFrame) -> "HolidayTransform": """ Fit HolidayTransform with data from df. Does nothing in this case. Parameters ---------- df: pd.DataFrame value series with index column in timestamp format """ return self def _transform(self, df: pd.DataFrame) -> pd.DataFrame: """ Transform data from df with HolidayTransform and generate a column of holidays flags. Parameters ---------- df: pd.DataFrame value series with index column in timestamp format Returns ------- : pd.DataFrame with added holidays """ if (df.index[1] - df.index[0]) > datetime.timedelta(days=1): raise ValueError("Frequency of data should be no more than daily.") cols = df.columns.get_level_values("segment").unique() out_column = self._get_column_name() encoded_matrix = np.array([int(x in self.holidays) for x in df.index]) encoded_matrix = encoded_matrix.reshape(-1, 1).repeat(len(cols), axis=1) encoded_df = pd.DataFrame( encoded_matrix, columns=pd.MultiIndex.from_product([cols, [out_column]], names=("segment", "feature")), index=df.index, ) encoded_df = encoded_df.astype("category") df = df.join(encoded_df) df = df.sort_index(axis=1) return df
[docs] def get_regressors_info(self) -> List[str]: """Return the list with regressors created by the transform. Returns ------- : List with regressors created by the transform. """ return [self._get_column_name()]