Source code for etna.transforms.outliers.base

from abc import ABC
from abc import abstractmethod
from typing import Dict
from typing import List
from typing import Optional

import numpy as np
import pandas as pd

from etna.datasets import TSDataset
from etna.transforms.base import ReversibleTransform
from etna.transforms.utils import check_new_segments


[docs]class OutliersTransform(ReversibleTransform, ABC): """Finds outliers in specific columns of DataFrame and replaces it with NaNs.""" def __init__(self, in_column: str): """ Create instance of OutliersTransform. Parameters ---------- in_column: name of processed column """ super().__init__(required_features=[in_column]) self.in_column = in_column self.outliers_timestamps: Optional[Dict[str, List[pd.Timestamp]]] = None self.original_values: Optional[Dict[str, List[pd.Timestamp]]] = None self._fit_segments: Optional[List[str]] = None
[docs] def get_regressors_info(self) -> List[str]: """Return the list with regressors created by the transform. Returns ------- : List with regressors created by the transform. """ return []
def _save_original_values(self, ts: TSDataset): """ Save values to be replaced with NaNs. Parameters ---------- ts: original TSDataset """ if self.outliers_timestamps is None: raise ValueError("Something went wrong during outliers detection stage! Check the transform parameters.") self.original_values = dict() for segment, timestamps in self.outliers_timestamps.items(): segment_ts = ts[:, segment, :] segment_values = segment_ts[segment_ts.index.isin(timestamps)].droplevel("segment", axis=1)[self.in_column] self.original_values[segment] = segment_values def _fit(self, df: pd.DataFrame) -> "OutliersTransform": """ Find outliers using detection method. Parameters ---------- df: dataframe with series to find outliers Returns ------- result: OutliersTransform instance with saved outliers """ ts = TSDataset(df, freq=pd.infer_freq(df.index)) self.outliers_timestamps = self.detect_outliers(ts) self._save_original_values(ts) self._fit_segments = ts.segments return self def _transform(self, df: pd.DataFrame) -> pd.DataFrame: """ Replace found outliers with NaNs. Parameters ---------- df: transform ``in_column`` series of given dataframe Returns ------- result: dataframe with in_column series with filled with NaNs Raises ------ ValueError: If transform isn't fitted. NotImplementedError: If there are segments that weren't present during training. """ if self.outliers_timestamps is None: raise ValueError("Transform is not fitted! Fit the Transform before calling transform method.") segments = df.columns.get_level_values("segment").unique().tolist() check_new_segments(transform_segments=segments, fit_segments=self._fit_segments) for segment in segments: # to locate only present indices segment_outliers_timestamps = df.index.intersection(self.outliers_timestamps[segment]) df.loc[segment_outliers_timestamps, pd.IndexSlice[segment, self.in_column]] = np.NaN return df def _inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: """ Inverse transformation. Returns back deleted values. Parameters ---------- df: data to transform Returns ------- result: data with reconstructed values Raises ------ ValueError: If transform isn't fitted. NotImplementedError: If there are segments that weren't present during training. """ if self.original_values is None or self.outliers_timestamps is None: raise ValueError("Transform is not fitted! Fit the Transform before calling inverse_transform method.") segments = df.columns.get_level_values("segment").unique().tolist() check_new_segments(transform_segments=segments, fit_segments=self._fit_segments) for segment in segments: segment_ts = df[segment, self.in_column] segment_ts[segment_ts.index.isin(self.outliers_timestamps[segment])] = self.original_values[segment] return df
[docs] @abstractmethod def detect_outliers(self, ts: TSDataset) -> Dict[str, List[pd.Timestamp]]: """Call function for detection outliers with self parameters. Parameters ---------- ts: dataset to process Returns ------- : dict of outliers in format {segment: [outliers_timestamps]} """ pass