Skip to content

Commit 7d2fe63

Browse files
committed
Add prediction type to return the mean, variance, and (if implemented) mode
Additionally, this change avoids unnecessary sampling if the prediction type doesn't need it.
1 parent a99ae90 commit 7d2fe63

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

lightgbmlss/distributions/distribution_utils.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def predict_dist(self,
364364
- "quantile" calculates the quantiles from the predicted distribution.
365365
- "parameters" returns the predicted distributional parameters.
366366
- "expectiles" returns the predicted expectiles.
367+
- "properties" returns the mean, variance, and (if implemented) mode.
367368
n_samples : int
368369
Number of samples to draw from the predicted distribution.
369370
quantiles : List[float]
@@ -402,18 +403,35 @@ def predict_dist(self,
402403
dist_params_predt = pd.DataFrame(dist_params_predt)
403404
dist_params_predt.columns = self.param_dict.keys()
404405

405-
# Draw samples from predicted response distribution
406-
pred_samples_df = self.draw_samples(predt_params=dist_params_predt,
407-
n_samples=n_samples,
408-
seed=seed)
409-
410406
if pred_type == "parameters":
411407
return dist_params_predt
412408

413409
elif pred_type == "expectiles":
414410
return dist_params_predt
411+
412+
elif pred_type == "properties":
413+
if self.tau is None:
414+
pred_params = torch.tensor(dist_params_predt.values)
415+
dist_kwargs = {arg_name: param for arg_name, param in zip(self.distribution_arg_names, pred_params.T)}
416+
dist_pred = self.distribution(**dist_kwargs)
417+
pred_props = pd.DataFrame({"mean": dist_pred.mean.detach().numpy(),
418+
"variance": dist_pred.variance.detach().numpy()})
419+
try:
420+
dist_pred.mode
421+
except NotImplementedError:
422+
pass
423+
else:
424+
pred_props["mode"] = dist_pred.mode.detach().numpy()
425+
return pred_props
426+
else:
427+
raise ValueError("Invalid prediction type.")
428+
429+
# Draw samples from predicted response distribution
430+
pred_samples_df = self.draw_samples(predt_params=dist_params_predt,
431+
n_samples=n_samples,
432+
seed=seed)
415433

416-
elif pred_type == "samples":
434+
if pred_type == "samples":
417435
return pred_samples_df
418436

419437
elif pred_type == "quantiles":

lightgbmlss/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def predict(self,
439439
- "quantile" calculates the quantiles from the predicted distribution.
440440
- "parameters" returns the predicted distributional parameters.
441441
- "expectiles" returns the predicted expectiles.
442+
- "properties" returns the mean, variance, and (if implemented) mode.
442443
n_samples : int
443444
Number of samples to draw from the predicted distribution.
444445
quantiles : List[float]

0 commit comments

Comments
 (0)