A simple animation showing linear regression

In [1]:
#linear regression example
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import time
In [2]:
from matplotlib import rc

# equivalent to rcParams['animation.html'] = 'html5'
rc('animation', html='html5')
In [3]:
#a bunch of points on the plain
x = np.array([1,2,3,4,5,6,7,8,9,10])
y = np.array([14,12,13,15,11,9,8,4,2,1])

#gradient of the quadratic loss
def grad(a,b):
    d = y - (a*x + b)      #derivative of the loss
    da = - np.sum(d * x)   #derivative of d w.r.t. a
    db = - np.sum(d)       #derivative of d w.r.t. b 
    return(da,db)
In [39]:
lr = 0.001
epochs = 2000

#step 1
a = np.random.rand()
b = np.random.rand()
params=[a,b]

fig = plt.figure()
plt.plot(x,y,'ro')
line, = plt.plot([], [], lw=2)

def init():
    #current approximation
    line.set_data([x[0],x[9]],[a*x[0]+b,a*x[9]+b])
    return line,

def step(i):
    a,b=params
    da,db = grad(a,b)
    if i%100==0:
      print("current loss = {}".format(np.sum((y-a*x-b)**2)))
    params[0] = a - lr*da
    params[1] = b - lr*db
    ##### for animation
    line.set_data([x[0],x[9]],[a*x[0]+b,a*x[9]+b])
    #time.sleep(.01)
    return line,

anim = animation.FuncAnimation(fig, step, init_func=init, frames=epochs, interval=1, blit=True, repeat=False)

plt.savefig("regr_exe.jpg")
plt.show()
No description has been provided for this image
In [40]:
anim
current loss = 669.1416045793965
current loss = 449.5447461163315
current loss = 306.18805350416426
current loss = 212.03323655041143
current loss = 150.1935735806599
current loss = 109.57808475133304
current loss = 82.90235830612147
current loss = 65.3820870527089
current loss = 53.87500006020983
current loss = 46.317294523707446
current loss = 41.35349152007033
current loss = 38.0933300676411
current loss = 35.952098296735215
current loss = 34.54576503268145
current loss = 33.62210357928028
current loss = 33.015454712823065
current loss = 32.617015603105294
current loss = 32.35532596197924
current loss = 32.18345159873916
current loss = 32.07056674547535
Out[40]:
Your browser does not support the video tag.