Hi I’m using a DL model (TensorFlow) to predict daily minimum, mean, and maximum values of a target dataset. I was thinking that the model would have 3 outputs for each day, (min, mean, max).
Is there a clean way to enforce the correct order of these (i.e., min<mean<max)? I can add a penalty to encourage the model to train that way, but that seems like a bit of a work-around.
Two techniques: penalty and variable transformation.
build one model with these three outputs, then modify/customize the loss function during its estimation by adding the penalty for violation of the assumption. this will not guarantee the inequalities but will make them very unlikely.
You can simply add −λ[min where \lambda is a hyper parameter reflecting how badly you want to enforce the conditions, and y_1,y_2,y_3 are you min, mean and max outputs. I’m using ReLU function here, but you can use any strictly positive function.
I use this technique in similar situations. Here’s how it goes. Create new variables: y’_1=y_2\\y’_2=\ln(y_2-y_1)\\y’_3=\ln(y_3-y_2)
Now you can fit the unconstrained model to new variables, then reconstruct the outputs as
The outputs will be guaranteed to have the required conditions.
There are variations, e.g. you can transform min, mean and max into mean, range and mean/range etc which can be more stable. You can replace exponent with any strictly positive function such as ReLU, as it is noted in a comments.
This may look like a better technique, but it has its own issues. The main one is that fitting to logarithm can produce very wild forecasts. It’s one reason why you should not transform the mean itself, and only min and max are transformed to distances from the mean. This way at least we may get reasonable mean forecast, and maybe crazy min and max, which are expected to be lousy anyways.
Another thing to be aware of is that usually mean forecast should be expected to have lower variance than min and max. Therefore, you may make some accommodations in your loss function to allow min and max have larger forecast error than mean.