DirectEnsemble¶
- class DirectEnsemble(pipelines: List[etna.pipeline.base.BasePipeline], n_jobs: int = 1, joblib_params: Optional[Dict[str, Any]] = None)[source]¶
Bases:
etna.ensembles.mixins.EnsembleMixin
,etna.ensembles.mixins.SaveEnsembleMixin
,etna.pipeline.base.BasePipeline
DirectEnsemble is a pipeline that forecasts future values merging the forecasts of base pipelines.
Ensemble expects several pipelines during init. These pipelines are expected to have different forecasting horizons. For each point in the future, forecast of the ensemble is forecast of base pipeline with the shortest horizon, which covers this point.
Examples
>>> from etna.datasets import generate_ar_df >>> from etna.datasets import TSDataset >>> from etna.ensembles import DirectEnsemble >>> from etna.models import NaiveModel >>> from etna.models import ProphetModel >>> from etna.pipeline import Pipeline >>> df = generate_ar_df(periods=30, start_time="2021-06-01", ar_coef=[1.2], n_segments=3) >>> df_ts_format = TSDataset.to_dataset(df) >>> ts = TSDataset(df_ts_format, "D") >>> prophet_pipeline = Pipeline(model=ProphetModel(), transforms=[], horizon=3) >>> naive_pipeline = Pipeline(model=NaiveModel(lag=10), transforms=[], horizon=5) >>> ensemble = DirectEnsemble(pipelines=[prophet_pipeline, naive_pipeline]) >>> _ = ensemble.fit(ts=ts) >>> forecast = ensemble.forecast() >>> forecast segment segment_0 segment_1 segment_2 feature target target target timestamp 2021-07-01 -10.37 -232.60 163.16 2021-07-02 -10.59 -242.05 169.62 2021-07-03 -11.41 -253.82 177.62 2021-07-04 -5.85 -139.57 96.99 2021-07-05 -6.11 -167.69 116.59
Init DirectEnsemble.
- Parameters
pipelines (List[etna.pipeline.base.BasePipeline]) – List of pipelines that should be used in ensemble
n_jobs (int) – Number of jobs to run in parallel
joblib_params (Optional[Dict[str, Any]]) – Additional parameters for
joblib.Parallel
- Raises
ValueError: – If two or more pipelines have the same horizons.
- Inherited-members
Methods
backtest
(ts, metrics[, n_folds, mode, ...])Run backtest with the pipeline.
fit
(ts)Fit pipelines in ensemble.
forecast
([ts, prediction_interval, ...])Make a forecast of the next points of a dataset.
load
(path[, ts])Load an object.
Get hyperparameter grid to tune.
predict
(ts[, start_timestamp, ...])Make in-sample predictions on dataset in a given range.
save
(path)Save the object.
set_params
(**params)Return new object instance with modified parameters.
to_dict
()Collect all information about etna object in dict.
Attributes
- fit(ts: etna.datasets.tsdataset.TSDataset) etna.ensembles.direct_ensemble.DirectEnsemble [source]¶
Fit pipelines in ensemble.
- Parameters
ts (etna.datasets.tsdataset.TSDataset) – TSDataset to fit ensemble
- Returns
Fitted ensemble
- Return type
self
- params_to_tune() Dict[str, etna.distributions.distributions.BaseDistribution] [source]¶
Get hyperparameter grid to tune.
Not implemented for this class.
- Returns
Grid with hyperparameters.
- Return type
Dict[str, etna.distributions.distributions.BaseDistribution]