Source code for etna.transforms.encoders.segment_encoder
import reprlib
from typing import List
import numpy as np
import pandas as pd
from sklearn import preprocessing
from etna.transforms.base import FutureMixin
from etna.transforms.base import IrreversibleTransform
[docs]class SegmentEncoderTransform(IrreversibleTransform, FutureMixin):
"""Encode segment label to categorical. Creates column 'segment_code'."""
idx = pd.IndexSlice
def __init__(self):
super().__init__(required_features=["target"])
self._le = preprocessing.LabelEncoder()
def _fit(self, df: pd.DataFrame) -> "SegmentEncoderTransform":
"""
Fit encoder on existing segment labels.
Parameters
----------
df:
dataframe with data to fit label encoder.
Returns
-------
:
Fitted transform
"""
segment_columns = df.columns.get_level_values("segment")
self._le.fit(segment_columns)
return self
def _transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Get encoded (categorical) for each segment.
Parameters
----------
df:
dataframe with data to transform.
Returns
-------
:
result dataframe
Raises
------
ValueError:
If transform isn't fitted.
NotImplementedError:
If there are segments that weren't present during training.
"""
segments = df.columns.get_level_values("segment").unique().tolist()
try:
new_segments = set(segments) - set(self._le.classes_)
except AttributeError:
raise ValueError("The transform isn't fitted!")
if len(new_segments) > 0:
raise NotImplementedError(
f"This transform can't process segments that weren't present on train data: {reprlib.repr(new_segments)}"
)
encoded_matrix = self._le.transform(segments)
encoded_matrix = np.tile(encoded_matrix, (len(df), 1))
encoded_df = pd.DataFrame(
encoded_matrix,
columns=pd.MultiIndex.from_product([segments, ["segment_code"]], names=("segment", "feature")),
index=df.index,
)
encoded_df = encoded_df.astype("category")
df = df.join(encoded_df)
df = df.sort_index(axis=1)
return df
[docs] def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
return ["segment_code"]