Skip to content

Commit 0c2e7f6

Browse files
committed
Add prediction type to return the mean, variance, and mode
Additionally, this change avoids unnecessary sampling if the prediction type doesn't need it.
1 parent ed939a8 commit 0c2e7f6

File tree

2 files changed

+99
-6
lines changed

2 files changed

+99
-6
lines changed

lightgbmlss/distributions/distribution_utils.py

+96-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import pandas as pd
99
from tqdm import tqdm
10+
from scipy import stats
1011

1112
from typing import Any, Dict, Optional, List, Tuple
1213
import matplotlib.pyplot as plt
@@ -334,12 +335,89 @@ def draw_samples(self,
334335
dist_samples = dist_samples.astype(int)
335336

336337
return dist_samples
338+
339+
def get_moments(self,
340+
predt_params: pd.DataFrame,
341+
inference: str = "none",
342+
n_samples: int = 1000,
343+
seed: int = 123
344+
) -> pd.DataFrame:
345+
"""
346+
Function that returns moments (mean, variance, mode) of a predicted distribution.
347+
348+
Arguments
349+
---------
350+
predt_params: pd.DataFrame
351+
pd.DataFrame with predicted distributional parameters.
352+
inference: str
353+
Type of inference from drawn samples:
354+
- "none" (default) Will return only the exact, implemented moments.
355+
- "missing" Will infer moments for missing implementations by drawing samples.
356+
- "all" Will infer all moments by drawing samples.
357+
n_samples: int
358+
Number of sample to draw from predicted response distribution.
359+
seed: int
360+
Manual seed.
361+
362+
Returns
363+
-------
364+
pred_dist: pd.DataFrame
365+
DataFrame with mean, variance, and mode of predicted response distribution.
366+
367+
"""
368+
if self.tau is None:
369+
pred_params = torch.tensor(predt_params.values)
370+
dist_kwargs = {arg_name: param for arg_name, param in zip(self.distribution_arg_names, pred_params.T)}
371+
dist_pred = self.distribution(**dist_kwargs)
372+
pred_moments = pd.DataFrame()
373+
374+
if inference != "none":
375+
torch.manual_seed(seed)
376+
dist_samples = dist_pred.sample((n_samples,)).squeeze().detach().numpy().T
377+
378+
if inference == "all":
379+
pred_moments["mean"] = np.mean(dist_samples, axis=1)
380+
pred_moments["variance"] = np.var(dist_samples, axis=1)
381+
pred_moments["mode"], _ = stats.mode(dist_samples, axis=1, keepdims=True)
382+
return pred_moments
383+
384+
try:
385+
mean = dist_pred.mean
386+
except NotImplementedError:
387+
if inference == "missing":
388+
pred_moments["mean"] = np.mean(dist_samples, axis=1)
389+
else:
390+
pred_moments["mean"] = mean.detach().numpy()
391+
392+
try:
393+
variance = dist_pred.variance
394+
except NotImplementedError:
395+
if inference == "missing":
396+
pred_moments["variance"] = np.var(dist_samples, axis=1)
397+
pass
398+
else:
399+
pred_moments["variance"] = variance.detach().numpy()
400+
try:
401+
mode = dist_pred.mode
402+
except NotImplementedError:
403+
if inference == "missing":
404+
pred_moments["mode"], _ = stats.mode(dist_samples, axis=1)
405+
else:
406+
pred_moments["mode"] = mode.detach().numpy()
407+
408+
if pred_moments.shape[1] == 0:
409+
return None
410+
else:
411+
return pred_moments
412+
else:
413+
return None
337414

338415
def predict_dist(self,
339416
booster: lgb.Booster,
340417
data: pd.DataFrame,
341418
start_values: np.ndarray,
342419
pred_type: str = "parameters",
420+
moments_inference: str = "none",
343421
n_samples: int = 1000,
344422
quantiles: list = [0.1, 0.5, 0.9],
345423
seed: str = 123
@@ -361,6 +439,12 @@ def predict_dist(self,
361439
- "quantiles" calculates the quantiles from the predicted distribution.
362440
- "parameters" returns the predicted distributional parameters.
363441
- "expectiles" returns the predicted expectiles.
442+
- "moments" returns the mean, variance, and (if implemented) mode.
443+
moments_inference: str
444+
Type of inference to use if the prediction type is "moments":
445+
- "none" (default) Will return only the exact, implemented moments.
446+
- "missing" Will infer moments for missing implementations by drawing samples.
447+
- "all" Will infer all moments by drawing samples.
364448
n_samples : int
365449
Number of samples to draw from the predicted distribution.
366450
quantiles : List[float]
@@ -398,18 +482,24 @@ def predict_dist(self,
398482
dist_params_predt = pd.DataFrame(dist_params_predt)
399483
dist_params_predt.columns = self.param_dict.keys()
400484

401-
# Draw samples from predicted response distribution
402-
pred_samples_df = self.draw_samples(predt_params=dist_params_predt,
403-
n_samples=n_samples,
404-
seed=seed)
405-
406485
if pred_type == "parameters":
407486
return dist_params_predt
408487

409488
elif pred_type == "expectiles":
410489
return dist_params_predt
490+
491+
elif pred_type == "moments":
492+
return self.get_moments(predt_params=dist_params_predt,
493+
inference=moments_inference,
494+
n_samples=n_samples,
495+
seed=seed)
496+
497+
# Draw samples from predicted response distribution
498+
pred_samples_df = self.draw_samples(predt_params=dist_params_predt,
499+
n_samples=n_samples,
500+
seed=seed)
411501

412-
elif pred_type == "samples":
502+
if pred_type == "samples":
413503
return pred_samples_df
414504

415505
elif pred_type == "quantiles":

lightgbmlss/model.py

+3
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ def predict(self,
436436
pred_type: str = "parameters",
437437
n_samples: int = 1000,
438438
quantiles: list = [0.1, 0.5, 0.9],
439+
moments_inference: str = "none",
439440
seed: str = 123):
440441
"""
441442
Function that predicts from the trained model.
@@ -450,6 +451,7 @@ def predict(self,
450451
- "quantiles" calculates the quantiles from the predicted distribution.
451452
- "parameters" returns the predicted distributional parameters.
452453
- "expectiles" returns the predicted expectiles.
454+
- "moments" returns the mean, variance, and (if implemented) mode.
453455
n_samples : int
454456
Number of samples to draw from the predicted distribution.
455457
quantiles : List[float]
@@ -470,6 +472,7 @@ def predict(self,
470472
pred_type=pred_type,
471473
n_samples=n_samples,
472474
quantiles=quantiles,
475+
moments_inference=moments_inference,
473476
seed=seed)
474477

475478
return predt_df

0 commit comments

Comments
 (0)