├── README.md └── multiplelinear.py /README.md: -------------------------------------------------------------------------------- 1 | # Visualize-Multiple-Linear-Regression -------------------------------------------------------------------------------- /multiplelinear.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jun 18 12:38:44 2019 4 | 5 | @author: krish.naik 6 | """ 7 | 8 | ## Visualization for Multiple Linear Regression 9 | 10 | import numpy as np 11 | X= [[150,100],[159,200],[170,350],[175,400],[179,500],[180,180],[189,159],[199,110],[199,400],[199,230],[235,120],[239,340],[239,360],[249,145],[249,400]] 12 | Y= [0.73,1.39,2.03,1.45,1.82,1.32,0.83,0.53,1.95,1.27,0.49,1.03,1.24,0.55,1.3] 13 | 14 | 15 | 16 | 17 | ## Prepare the Dataset 18 | 19 | import pandas as pd 20 | df2=pd.DataFrame(X,columns=['Price','AdSpends']) 21 | df2['Sales']=pd.Series(Y) 22 | df2 23 | 24 | 25 | ## Apply multiple Linear Regression 26 | import matplotlib.pyplot as plt 27 | import statsmodels.formula.api as smf 28 | model = smf.ols(formula='Sales ~ Price + AdSpends', data=df2) 29 | results_formula = model.fit() 30 | results_formula.params 31 | 32 | 33 | 34 | ## Prepare the data for Visualization 35 | 36 | x_surf, y_surf = np.meshgrid(np.linspace(df2.Price.min(), df2.Price.max(), 100),np.linspace(df2.AdSpends.min(), df2.AdSpends.max(), 100)) 37 | onlyX = pd.DataFrame({'Price': x_surf.ravel(), 'AdSpends': y_surf.ravel()}) 38 | fittedY=results_formula.predict(exog=onlyX) 39 | 40 | 41 | 42 | ## convert the predicted result in an array 43 | fittedY=np.array(fittedY) 44 | 45 | 46 | 47 | 48 | # Visualize the Data for Multiple Linear Regression 49 | 50 | fig = plt.figure() 51 | ax = fig.add_subplot(111, projection='3d') 52 | ax.scatter(df2['Price'],df2['AdSpends'],df2['Sales'],c='red', marker='o', alpha=0.5) 53 | ax.plot_surface(x_surf,y_surf,fittedY.reshape(x_surf.shape), color='b', alpha=0.3) 54 | ax.set_xlabel('Price') 55 | ax.set_ylabel('AdSpends') 56 | ax.set_zlabel('Sales') 57 | plt.show() 58 | 59 | --------------------------------------------------------------------------------