Why is Pymc3 ADVI worse than MCMC in this logistic regression example

0 votes

I am aware of the mathematical differences between ADVI/MCMC, but I am trying to understand the practical implications of using one or the other. I am running a very simple logistic regressione example on data I created in this way:

import pandas as pd
import pymc3 as pm
import matplotlib.pyplot as plt
import numpy as np

def logistic(x, b, noise=None):
    L = x.T.dot(b)
    if noise is not None:
        L = L+noise
    return 1/(1+np.exp(-L))

x1 = np.linspace(-10., 10, 10000)
x2 = np.linspace(0., 20, 10000)
bias = np.ones(len(x1))
X = np.vstack([x1,x2,bias]) # Add intercept
B =  [-10., 2., 1.] # Sigmoid params for X + intercept

# Noisy mean
pnoisy = logistic(X, B, noise=np.random.normal(loc=0., scale=0., size=len(x1)))
# dichotomize pnoisy -- sample 0/1 with probability pnoisy
y = np.random.binomial(1., pnoisy)

And the I run ADVI like this:

with pm.Model() as model: 
    # Define priors
    intercept = pm.Normal('Intercept', 0, sd=10)
    x1_coef = pm.Normal('x1', 0, sd=10)
    x2_coef = pm.Normal('x2', 0, sd=10)

    # Define likelihood
    likelihood = pm.Bernoulli('y',                  
           pm.math.sigmoid(intercept+x1_coef*X[0]+x2_coef*X[1]),
                          observed=y)
    approx = pm.fit(90000, method='advi')

Unfortunately, no matter how much I increase the sampling, ADVI does not seem to be able to recover the original betas I defined [-10., 2., 1.], while MCMC works fine (as shown below)

enter image description here

Thanks' for the help!

Apr 5, 2022 in Machine Learning by Nandini
• 5,480 points
481 views

1 answer to this question.

0 votes

This is a good query! Mean field variational inference is PyMC3's default 'advi,' which performs a poor job of capturing correlations. It turns out that the model you created has a unique correlation structure, as evidenced by the following

import arviz as az
az.plot_pair(trace, figsize=(5, 5))

correlated samples
PyMC3 includes a convergence tester; executing optimization for too long or too short can result in amusing results:

from pymc3.variational.callbacks import CheckParametersConvergence

with model:
    fit = pm.fit(100_000, method='advi', callbacks=[CheckParametersConvergence()])

draws = fit.sample(2_000)

For me, this stops after roughly 60,000 iterations. Now we can look at the correlations and see that ADVI fits axis-aligned gaussians as expected:

az.plot_pair(draws, figsize=(5, 5))

another correlation image

Finally, the NUTS and (mean field) ADVI fits can be compared:

az.plot_forest([draws, trace])

forest plot
It's worth noting that ADVI underestimates variance, but it's reasonably close for each parameter's mean. You may also use method='fullrank advi' to better capture the relationships you're observing.

(Note: arviz will shortly be PyMC3's plotting library.)

answered Apr 5, 2022 by Dev
• 6,000 points

Related Questions In Machine Learning

0 votes
1 answer
0 votes
1 answer

Can we change the sigmoid with tanh in Logistic regression transforms??

Hi@Deepanshu, Yes, you can use tanh instead of ...READ MORE

answered May 12, 2020 in Machine Learning by MD
• 95,440 points
2,254 views
0 votes
1 answer

What is the difference between linear regression and logistic regression?

Hi Dev, to answer your question Linear Regression ...READ MORE

answered Feb 2, 2022 in Machine Learning by Nandini
• 5,480 points
885 views
0 votes
1 answer
0 votes
1 answer

Logistic Regression Example

Logistic Regression often referred to as the ...READ MORE

answered Jul 19, 2018 in Data Analytics by CodingByHeart77
• 3,740 points
1,196 views
0 votes
1 answer
0 votes
0 answers

How to calculate accuracy in a logistic regression model in python?

What accuracy score is considered a good ...READ MORE

Jul 30, 2019 in Python by Waseem
• 4,540 points
1,603 views
0 votes
1 answer

Can we use Normal Equation for Logistic Regression ?

Well not likely,  only one discriminative method ...READ MORE

answered Feb 24, 2022 in Machine Learning by Nandini
• 5,480 points
742 views
0 votes
1 answer

Bad logistic regression in trivial example [scikit-learn]

This is due to the process of ...READ MORE

answered Mar 17, 2022 in Machine Learning by Dev
• 6,000 points
394 views
webinar REGISTER FOR FREE WEBINAR X
REGISTER NOW
webinar_success Thank you for registering Join Edureka Meetup community for 100+ Free Webinars each month JOIN MEETUP GROUP