from abc import ABC
from abc import abstractmethod
from typing import Optional
import pandas as pd
from scipy.sparse import csr_matrix
from etna.core import BaseMixin
from etna.datasets import TSDataset
from etna.datasets.utils import get_level_dataframe
[docs]class BaseReconciliator(ABC, BaseMixin):
"""Base class to hold reconciliation methods."""
def __init__(self, target_level: str, source_level: str):
"""Init BaseReconciliator.
Parameters
----------
target_level:
Level to be reconciled from the forecasts.
source_level:
Level to be forecasted.
"""
self.target_level = target_level
self.source_level = source_level
self.mapping_matrix: Optional[csr_matrix] = None
[docs] @abstractmethod
def fit(self, ts: TSDataset) -> "BaseReconciliator":
"""Fit the reconciliator parameters.
Parameters
----------
ts:
TSDataset on the level which is lower or equal to ``target_level``, ``source_level``.
Returns
-------
:
Fitted instance of reconciliator.
"""
pass
[docs] def aggregate(self, ts: TSDataset) -> TSDataset:
"""Aggregate the dataset to the ``source_level``.
Parameters
----------
ts:
TSDataset on the level which is lower or equal to ``source_level``.
Returns
-------
:
TSDataset on the ``source_level``.
"""
ts_aggregated = ts.get_level_dataset(target_level=self.source_level)
return ts_aggregated
[docs] def reconcile(self, ts: TSDataset) -> TSDataset:
"""Reconcile the forecasts in the dataset.
Parameters
----------
ts:
TSDataset on the ``source_level``.
Returns
-------
:
TSDataset on the ``target_level``.
"""
if self.mapping_matrix is None:
raise ValueError(f"Reconciliator is not fitted!")
if ts.hierarchical_structure is None:
raise ValueError(f"Passed dataset has no hierarchical structure!")
if ts.current_df_level != self.source_level:
raise ValueError(f"Dataset should be on the {self.source_level} level!")
current_level_segments = ts.hierarchical_structure.get_level_segments(level_name=self.source_level)
target_level_segments = ts.hierarchical_structure.get_level_segments(level_name=self.target_level)
target_names = ts.target_quantiles_names + ts.target_components_names + ("target",)
df_reconciled = get_level_dataframe(
df=ts.to_pandas(features=target_names),
mapping_matrix=self.mapping_matrix,
source_level_segments=current_level_segments,
target_level_segments=target_level_segments,
)
target_components_df = df_reconciled.loc[:, pd.IndexSlice[:, ts.target_components_names]]
if len(ts.target_components_names) > 0: # for pandas >=1.1, <1.2
df_reconciled = df_reconciled.drop(columns=list(ts.target_components_names), level="feature")
ts_reconciled = TSDataset(
df=df_reconciled,
freq=ts.freq,
df_exog=ts.df_exog,
known_future=ts.known_future,
hierarchical_structure=ts.hierarchical_structure,
)
if len(ts.target_components_names) > 0:
ts_reconciled.add_target_components(target_components_df=target_components_df)
return ts_reconciled