@@ -364,6 +364,7 @@ def predict_dist(self,
364
364
- "quantile" calculates the quantiles from the predicted distribution.
365
365
- "parameters" returns the predicted distributional parameters.
366
366
- "expectiles" returns the predicted expectiles.
367
+ - "properties" returns the mean, variance, and (if implemented) mode.
367
368
n_samples : int
368
369
Number of samples to draw from the predicted distribution.
369
370
quantiles : List[float]
@@ -402,18 +403,35 @@ def predict_dist(self,
402
403
dist_params_predt = pd .DataFrame (dist_params_predt )
403
404
dist_params_predt .columns = self .param_dict .keys ()
404
405
405
- # Draw samples from predicted response distribution
406
- pred_samples_df = self .draw_samples (predt_params = dist_params_predt ,
407
- n_samples = n_samples ,
408
- seed = seed )
409
-
410
406
if pred_type == "parameters" :
411
407
return dist_params_predt
412
408
413
409
elif pred_type == "expectiles" :
414
410
return dist_params_predt
411
+
412
+ elif pred_type == "properties" :
413
+ if self .tau is None :
414
+ pred_params = torch .tensor (dist_params_predt .values )
415
+ dist_kwargs = {arg_name : param for arg_name , param in zip (self .distribution_arg_names , pred_params .T )}
416
+ dist_pred = self .distribution (** dist_kwargs )
417
+ pred_props = pd .DataFrame ({"mean" : dist_pred .mean .detach ().numpy (),
418
+ "variance" : dist_pred .variance .detach ().numpy ()})
419
+ try :
420
+ dist_pred .mode
421
+ except NotImplementedError :
422
+ pass
423
+ else :
424
+ pred_props ["mode" ] = dist_pred .mode .detach ().numpy ()
425
+ return pred_props
426
+ else :
427
+ raise ValueError ("Invalid prediction type." )
428
+
429
+ # Draw samples from predicted response distribution
430
+ pred_samples_df = self .draw_samples (predt_params = dist_params_predt ,
431
+ n_samples = n_samples ,
432
+ seed = seed )
415
433
416
- elif pred_type == "samples" :
434
+ if pred_type == "samples" :
417
435
return pred_samples_df
418
436
419
437
elif pred_type == "quantiles" :
0 commit comments