7
7
import numpy as np
8
8
import pandas as pd
9
9
from tqdm import tqdm
10
+ from scipy import stats
10
11
11
12
from typing import Any , Dict , Optional , List , Tuple
12
13
import matplotlib .pyplot as plt
@@ -334,12 +335,89 @@ def draw_samples(self,
334
335
dist_samples = dist_samples .astype (int )
335
336
336
337
return dist_samples
338
+
339
+ def get_moments (self ,
340
+ predt_params : pd .DataFrame ,
341
+ inference : str = "none" ,
342
+ n_samples : int = 1000 ,
343
+ seed : int = 123
344
+ ) -> pd .DataFrame :
345
+ """
346
+ Function that returns moments (mean, variance, mode) of a predicted distribution.
347
+
348
+ Arguments
349
+ ---------
350
+ predt_params: pd.DataFrame
351
+ pd.DataFrame with predicted distributional parameters.
352
+ inference: str
353
+ Type of inference from drawn samples:
354
+ - "none" (default) Will return only the exact, implemented moments.
355
+ - "missing" Will infer moments for missing implementations by drawing samples.
356
+ - "all" Will infer all moments by drawing samples.
357
+ n_samples: int
358
+ Number of sample to draw from predicted response distribution.
359
+ seed: int
360
+ Manual seed.
361
+
362
+ Returns
363
+ -------
364
+ pred_dist: pd.DataFrame
365
+ DataFrame with mean, variance, and mode of predicted response distribution.
366
+
367
+ """
368
+ if self .tau is None :
369
+ pred_params = torch .tensor (predt_params .values )
370
+ dist_kwargs = {arg_name : param for arg_name , param in zip (self .distribution_arg_names , pred_params .T )}
371
+ dist_pred = self .distribution (** dist_kwargs )
372
+ pred_moments = pd .DataFrame ()
373
+
374
+ if inference != "none" :
375
+ torch .manual_seed (seed )
376
+ dist_samples = dist_pred .sample ((n_samples ,)).squeeze ().detach ().numpy ().T
377
+
378
+ if inference == "all" :
379
+ pred_moments ["mean" ] = np .mean (dist_samples , axis = 1 )
380
+ pred_moments ["variance" ] = np .var (dist_samples , axis = 1 )
381
+ pred_moments ["mode" ], _ = stats .mode (dist_samples , axis = 1 , keepdims = True )
382
+ return pred_moments
383
+
384
+ try :
385
+ mean = dist_pred .mean
386
+ except NotImplementedError :
387
+ if inference == "missing" :
388
+ pred_moments ["mean" ] = np .mean (dist_samples , axis = 1 )
389
+ else :
390
+ pred_moments ["mean" ] = mean .detach ().numpy ()
391
+
392
+ try :
393
+ variance = dist_pred .variance
394
+ except NotImplementedError :
395
+ if inference == "missing" :
396
+ pred_moments ["variance" ] = np .var (dist_samples , axis = 1 )
397
+ pass
398
+ else :
399
+ pred_moments ["variance" ] = variance .detach ().numpy ()
400
+ try :
401
+ mode = dist_pred .mode
402
+ except NotImplementedError :
403
+ if inference == "missing" :
404
+ pred_moments ["mode" ], _ = stats .mode (dist_samples , axis = 1 )
405
+ else :
406
+ pred_moments ["mode" ] = mode .detach ().numpy ()
407
+
408
+ if pred_moments .shape [1 ] == 0 :
409
+ return None
410
+ else :
411
+ return pred_moments
412
+ else :
413
+ return None
337
414
338
415
def predict_dist (self ,
339
416
booster : lgb .Booster ,
340
417
data : pd .DataFrame ,
341
418
start_values : np .ndarray ,
342
419
pred_type : str = "parameters" ,
420
+ moments_inference : str = "none" ,
343
421
n_samples : int = 1000 ,
344
422
quantiles : list = [0.1 , 0.5 , 0.9 ],
345
423
seed : str = 123
@@ -361,6 +439,12 @@ def predict_dist(self,
361
439
- "quantiles" calculates the quantiles from the predicted distribution.
362
440
- "parameters" returns the predicted distributional parameters.
363
441
- "expectiles" returns the predicted expectiles.
442
+ - "moments" returns the mean, variance, and (if implemented) mode.
443
+ moments_inference: str
444
+ Type of inference to use if the prediction type is "moments":
445
+ - "none" (default) Will return only the exact, implemented moments.
446
+ - "missing" Will infer moments for missing implementations by drawing samples.
447
+ - "all" Will infer all moments by drawing samples.
364
448
n_samples : int
365
449
Number of samples to draw from the predicted distribution.
366
450
quantiles : List[float]
@@ -398,18 +482,24 @@ def predict_dist(self,
398
482
dist_params_predt = pd .DataFrame (dist_params_predt )
399
483
dist_params_predt .columns = self .param_dict .keys ()
400
484
401
- # Draw samples from predicted response distribution
402
- pred_samples_df = self .draw_samples (predt_params = dist_params_predt ,
403
- n_samples = n_samples ,
404
- seed = seed )
405
-
406
485
if pred_type == "parameters" :
407
486
return dist_params_predt
408
487
409
488
elif pred_type == "expectiles" :
410
489
return dist_params_predt
490
+
491
+ elif pred_type == "moments" :
492
+ return self .get_moments (predt_params = dist_params_predt ,
493
+ inference = moments_inference ,
494
+ n_samples = n_samples ,
495
+ seed = seed )
496
+
497
+ # Draw samples from predicted response distribution
498
+ pred_samples_df = self .draw_samples (predt_params = dist_params_predt ,
499
+ n_samples = n_samples ,
500
+ seed = seed )
411
501
412
- elif pred_type == "samples" :
502
+ if pred_type == "samples" :
413
503
return pred_samples_df
414
504
415
505
elif pred_type == "quantiles" :
0 commit comments