TransWikia.com

Quantile Regression with Tensorflow Probability

Stack Overflow Asked on December 1, 2020

I am trying to give tensorflow probability a try. I have a simple quantile regression in R. I would like to get the same results from tensorflow probability.

R Quantile Regression

library("quantreg")
mtcars_url <- "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv"
mtcars <- readr::read_csv(mtcars_url)
rqfit <- rq(mpg ~ wt, data = mtcars, tau = seq(.1, .9, by = 0.1))
predict(rqfit, mtcars, interval = c("confidence"),
          level = .95)

The output for the above is 9 quantile predictions for each value

    tau= 0.1 tau= 0.2  tau= 0.3  tau= 0.4  tau= 0.5  tau= 0.6  tau= 0.7 tau= 0.8 tau= 0.9
1  20.493299 20.92800 21.238356 22.176471 22.338816 23.592283 24.475462 25.12302 27.97207
2  18.837113 19.33680 19.910959 20.938971 21.181250 22.313183 23.110731 23.78409 26.37404
3  22.441753 22.80000 22.800000 23.632353 23.700658 25.097106 26.081028 26.69824 29.85210
4  16.628866 17.21520 18.141096 19.288971 19.637829 20.607717 21.291089 21.99884 24.24333
5  15.167526 15.81120 16.969863 18.197059 18.616447 19.479100 20.086915 20.81743 22.83330
6  15.037629 15.68640 16.865753 18.100000 18.525658 19.378778 19.979877 20.71242 22.70797
7  14.323196 15.00000 16.293151 17.566176 18.026316 18.827010 19.391169 20.13484 22.01862
8  16.791237 17.37120 18.271233 19.410294 19.751316 20.733119 21.424886 22.13011 24.40000
9  17.051031 17.62080 18.479452 19.604412 19.932895 20.933762 21.638962 22.34014 24.65067
10 15.167526 15.81120 16.969863 18.197059 18.616447 19.479100 20.086915 20.81743 22.83330
11 15.167526 15.81120 16.969863 18.197059 18.616447 19.479100 20.086915 20.81743 22.83330
12 11.075773 11.88000 13.690411 15.139706 15.756579 16.318971 16.715226 17.50948 18.88523
13 13.284021 14.00160 15.460274 16.789706 17.300000 18.024437 18.534868 19.29472 21.01594
14 12.959278 13.68960 15.200000 16.547059 17.073026 17.773633 18.267273 19.03219 20.70260
15  3.411856  4.51680  7.547945  9.413235 10.400000 10.400000 10.400000 11.31363 11.49042
16  2.281753  3.43104  6.642192  8.568824  9.610132  9.527203  9.468772 10.40000 10.40000
17  2.794845  3.92400  7.053425  8.952206  9.968750  9.923473  9.891571 10.81481 10.89508
18 23.221134 23.54880 23.424658 24.214706 24.245395 25.699035 26.723254 27.32833 30.60412
19 27.020619 27.19920 26.469863 27.053676 26.900987 28.633441 29.854108 30.40000 34.27019
20 25.591753 25.82640 25.324658 25.986029 25.902303 27.529904 28.676693 29.24484 32.89150
21 21.500000 21.89520 22.045205 22.928676 23.042434 24.369775 25.305004 25.93689 28.94342
22 14.647938 15.31200 16.553425 17.808824 18.253289 19.077814 19.658764 20.39737 22.33196
23 15.200000 15.84240 16.995890 18.221324 18.639145 19.504180 20.113674 20.84369 22.86464
24 12.569588 13.31520 14.887671 16.255882 16.800658 17.472669 17.946160 18.71714 20.32659
25 12.537113 13.28400 14.861644 16.231618 16.777961 17.447588 17.919401 18.69089 20.29526
26 24.942268 25.20240 24.804110 25.500735 25.448355 27.028296 28.141504 28.71977 32.26482
27 23.610825 23.92320 23.736986 24.505882 24.517763 26.000000 27.044367 27.64337 30.98013
28 27.683093 27.83568 27.000822 27.548676 27.364013 29.145080 30.400000 30.93557 34.90940
29 16.921134 17.49600 18.375342 19.507353 19.842105 20.833441 21.531924 22.23513 24.52534
30 19.519072 19.99200 20.457534 21.448529 21.657895 22.839871 23.672679 24.33542 27.03205
31 14.323196 15.00000 16.293151 17.566176 18.026316 18.827010 19.391169 20.13484 22.01862
32 19.454124 19.92960 20.405479 21.400000 21.612500 22.789711 23.619160 24.28291 26.96938

Tensorflow Probability Try

# pip3 install tensorflow_probability must be installed
from pprint import pprint
import numpy as np
import pandas as pd
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
import termplotlib as tpl
np.set_printoptions(formatter={'float': lambda x: "{0:0.1f}".format(x)})

tf.enable_v2_behavior()
tfd = tfp.distributions
mtcars_url = "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv"
df_cars = pd.read_csv(mtcars_url)

y = df_cars.mpg
x = df_cars.wt
y = np.array(y)
x = np.array(x)
negloglik = lambda y, rv_y: -rv_y.log_prob(y)
model = tf.keras.Sequential([
      tf.keras.layers.Dense(1 + 1),
      tfp.layers.DistributionLambda(
            lambda t: tfd.Normal(loc=t[..., :1],
                                           scale=1e-3 + tf.math.softplus(0.05 * t[...,1:]))),
            ])

# Do inference.
model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.01), loss=negloglik)
model.fit(x, y, epochs=1000, verbose=False);
# generate 40000 samples
n = 40000
yhat = model(np.array(x))
a = yhat.sample(n).numpy().reshape(32,n,1)

# check that the quantiles of the 40000 samples are similar to R script
quantiles = np.linspace(start=0.1, stop=0.9, num=9)
array_quants = np.quantile(a, q = quantiles, axis=1)
print(array_quants.reshape(32, 9))

The output here is:

[[13.2 13.3 13.2 13.2 13.2 13.2 13.2 13.2 13.2]
 [13.2 13.2 13.2 13.2 13.2 13.2 13.2 13.2 13.2]
 [13.2 13.2 13.2 13.2 13.2 13.2 13.2 13.2 13.2]
 [13.2 13.2 13.2 13.2 13.2 15.1 15.1 15.1 15.2]
 [15.1 15.1 15.2 15.2 15.2 15.2 15.1 15.2 15.2]
 [15.2 15.2 15.2 15.2 15.1 15.2 15.2 15.2 15.1]
 [15.2 15.2 15.2 15.2 15.2 15.2 15.1 15.1 15.2]
 [15.1 16.9 16.9 16.9 16.9 16.9 16.9 16.9 16.9]
 [16.9 16.9 16.9 16.9 16.9 16.9 16.9 16.9 16.9]
 [16.9 16.9 16.9 16.9 16.9 16.9 16.9 16.9 16.9]
 [16.9 16.9 16.9 16.9 16.9 16.9 18.3 18.3 18.3]
 [18.3 18.3 18.3 18.3 18.3 18.3 18.3 18.3 18.3]
 [18.3 18.3 18.3 18.3 18.3 18.3 18.3 18.3 18.3]
 [18.3 18.3 18.3 18.3 18.3 18.3 18.3 18.3 18.3]
 [18.3 18.3 19.4 19.3 19.4 19.4 19.4 19.4 19.3]
 [19.4 19.4 19.3 19.4 19.3 19.4 19.4 19.4 19.3]
 [19.4 19.4 19.4 19.3 19.4 19.4 19.4 19.3 19.4]
 [19.4 19.3 19.4 19.4 19.4 19.4 19.3 20.2 20.3]
 [20.2 20.2 20.2 20.3 20.2 20.2 20.2 20.2 20.2]
 [20.3 20.2 20.2 20.3 20.2 20.3 20.3 20.3 20.2]
 [20.3 20.3 20.3 20.2 20.2 20.3 20.2 20.3 20.3]
 [20.3 20.3 20.2 21.1 21.1 21.1 21.1 21.1 21.2]
 [21.1 21.1 21.1 21.1 21.1 21.1 21.1 21.1 21.1]
 [21.1 21.2 21.2 21.1 21.1 21.1 21.1 21.2 21.1]
 [21.1 21.2 21.2 21.1 21.1 21.1 21.2 21.1 22.2]
 [22.2 22.2 22.2 22.2 22.2 22.2 22.2 22.2 22.2]
 [22.2 22.2 22.2 22.2 22.2 22.2 22.2 22.2 22.2]
 [22.2 22.2 22.2 22.2 22.2 22.2 22.2 22.3 22.2]
 [22.2 22.2 22.2 22.2 24.9 24.9 24.9 24.9 24.9]
 [24.8 24.9 24.9 24.9 24.9 24.8 24.9 24.8 24.9]
 [24.9 24.9 24.9 24.9 24.9 24.9 24.9 24.9 24.9]
 [24.9 24.9 24.9 24.9 24.9 24.8 24.9 24.8 24.9]]

I am still learning TFP so I am sure there is some simple mistake I made in the model specification, but it is not currently obvious to me what that mistake is.

One Answer

Try to adjust the learning rate and see if the model converges correctly. Also, with the reshape function in your code, you probably intended to do a transpose. With a learning rate of 1.0 and some code fixes, this is what I got:

>>> n = 1000
>>> a = yhat.sample(n)[:,:,0].numpy().T
>>> quantiles = np.linspace(start=0.1, stop=0.9, num=9)
>>> array_quants = np.quantile(a, q = quantiles, axis=1)
>>> np.round(array_quants.T,1)
array([[18.1, 19.8, 21.1, 22. , 23.1, 24.2, 25.4, 26.5, 28. ],
       [16.6, 18.4, 19.8, 21. , 21.8, 23. , 24.3, 25.7, 27.4],
       [19.9, 21.5, 22.9, 23.8, 24.7, 25.6, 26.6, 28. , 29.4],
       [14.2, 16.1, 17.7, 19.1, 20.3, 21.5, 22.7, 24.4, 26.4],
       [12.8, 15. , 16.4, 17.7, 19. , 20.3, 21.8, 23.4, 25.4],
       [12.3, 14.6, 16.3, 17.7, 19. , 20.2, 21.7, 23.2, 25.2],
       [12.1, 14.5, 16.2, 17.3, 18.6, 19.7, 21.3, 23. , 25.1],
       [14.6, 16.3, 17.5, 19.1, 20.4, 21.6, 22.9, 24.2, 26.6],
       [14.9, 17. , 18.4, 19.5, 20.5, 21.7, 22.9, 24.4, 26.5],
       [13.1, 15.2, 17. , 18.2, 19.3, 20.6, 21.9, 23.4, 25.4],
       [12.9, 15. , 16.7, 18.2, 19.4, 20.7, 22. , 23.6, 25.7],
       [ 8.7, 11.6, 13.3, 14.5, 16.1, 17.5, 19. , 20.8, 23. ],
       [10.7, 13.3, 15.1, 16.5, 17.7, 19.1, 20.6, 22.3, 24.3],
       [11.1, 13.3, 14.9, 16.3, 17.4, 18.9, 20.4, 21.9, 24.1],
       [ 2. ,  5. ,  7.1,  8.6, 10.2, 11.9, 13.4, 15.9, 18.6],
       [-0.1,  3.9,  6.1,  8. ,  9.7, 11.6, 13.5, 15.5, 17.9],
       [ 0.3,  3.9,  6.3,  8.1,  9.8, 11.5, 13.1, 15.4, 18.4],
       [20.6, 22.3, 23.3, 24.3, 25.2, 26.2, 27.2, 28.3, 30. ],
       [24.2, 25.6, 26.6, 27.6, 28.3, 29.2, 30. , 30.8, 32. ],
       [22.7, 24.1, 25.1, 26.1, 26.9, 27.9, 28.8, 29.9, 31.3],
       [19. , 20.9, 22. , 23.3, 24. , 25. , 26. , 27.4, 29.1],
       [12.4, 14.6, 16.3, 17.7, 18.9, 20. , 21.5, 23.1, 25.7],
       [13.2, 15.3, 16.8, 18.2, 19.2, 20.5, 22. , 23.4, 25.4],
       [10.5, 13.1, 14.9, 16.2, 17.4, 18.6, 20.1, 21.9, 24.3],
       [10.3, 12.7, 14.4, 15.8, 17.4, 18.7, 20. , 21.9, 24.3],
       [22.4, 23.9, 24.9, 25.8, 26.6, 27.5, 28.3, 29.3, 30.8],
       [21.1, 22.6, 23.6, 24.5, 25.5, 26.3, 27.2, 28. , 29.6],
       [25.1, 26.3, 27.2, 28. , 28.7, 29.4, 30.1, 31. , 32.3],
       [14.4, 16.3, 17.7, 18.9, 20. , 21.3, 22.7, 24.2, 26.5],
       [16.9, 18.8, 20.2, 21.2, 22.5, 23.5, 24.6, 25.8, 27.5],
       [12.4, 14.7, 16.1, 17.5, 18.7, 20. , 21.5, 23.1, 25.3],
       [16.8, 18.7, 20. , 21.2, 22.5, 23.5, 24.6, 26.3, 28.1]])

Correct answer by Jongmmm on December 1, 2020

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP