GradientCOBRA method of gradientcobra v1.0.8 package

Author

Installing and importing packages

gradientcobra can be installed from pypi using pip:

pip install gradientcobra

Importing packages

# Metric of error
from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error

# Plotting figures
import matplotlib.pyplot as plt
from matplotlib import cm

# Import class GradientCOBRA from the gradientcobra library
from gradientcobra.gradientcobra import GradientCOBRA

import seaborn as sns
sns.set()

Simulated data

We simulate a regression data with \(1000\) observations and \(10\) inputs variables.

# For simulating dataset
from sklearn.datasets import make_regression

X1, y1 = make_regression(n_samples=1000, n_features=10, noise=1)

Now, let’s randoly split the simulated data into \(80\%-20\%\) training-testing data.

from sklearn.model_selection import train_test_split

X_train1, X_test1, y_train1, y_test1 = train_test_split(X1, y1, test_size=0.2)
print('shape: x_train = {} , x_train = {} , y_train = {} , y_test = {}'.format(
    X_train1.shape, 
    X_test1.shape, 
    y_train1.shape, 
    y_test1.shape))
shape: x_train = (800, 10) , x_train = (200, 10) , y_train = (800,) , y_test = (200,)

GradientCOBRA with default parameters

We create gradientcobra object called gc1 using GradientCOBRA class with the default parameters, then fit it to the training data.

gc1 = GradientCOBRA()
gc1_fit = gc1.fit(X_train1, y_train1)

    * Gradient descent with radial kernel is implemented...
        ~ Initial t = 0:        ~ bandwidth: 2.500  ~ gradient: 8.002   ~ threshold: 1e-05
        ~     Iteration: 1  ~ bandwidth: 2.490  ~ gradient: 8.002   ~ stopping criterion: 4.003     ~     Iteration: 2  ~ bandwidth: 2.490  ~ gradient: 8.002   ~ stopping criterion: 4.003     ~     Iteration: 3  ~ bandwidth: 2.490  ~ gradient: 8.002   ~ stopping criterion: 4.003     ~     Iteration: 4  ~ bandwidth: 2.490  ~ gradient: 8.002   ~ stopping criterion: 4.003     ~     Iteration: 5  ~ bandwidth: 2.490  ~ gradient: 8.209   ~ stopping criterion: 4.001     ~     Iteration: 6  ~ bandwidth: 2.480  ~ gradient: 8.135   ~ stopping criterion: 4.104     ~     Iteration: 7  ~ bandwidth: 2.470  ~ gradient: 8.206   ~ stopping criterion: 4.067     ~     Iteration: 8  ~ bandwidth: 2.459  ~ gradient: 7.443   ~ stopping criterion: 4.103     ~     Iteration: 9  ~ bandwidth: 2.450  ~ gradient: 8.307   ~ stopping criterion: 3.721     ~     Iteration: 10     ~ bandwidth: 2.440  ~ gradient: 7.917   ~ stopping criterion: 4.154     ~     Iteration: 11     ~ bandwidth: 2.430  ~ gradient: 8.570   ~ stopping criterion: 3.958     ~     Iteration: 12     ~ bandwidth: 2.419  ~ gradient: 7.972   ~ stopping criterion: 4.285     ~     Iteration: 13     ~ bandwidth: 2.409  ~ gradient: 8.147   ~ stopping criterion: 3.986     ~     Iteration: 14     ~ bandwidth: 2.399  ~ gradient: 7.820   ~ stopping criterion: 4.073     ~     Iteration: 15     ~ bandwidth: 2.389  ~ gradient: 7.755   ~ stopping criterion: 3.910     ~     Iteration: 16     ~ bandwidth: 2.379  ~ gradient: 8.103   ~ stopping criterion: 3.877     ~     Iteration: 17     ~ bandwidth: 2.369  ~ gradient: 8.136   ~ stopping criterion: 4.052     ~     Iteration: 18     ~ bandwidth: 2.359  ~ gradient: 8.135   ~ stopping criterion: 4.068     ~     Iteration: 19     ~ bandwidth: 2.349  ~ gradient: 8.068   ~ stopping criterion: 4.067     ~     Iteration: 20     ~ bandwidth: 2.339  ~ gradient: 7.799   ~ stopping criterion: 4.034     ~     Iteration: 21     ~ bandwidth: 2.329  ~ gradient: 7.948   ~ stopping criterion: 3.899     ~     Iteration: 22     ~ bandwidth: 2.319  ~ gradient: 8.215   ~ stopping criterion: 3.974     ~     Iteration: 23     ~ bandwidth: 2.309  ~ gradient: 8.478   ~ stopping criterion: 4.107     ~     Iteration: 24     ~ bandwidth: 2.298  ~ gradient: 7.833   ~ stopping criterion: 4.239     ~     Iteration: 25     ~ bandwidth: 2.289  ~ gradient: 8.336   ~ stopping criterion: 3.917     ~     Iteration: 26     ~ bandwidth: 2.278  ~ gradient: 8.584   ~ stopping criterion: 4.168     ~     Iteration: 27     ~ bandwidth: 2.267  ~ gradient: 8.285   ~ stopping criterion: 4.292     ~     Iteration: 28     ~ bandwidth: 2.257  ~ gradient: 7.930   ~ stopping criterion: 4.142     ~     Iteration: 29     ~ bandwidth: 2.247  ~ gradient: 8.896   ~ stopping criterion: 3.965     ~     Iteration: 30     ~ bandwidth: 2.236  ~ gradient: 8.236   ~ stopping criterion: 4.448     ~     Iteration: 31     ~ bandwidth: 2.226  ~ gradient: 8.317   ~ stopping criterion: 4.118     ~     Iteration: 32     ~ bandwidth: 2.215  ~ gradient: 8.716   ~ stopping criterion: 4.159     ~     Iteration: 33     ~ bandwidth: 2.205  ~ gradient: 8.352   ~ stopping criterion: 4.358     ~     Iteration: 34     ~ bandwidth: 2.194  ~ gradient: 8.312   ~ stopping criterion: 4.176     ~     Iteration: 35     ~ bandwidth: 2.184  ~ gradient: 8.497   ~ stopping criterion: 4.156     ~     Iteration: 36     ~ bandwidth: 2.173  ~ gradient: 8.397   ~ stopping criterion: 4.248     ~     Iteration: 37     ~ bandwidth: 2.163  ~ gradient: 9.139   ~ stopping criterion: 4.199     ~     Iteration: 38     ~ bandwidth: 2.151  ~ gradient: 7.960   ~ stopping criterion: 4.570     ~     Iteration: 39     ~ bandwidth: 2.141  ~ gradient: 9.014   ~ stopping criterion: 3.980     ~     Iteration: 40     ~ bandwidth: 2.130  ~ gradient: 8.339   ~ stopping criterion: 4.507     ~     Iteration: 41     ~ bandwidth: 2.120  ~ gradient: 8.813   ~ stopping criterion: 4.169     ~     Iteration: 42     ~ bandwidth: 2.109  ~ gradient: 8.508   ~ stopping criterion: 4.406     ~     Iteration: 43     ~ bandwidth: 2.098  ~ gradient: 8.844   ~ stopping criterion: 4.254     ~     Iteration: 44     ~ bandwidth: 2.087  ~ gradient: 8.637   ~ stopping criterion: 4.422     ~     Iteration: 45     ~ bandwidth: 2.076  ~ gradient: 8.895   ~ stopping criterion: 4.318     ~     Iteration: 46     ~ bandwidth: 2.065  ~ gradient: 8.560   ~ stopping criterion: 4.448     ~     Iteration: 47     ~ bandwidth: 2.054  ~ gradient: 8.401   ~ stopping criterion: 4.280     ~     Iteration: 48     ~ bandwidth: 2.044  ~ gradient: 8.963   ~ stopping criterion: 4.201     ~     Iteration: 49     ~ bandwidth: 2.033  ~ gradient: 8.534   ~ stopping criterion: 4.482     ~     Iteration: 50     ~ bandwidth: 2.022  ~ gradient: 9.383   ~ stopping criterion: 4.267     ~     Iteration: 51     ~ bandwidth: 2.010  ~ gradient: 8.937   ~ stopping criterion: 4.691     ~     Iteration: 52     ~ bandwidth: 1.999  ~ gradient: 9.044   ~ stopping criterion: 4.469     ~     Iteration: 53     ~ bandwidth: 1.988  ~ gradient: 9.048   ~ stopping criterion: 4.522     ~     Iteration: 54     ~ bandwidth: 1.976  ~ gradient: 8.294   ~ stopping criterion: 4.524     ~     Iteration: 55     ~ bandwidth: 1.966  ~ gradient: 9.155   ~ stopping criterion: 4.147     ~     Iteration: 56     ~ bandwidth: 1.955  ~ gradient: 9.021   ~ stopping criterion: 4.578     ~     Iteration: 57     ~ bandwidth: 1.943  ~ gradient: 8.945   ~ stopping criterion: 4.510     ~     Iteration: 58     ~ bandwidth: 1.932  ~ gradient: 9.460   ~ stopping criterion: 4.472     ~     Iteration: 59     ~ bandwidth: 1.920  ~ gradient: 9.442   ~ stopping criterion: 4.730     ~     Iteration: 60     ~ bandwidth: 1.908  ~ gradient: 9.481   ~ stopping criterion: 4.721     ~     Iteration: 61     ~ bandwidth: 1.897  ~ gradient: 8.911   ~ stopping criterion: 4.740     ~     Iteration: 62     ~ bandwidth: 1.885  ~ gradient: 8.933   ~ stopping criterion: 4.456     ~     Iteration: 63     ~ bandwidth: 1.874  ~ gradient: 9.264   ~ stopping criterion: 4.467     ~     Iteration: 64     ~ bandwidth: 1.863  ~ gradient: 8.965   ~ stopping criterion: 4.632     ~     Iteration: 65     ~ bandwidth: 1.852  ~ gradient: 9.168   ~ stopping criterion: 4.482     ~     Iteration: 66     ~ bandwidth: 1.840  ~ gradient: 8.875   ~ stopping criterion: 4.584     ~     Iteration: 67     ~ bandwidth: 1.829  ~ gradient: 9.439   ~ stopping criterion: 4.438     ~     Iteration: 68     ~ bandwidth: 1.817  ~ gradient: 9.092   ~ stopping criterion: 4.719     ~     Iteration: 69     ~ bandwidth: 1.806  ~ gradient: 9.207   ~ stopping criterion: 4.546     ~     Iteration: 70     ~ bandwidth: 1.794  ~ gradient: 9.209   ~ stopping criterion: 4.604     ~     Iteration: 71     ~ bandwidth: 1.783  ~ gradient: 9.595   ~ stopping criterion: 4.604     ~     Iteration: 72     ~ bandwidth: 1.771  ~ gradient: 9.453   ~ stopping criterion: 4.798     ~     Iteration: 73     ~ bandwidth: 1.759  ~ gradient: 9.245   ~ stopping criterion: 4.726     ~     Iteration: 74     ~ bandwidth: 1.747  ~ gradient: 9.648   ~ stopping criterion: 4.623     ~     Iteration: 75     ~ bandwidth: 1.735  ~ gradient: 9.296   ~ stopping criterion: 4.824     ~     Iteration: 76     ~ bandwidth: 1.724  ~ gradient: 9.305   ~ stopping criterion: 4.648     ~     Iteration: 77     ~ bandwidth: 1.712  ~ gradient: 9.616   ~ stopping criterion: 4.652     ~     Iteration: 78     ~ bandwidth: 1.700  ~ gradient: 9.062   ~ stopping criterion: 4.808     ~     Iteration: 79     ~ bandwidth: 1.689  ~ gradient: 9.942   ~ stopping criterion: 4.531     ~     Iteration: 80     ~ bandwidth: 1.676  ~ gradient: 10.174  ~ stopping criterion: 4.971     ~     Iteration: 81     ~ bandwidth: 1.664  ~ gradient: 9.579   ~ stopping criterion: 5.087     ~     Iteration: 82     ~ bandwidth: 1.652  ~ gradient: 9.629   ~ stopping criterion: 4.790     ~     Iteration: 83     ~ bandwidth: 1.640  ~ gradient: 9.393   ~ stopping criterion: 4.815     ~     Iteration: 84     ~ bandwidth: 1.628  ~ gradient: 9.745   ~ stopping criterion: 4.697     ~     Iteration: 85     ~ bandwidth: 1.616  ~ gradient: 9.628   ~ stopping criterion: 4.873     ~     Iteration: 86     ~ bandwidth: 1.604  ~ gradient: 9.639   ~ stopping criterion: 4.814     ~     Iteration: 87     ~ bandwidth: 1.592  ~ gradient: 9.516   ~ stopping criterion: 4.819     ~     Iteration: 88     ~ bandwidth: 1.580  ~ gradient: 9.815   ~ stopping criterion: 4.758     ~     Iteration: 89     ~ bandwidth: 1.567  ~ gradient: 9.447   ~ stopping criterion: 4.908     ~     Iteration: 90     ~ bandwidth: 1.556  ~ gradient: 9.282   ~ stopping criterion: 4.723     ~     Iteration: 91     ~ bandwidth: 1.544  ~ gradient: 9.354   ~ stopping criterion: 4.641     ~     Iteration: 92     ~ bandwidth: 1.532  ~ gradient: 9.314   ~ stopping criterion: 4.677     ~     Iteration: 93     ~ bandwidth: 1.521  ~ gradient: 9.413   ~ stopping criterion: 4.657     ~     Iteration: 94     ~ bandwidth: 1.509  ~ gradient: 8.984   ~ stopping criterion: 4.706     ~     Iteration: 95     ~ bandwidth: 1.498  ~ gradient: 9.645   ~ stopping criterion: 4.492     ~     Iteration: 96     ~ bandwidth: 1.486  ~ gradient: 9.737   ~ stopping criterion: 4.822     ~     Iteration: 97     ~ bandwidth: 1.474  ~ gradient: 8.980   ~ stopping criterion: 4.869     ~     Iteration: 98     ~ bandwidth: 1.462  ~ gradient: 9.264   ~ stopping criterion: 4.490     ~     Iteration: 99     ~ bandwidth: 1.451  ~ gradient: 9.442   ~ stopping criterion: 4.632     ~     Iteration: 100    ~ bandwidth: 1.439  ~ gradient: 9.194   ~ stopping criterion: 4.721     ~     Iteration: 101    ~ bandwidth: 1.427  ~ gradient: 9.073   ~ stopping criterion: 4.597     ~     Iteration: 102    ~ bandwidth: 1.416  ~ gradient: 9.481   ~ stopping criterion: 4.536     ~     Iteration: 103    ~ bandwidth: 1.404  ~ gradient: 9.198   ~ stopping criterion: 4.740     ~     Iteration: 104    ~ bandwidth: 1.393  ~ gradient: 8.945   ~ stopping criterion: 4.599     ~     Iteration: 105    ~ bandwidth: 1.382  ~ gradient: 9.450   ~ stopping criterion: 4.473     ~     Iteration: 106    ~ bandwidth: 1.370  ~ gradient: 9.207   ~ stopping criterion: 4.725     ~     Iteration: 107    ~ bandwidth: 1.358  ~ gradient: 9.401   ~ stopping criterion: 4.604     ~     Iteration: 108    ~ bandwidth: 1.347  ~ gradient: 9.197   ~ stopping criterion: 4.700     ~     Iteration: 109    ~ bandwidth: 1.335  ~ gradient: 9.187   ~ stopping criterion: 4.599     ~     Iteration: 110    ~ bandwidth: 1.324  ~ gradient: 9.081   ~ stopping criterion: 4.594     ~     Iteration: 111    ~ bandwidth: 1.312  ~ gradient: 8.719   ~ stopping criterion: 4.541     ~     Iteration: 112    ~ bandwidth: 1.301  ~ gradient: 8.621   ~ stopping criterion: 4.359     ~     Iteration: 113    ~ bandwidth: 1.291  ~ gradient: 8.600   ~ stopping criterion: 4.311     ~     Iteration: 114    ~ bandwidth: 1.280  ~ gradient: 8.490   ~ stopping criterion: 4.300     ~     Iteration: 115    ~ bandwidth: 1.269  ~ gradient: 8.946   ~ stopping criterion: 4.245     ~     Iteration: 116    ~ bandwidth: 1.258  ~ gradient: 7.976   ~ stopping criterion: 4.473     ~     Iteration: 117    ~ bandwidth: 1.248  ~ gradient: 8.531   ~ stopping criterion: 3.988     ~     Iteration: 118    ~ bandwidth: 1.237  ~ gradient: 8.473   ~ stopping criterion: 4.265     ~     Iteration: 119    ~ bandwidth: 1.227  ~ gradient: 8.383   ~ stopping criterion: 4.236     ~     Iteration: 120    ~ bandwidth: 1.216  ~ gradient: 7.860   ~ stopping criterion: 4.191     ~     Iteration: 121    ~ bandwidth: 1.206  ~ gradient: 7.453   ~ stopping criterion: 3.930     ~     Iteration: 122    ~ bandwidth: 1.197  ~ gradient: 7.410   ~ stopping criterion: 3.726     ~     Iteration: 123    ~ bandwidth: 1.188  ~ gradient: 7.759   ~ stopping criterion: 3.705     ~     Iteration: 124    ~ bandwidth: 1.178  ~ gradient: 6.758   ~ stopping criterion: 3.879     ~     Iteration: 125    ~ bandwidth: 1.170  ~ gradient: 7.204   ~ stopping criterion: 3.379     ~     Iteration: 126    ~ bandwidth: 1.161  ~ gradient: 6.294   ~ stopping criterion: 3.602     ~     Iteration: 127    ~ bandwidth: 1.153  ~ gradient: 7.281   ~ stopping criterion: 3.147     ~     Iteration: 128    ~ bandwidth: 1.144  ~ gradient: 7.028   ~ stopping criterion: 3.640     ~     Iteration: 129    ~ bandwidth: 1.135  ~ gradient: 6.209   ~ stopping criterion: 3.514     ~     Iteration: 130    ~ bandwidth: 1.127  ~ gradient: 6.266   ~ stopping criterion: 3.105     ~     Iteration: 131    ~ bandwidth: 1.119  ~ gradient: 6.797   ~ stopping criterion: 3.133     ~     Iteration: 132    ~ bandwidth: 1.111  ~ gradient: 7.064   ~ stopping criterion: 3.398     ~     Iteration: 133    ~ bandwidth: 1.102  ~ gradient: 6.176   ~ stopping criterion: 3.532     ~     Iteration: 134    ~ bandwidth: 1.094  ~ gradient: 5.986   ~ stopping criterion: 3.088     ~     Iteration: 135    ~ bandwidth: 1.087  ~ gradient: 6.235   ~ stopping criterion: 2.993     ~     Iteration: 136    ~ bandwidth: 1.079  ~ gradient: 5.797   ~ stopping criterion: 3.117     ~     Iteration: 137    ~ bandwidth: 1.072  ~ gradient: 5.946   ~ stopping criterion: 2.899     ~     Iteration: 138    ~ bandwidth: 1.064  ~ gradient: 5.627   ~ stopping criterion: 2.973     ~     Iteration: 139    ~ bandwidth: 1.057  ~ gradient: 5.058   ~ stopping criterion: 2.813     ~     Iteration: 140    ~ bandwidth: 1.051  ~ gradient: 4.886   ~ stopping criterion: 2.529     ~     Iteration: 141    ~ bandwidth: 1.045  ~ gradient: 4.768   ~ stopping criterion: 2.443     ~     Iteration: 142    ~ bandwidth: 1.039  ~ gradient: 4.637   ~ stopping criterion: 2.384     ~     Iteration: 143    ~ bandwidth: 1.033  ~ gradient: 4.663   ~ stopping criterion: 2.318     ~     Iteration: 144    ~ bandwidth: 1.027  ~ gradient: 4.550   ~ stopping criterion: 2.332     ~     Iteration: 145    ~ bandwidth: 1.022  ~ gradient: 4.377   ~ stopping criterion: 2.275     ~     Iteration: 146    ~ bandwidth: 1.016  ~ gradient: 3.875   ~ stopping criterion: 2.188     ~     Iteration: 147    ~ bandwidth: 1.011  ~ gradient: 4.062   ~ stopping criterion: 1.938     ~     Iteration: 148    ~ bandwidth: 1.006  ~ gradient: 4.194   ~ stopping criterion: 2.031     ~     Iteration: 149    ~ bandwidth: 1.001  ~ gradient: 3.776   ~ stopping criterion: 2.097     ~     Iteration: 150    ~ bandwidth: 0.996  ~ gradient: 3.336   ~ stopping criterion: 1.888     ~     Iteration: 151    ~ bandwidth: 0.992  ~ gradient: 3.273   ~ stopping criterion: 1.668     ~     Iteration: 152    ~ bandwidth: 0.988  ~ gradient: 2.999   ~ stopping criterion: 1.636     ~     Iteration: 153    ~ bandwidth: 0.984  ~ gradient: 3.048   ~ stopping criterion: 1.499     ~     Iteration: 154    ~ bandwidth: 0.981  ~ gradient: 3.153   ~ stopping criterion: 1.524     ~     Iteration: 155    ~ bandwidth: 0.977  ~ gradient: 2.129   ~ stopping criterion: 1.577     ~     Iteration: 156    ~ bandwidth: 0.974  ~ gradient: 2.714   ~ stopping criterion: 1.064     ~     Iteration: 157    ~ bandwidth: 0.971  ~ gradient: 2.656   ~ stopping criterion: 1.357     ~     Iteration: 158    ~ bandwidth: 0.967  ~ gradient: 2.825   ~ stopping criterion: 1.328     ~     Iteration: 159    ~ bandwidth: 0.964  ~ gradient: 2.303   ~ stopping criterion: 1.413     ~     Iteration: 160    ~ bandwidth: 0.961  ~ gradient: 2.390   ~ stopping criterion: 1.151     ~     Iteration: 161    ~ bandwidth: 0.958  ~ gradient: 1.688   ~ stopping criterion: 1.195     ~     Iteration: 162    ~ bandwidth: 0.956  ~ gradient: 2.041   ~ stopping criterion: 0.844     ~     Iteration: 163    ~ bandwidth: 0.953  ~ gradient: 1.592   ~ stopping criterion: 1.021     ~     Iteration: 164    ~ bandwidth: 0.951  ~ gradient: 1.137   ~ stopping criterion: 0.796     ~     Iteration: 165    ~ bandwidth: 0.950  ~ gradient: 1.866   ~ stopping criterion: 0.569     ~     Iteration: 166    ~ bandwidth: 0.947  ~ gradient: 1.174   ~ stopping criterion: 0.933     ~     Iteration: 167    ~ bandwidth: 0.946  ~ gradient: 1.711   ~ stopping criterion: 0.587     ~     Iteration: 168    ~ bandwidth: 0.944  ~ gradient: 1.655   ~ stopping criterion: 0.855     ~     Iteration: 169    ~ bandwidth: 0.942  ~ gradient: 1.028   ~ stopping criterion: 0.827     ~     Iteration: 170    ~ bandwidth: 0.940  ~ gradient: 0.706   ~ stopping criterion: 0.514     ~     Iteration: 171    ~ bandwidth: 0.940  ~ gradient: 0.899   ~ stopping criterion: 0.353     ~     Iteration: 172    ~ bandwidth: 0.938  ~ gradient: 1.019   ~ stopping criterion: 0.449     ~     Iteration: 173    ~ bandwidth: 0.937  ~ gradient: 1.225   ~ stopping criterion: 0.510     ~     Iteration: 174    ~ bandwidth: 0.936  ~ gradient: 1.174   ~ stopping criterion: 0.612     ~     Iteration: 175    ~ bandwidth: 0.934  ~ gradient: 0.980   ~ stopping criterion: 0.587     ~     Iteration: 176    ~ bandwidth: 0.933  ~ gradient: 0.675   ~ stopping criterion: 0.490     ~     Iteration: 177    ~ bandwidth: 0.932  ~ gradient: 1.208   ~ stopping criterion: 0.338     ~     Iteration: 178    ~ bandwidth: 0.931  ~ gradient: 1.009   ~ stopping criterion: 0.604     ~     Iteration: 179    ~ bandwidth: 0.929  ~ gradient: 0.308   ~ stopping criterion: 0.504     ~     Iteration: 180    ~ bandwidth: 0.929  ~ gradient: 0.898   ~ stopping criterion: 0.154     ~     Iteration: 181    ~ bandwidth: 0.928  ~ gradient: 0.077   ~ stopping criterion: 0.449     ~     Iteration: 182    ~ bandwidth: 0.928  ~ gradient: 0.274   ~ stopping criterion: 0.039     ~     Iteration: 183    ~ bandwidth: 0.927  ~ gradient: 0.892   ~ stopping criterion: 0.137     ~     Iteration: 184    ~ bandwidth: 0.926  ~ gradient: 0.706   ~ stopping criterion: 0.446     ~     Iteration: 185    ~ bandwidth: 0.925  ~ gradient: 1.077   ~ stopping criterion: 0.353     ~     Iteration: 186    ~ bandwidth: 0.924  ~ gradient: -0.295  ~ stopping criterion: 0.539     ~     Iteration: 187    ~ bandwidth: 0.924  ~ gradient: 0.509   ~ stopping criterion: 0.147     ~     Iteration: 188    ~ bandwidth: 0.924  ~ gradient: 0.052   ~ stopping criterion: 0.255     ~     Iteration: 189    ~ bandwidth: 0.924  ~ gradient: 0.590   ~ stopping criterion: 0.026     ~     Iteration: 190    ~ bandwidth: 0.923  ~ gradient: 0.831   ~ stopping criterion: 0.295     ~     Iteration: 191    ~ bandwidth: 0.923  ~ gradient: 0.210   ~ stopping criterion: 0.415     ~     Iteration: 192    ~ bandwidth: 0.923  ~ gradient: 0.194   ~ stopping criterion: 0.105     ~     Iteration: 193    ~ bandwidth: 0.922  ~ gradient: 0.682   ~ stopping criterion: 0.097     ~     Iteration: 194    ~ bandwidth: 0.922  ~ gradient: -0.096  ~ stopping criterion: 0.341     ~     Iteration: 195    ~ bandwidth: 0.922  ~ gradient: 0.741   ~ stopping criterion: 0.048     ~     Iteration: 196    ~ bandwidth: 0.921  ~ gradient: -0.099  ~ stopping criterion: 0.370     ~     Iteration: 197    ~ bandwidth: 0.922  ~ gradient: 0.487   ~ stopping criterion: 0.049     ~     Iteration: 198    ~ bandwidth: 0.921  ~ gradient: -0.103  ~ stopping criterion: 0.243     ~     Iteration: 199    ~ bandwidth: 0.921  ~ gradient: 0.260   ~ stopping criterion: 0.052     ~     Iteration: 200    ~ bandwidth: 0.921  ~ gradient: 1.251   ~ stopping criterion: 0.130     ~     Iteration: 201    ~ bandwidth: 0.921  ~ gradient: 0.739   ~ stopping criterion: 0.625     ~     Iteration: 202    ~ bandwidth: 0.921  ~ gradient: 0.066   ~ stopping criterion: 0.370     ~     Iteration: 203    ~ bandwidth: 0.921  ~ gradient: 0.530   ~ stopping criterion: 0.033     ~     Iteration: 204    ~ bandwidth: 0.921  ~ gradient: 0.545   ~ stopping criterion: 0.265     ~     Iteration: 205    ~ bandwidth: 0.921  ~ gradient: 0.727   ~ stopping criterion: 0.273     ~     Iteration: 206    ~ bandwidth: 0.920  ~ gradient: 0.232   ~ stopping criterion: 0.364     ~     Iteration: 207    ~ bandwidth: 0.920  ~ gradient: -0.294  ~ stopping criterion: 0.116     ~     Iteration: 208    ~ bandwidth: 0.920  ~ gradient: -0.101  ~ stopping criterion: 0.147     ~     Iteration: 209    ~ bandwidth: 0.921  ~ gradient: 0.122   ~ stopping criterion: 0.051     ~     Iteration: 210    ~ bandwidth: 0.920  ~ gradient: -0.006  ~ stopping criterion: 0.061     ~     Iteration: 211    ~ bandwidth: 0.920  ~ gradient: 0.653   ~ stopping criterion: 0.003     ~     Iteration: 212    ~ bandwidth: 0.920  ~ gradient: 0.379   ~ stopping criterion: 0.326     ~     Iteration: 213    ~ bandwidth: 0.920  ~ gradient: 0.614   ~ stopping criterion: 0.189     ~     Iteration: 214    ~ bandwidth: 0.920  ~ gradient: 0.130   ~ stopping criterion: 0.307     ~     Iteration: 215    ~ bandwidth: 0.920  ~ gradient: 0.545   ~ stopping criterion: 0.065     ~     Iteration: 216    ~ bandwidth: 0.920  ~ gradient: 0.325   ~ stopping criterion: 0.273     ~     Iteration: 217    ~ bandwidth: 0.920  ~ gradient: 0.307   ~ stopping criterion: 0.163     ~     Iteration: 218    ~ bandwidth: 0.920  ~ gradient: 0.489   ~ stopping criterion: 0.153     ~     Iteration: 219    ~ bandwidth: 0.920  ~ gradient: 0.646   ~ stopping criterion: 0.245     ~     Iteration: 220    ~ bandwidth: 0.920  ~ gradient: 0.235   ~ stopping criterion: 0.323     ~     Iteration: 221    ~ bandwidth: 0.920  ~ gradient: 0.537   ~ stopping criterion: 0.118     ~     Iteration: 222    ~ bandwidth: 0.920  ~ gradient: -0.015  ~ stopping criterion: 0.268     ~     Iteration: 223    ~ bandwidth: 0.920  ~ gradient: 0.582   ~ stopping criterion: 0.008     ~     Iteration: 224    ~ bandwidth: 0.920  ~ gradient: 0.654   ~ stopping criterion: 0.291     ~     Iteration: 225    ~ bandwidth: 0.920  ~ gradient: 0.328   ~ stopping criterion: 0.327     ~     Iteration: 226    ~ bandwidth: 0.920  ~ gradient: 0.437   ~ stopping criterion: 0.164     ~     Iteration: 227    ~ bandwidth: 0.920  ~ gradient: 0.187   ~ stopping criterion: 0.218     ~     Iteration: 228    ~ bandwidth: 0.920  ~ gradient: -0.234  ~ stopping criterion: 0.093     ~     Iteration: 229    ~ bandwidth: 0.920  ~ gradient: 0.216   ~ stopping criterion: 0.117     ~     Iteration: 230    ~ bandwidth: 0.920  ~ gradient: 0.004   ~ stopping criterion: 0.108     ~     Iteration: 231    ~ bandwidth: 0.920  ~ gradient: 0.008   ~ stopping criterion: 0.002     ~     Iteration: 232    ~ bandwidth: 0.920  ~ gradient: 0.300   ~ stopping criterion: 0.004     ~     Iteration: 233    ~ bandwidth: 0.920  ~ gradient: 0.592   ~ stopping criterion: 0.150     ~     Iteration: 234    ~ bandwidth: 0.920  ~ gradient: 1.126   ~ stopping criterion: 0.296     ~     Iteration: 235    ~ bandwidth: 0.920  ~ gradient: 0.781   ~ stopping criterion: 0.563     ~     Iteration: 236    ~ bandwidth: 0.920  ~ gradient: 0.373   ~ stopping criterion: 0.390     ~     Iteration: 237    ~ bandwidth: 0.920  ~ gradient: 0.464   ~ stopping criterion: 0.187     ~     Iteration: 238    ~ bandwidth: 0.920  ~ gradient: 0.556   ~ stopping criterion: 0.232     ~     Iteration: 239    ~ bandwidth: 0.920  ~ gradient: -0.110  ~ stopping criterion: 0.278     ~     Iteration: 240    ~ bandwidth: 0.920  ~ gradient: 0.659   ~ stopping criterion: 0.055     ~     Iteration: 241    ~ bandwidth: 0.920  ~ gradient: 0.484   ~ stopping criterion: 0.330     ~     Iteration: 242    ~ bandwidth: 0.920  ~ gradient: -0.097  ~ stopping criterion: 0.242     ~     Iteration: 243    ~ bandwidth: 0.920  ~ gradient: 0.057   ~ stopping criterion: 0.048     ~     Iteration: 244    ~ bandwidth: 0.920  ~ gradient: 0.363   ~ stopping criterion: 0.029     ~     Iteration: 245    ~ bandwidth: 0.920  ~ gradient: 0.289   ~ stopping criterion: 0.182     ~     Iteration: 246    ~ bandwidth: 0.920  ~ gradient: 0.395   ~ stopping criterion: 0.144     ~     Iteration: 247    ~ bandwidth: 0.920  ~ gradient: 0.680   ~ stopping criterion: 0.198     ~     Iteration: 248    ~ bandwidth: 0.920  ~ gradient: 0.806   ~ stopping criterion: 0.340     ~     Iteration: 249    ~ bandwidth: 0.920  ~ gradient: 0.099   ~ stopping criterion: 0.403     ~     Iteration: 250    ~ bandwidth: 0.920  ~ gradient: -0.380  ~ stopping criterion: 0.049     ~     Iteration: 251    ~ bandwidth: 0.920  ~ gradient: 0.626   ~ stopping criterion: 0.190     ~     Iteration: 252    ~ bandwidth: 0.920  ~ gradient: -0.037  ~ stopping criterion: 0.313     ~     Iteration: 253    ~ bandwidth: 0.920  ~ gradient: 0.126   ~ stopping criterion: 0.019     ~     Iteration: 254    ~ bandwidth: 0.920  ~ gradient: 0.300   ~ stopping criterion: 0.063     ~     Iteration: 255    ~ bandwidth: 0.920  ~ gradient: 0.550   ~ stopping criterion: 0.150     ~     Iteration: 256    ~ bandwidth: 0.920  ~ gradient: 0.243   ~ stopping criterion: 0.275     ~     Iteration: 257    ~ bandwidth: 0.920  ~ gradient: 0.069   ~ stopping criterion: 0.121     ~     Iteration: 258    ~ bandwidth: 0.920  ~ gradient: 0.142   ~ stopping criterion: 0.035     ~     Iteration: 259    ~ bandwidth: 0.920  ~ gradient: 0.589   ~ stopping criterion: 0.071     ~     Iteration: 260    ~ bandwidth: 0.920  ~ gradient: -0.335  ~ stopping criterion: 0.295     ~     Iteration: 261    ~ bandwidth: 0.920  ~ gradient: -0.353  ~ stopping criterion: 0.168     ~     Iteration: 262    ~ bandwidth: 0.920  ~ gradient: -0.027  ~ stopping criterion: 0.176     ~     Iteration: 263    ~ bandwidth: 0.920  ~ gradient: 0.183   ~ stopping criterion: 0.013     ~     Iteration: 264    ~ bandwidth: 0.920  ~ gradient: 0.355   ~ stopping criterion: 0.091     ~     Iteration: 265    ~ bandwidth: 0.920  ~ gradient: -0.125  ~ stopping criterion: 0.177     ~     Iteration: 266    ~ bandwidth: 0.920  ~ gradient: 0.344   ~ stopping criterion: 0.062     ~     Iteration: 267    ~ bandwidth: 0.920  ~ gradient: 0.703   ~ stopping criterion: 0.172     ~     Iteration: 268    ~ bandwidth: 0.920  ~ gradient: 0.038   ~ stopping criterion: 0.351     ~     Iteration: 269    ~ bandwidth: 0.920  ~ gradient: -0.208  ~ stopping criterion: 0.019     ~     Iteration: 270    ~ bandwidth: 0.920  ~ gradient: 0.356   ~ stopping criterion: 0.104     ~     Iteration: 271    ~ bandwidth: 0.920  ~ gradient: -0.092  ~ stopping criterion: 0.178     ~     Iteration: 272    ~ bandwidth: 0.920  ~ gradient: -0.226  ~ stopping criterion: 0.046     ~     Iteration: 273    ~ bandwidth: 0.920  ~ gradient: -0.356  ~ stopping criterion: 0.113     ~     Iteration: 274    ~ bandwidth: 0.920  ~ gradient: 0.292   ~ stopping criterion: 0.178     ~     Iteration: 275    ~ bandwidth: 0.920  ~ gradient: 0.050   ~ stopping criterion: 0.146     ~     Iteration: 276    ~ bandwidth: 0.920  ~ gradient: 0.281   ~ stopping criterion: 0.025     ~     Iteration: 277    ~ bandwidth: 0.920  ~ gradient: 0.281   ~ stopping criterion: 0.140     ~     Iteration: 278    ~ bandwidth: 0.920  ~ gradient: 0.200   ~ stopping criterion: 0.140     ~     Iteration: 279    ~ bandwidth: 0.920  ~ gradient: -0.145  ~ stopping criterion: 0.100     ~     Iteration: 280    ~ bandwidth: 0.920  ~ gradient: -0.192  ~ stopping criterion: 0.072     ~     Iteration: 281    ~ bandwidth: 0.920  ~ gradient: 0.368   ~ stopping criterion: 0.096     ~     Iteration: 282    ~ bandwidth: 0.920  ~ gradient: 0.579   ~ stopping criterion: 0.184     ~     Iteration: 283    ~ bandwidth: 0.920  ~ gradient: -0.217  ~ stopping criterion: 0.289     ~     Iteration: 284    ~ bandwidth: 0.920  ~ gradient: 0.484   ~ stopping criterion: 0.108     ~     Iteration: 285    ~ bandwidth: 0.920  ~ gradient: -0.490  ~ stopping criterion: 0.242     ~     Iteration: 286    ~ bandwidth: 0.920  ~ gradient: -0.358  ~ stopping criterion: 0.245     ~     Iteration: 287    ~ bandwidth: 0.920  ~ gradient: 0.130   ~ stopping criterion: 0.179     ~     Iteration: 288    ~ bandwidth: 0.920  ~ gradient: -0.304  ~ stopping criterion: 0.065     ~     Iteration: 289    ~ bandwidth: 0.920  ~ gradient: 0.351   ~ stopping criterion: 0.152     ~     Iteration: 290    ~ bandwidth: 0.920  ~ gradient: -0.409  ~ stopping criterion: 0.176     ~     Iteration: 291    ~ bandwidth: 0.920  ~ gradient: 0.123   ~ stopping criterion: 0.204     ~     Iteration: 292    ~ bandwidth: 0.920  ~ gradient: 0.383   ~ stopping criterion: 0.062     ~     Iteration: 293    ~ bandwidth: 0.920  ~ gradient: 0.508   ~ stopping criterion: 0.191     ~     Iteration: 294    ~ bandwidth: 0.920  ~ gradient: 0.431   ~ stopping criterion: 0.254     ~     Iteration: 295    ~ bandwidth: 0.920  ~ gradient: -0.224  ~ stopping criterion: 0.215     ~     Iteration: 296    ~ bandwidth: 0.920  ~ gradient: 0.018   ~ stopping criterion: 0.112     ~     Iteration: 297    ~ bandwidth: 0.920  ~ gradient: 0.005   ~ stopping criterion: 0.009     ~     Iteration: 298    ~ bandwidth: 0.920  ~ gradient: -0.002  ~ stopping criterion: 0.002     ~     Iteration: 299    ~ bandwidth: 0.920  ~ gradient: 0.035   ~ stopping criterion: 0.001     ~     Iteration: 300    ~ bandwidth: 0.920  ~ gradient: -0.032  ~ stopping criterion: 0.017                                                                                                                                                                                 ~    Stopped at: 300    ~ bandwidth: 0.920  ~ gradient: -0.032  ~ stopping criterion: 0.017

The estimated optimal bandwidth is given by gc1.optimization_outputs['opt_bandwidth'].

# Gradient COBRA with default parameter
print("Estimated bandwidth :" + str(gc1_fit.optimization_outputs['opt_bandwidth']))
Estimated bandwidth :0.9197033463076068

We can look at the learning curve of the algorithm using draw_learning_curve() method.

gc1_fit.draw_learning_curve()

We evaluate the performance of the method on the testing data using MSE and MAPE.

from sklearn.metrics import mean_absolute_percentage_error
y_pred1 = gc1_fit.predict(X_test1)
print(mean_absolute_percentage_error(y_test1, y_pred1))
print(mean_squared_error(y_test1, y_pred1))
0.17824458759870632
110.1637932251487

Let’s look at qq-plot of the predictions and the actual response values using draw_learning_curve() method.

gc1_fit.draw_learning_curve(y_test=y_test1, fig_type='qq')

GradientCOBRA with non-default parameters

GradientCOBRA offers various options to adjust the performance of the method. You can adjust the learning rate of gardient descent or perform grid search to estimate the bandwidth parameter. Moveover, you can control the hyperparameters of the basic estimators to enhance the aggregation performance. This can be done as follows:

  • learning_rate : control the learning rate of gradient descent in estimating the bandwidth parameter
  • speed : the speed of the learning rate.
  • kernel : the kernel function used for the aggregation
  • opt_method : the optimiztaion algorithm for estimating the bandwidth. It can be gradient descent (grad) or grid search (grid).
  • max_iter : maximum iteration of gradient descent.
  • loss_function : control the type of loss function used for optimizing the bandwidth.
  • opt_params : control the optimization algorithm such as adjusting n_cv, start, precision, …
  • estiamtor_list : the list of basic estimators used for the aggregation.
  • estimator_params : controlling the hyperparameters of the basic estimators. It must be a dictionary with (key, dict) = (estimator, dict), i.e. the key must be the name of the basic estimator, and the value is a dictionary containing its hyperparamaters.

We create another object gc2 with non-default parameters, then fit it to the same training data as in the previous example.

gc2 = GradientCOBRA(learning_rate=0.03,
                    speed="linear",
                    kernel='radial',
                    opt_method='grad',
                    loss_function="weighted_mse",
                    max_iter=200,
                    estimator_list=['random_forest', 
                                    'adaboost', 
                                    'knn', 
                                    'lasso', 
                                    'ridge'],
                    estimator_params={
                        'random_forest' : {
                            'n_estimators' : 300,
                            'min_samples_leaf' : 10},

                        'adaboost' : {
                            'n_estimators' : 300,
                            'max_depth' : 10},
                        
                        'knn' : {
                            'n_neighbors' : 30}
                    })

gc2_fit = gc2.fit(X_train1, y_train1)

    * Gradient descent with radial kernel is implemented...
        ~ Initial t = 0:        ~ bandwidth: 2.500  ~ gradient: 0.107   ~ threshold: 1e-05
        ~     Iteration: 1  ~ bandwidth: 2.500  ~ gradient: 0.107   ~ stopping criterion: 0.053     ~     Iteration: 2  ~ bandwidth: 2.470  ~ gradient: 0.107   ~ stopping criterion: 0.059     ~     Iteration: 3  ~ bandwidth: 2.440  ~ gradient: 0.107   ~ stopping criterion: 0.065     ~     Iteration: 4  ~ bandwidth: 2.410  ~ gradient: 0.107   ~ stopping criterion: 0.071     ~     Iteration: 5  ~ bandwidth: 2.380  ~ gradient: 0.133   ~ stopping criterion: 0.053     ~     Iteration: 6  ~ bandwidth: 2.193  ~ gradient: 0.070   ~ stopping criterion: 0.067     ~     Iteration: 7  ~ bandwidth: 2.075  ~ gradient: 0.070   ~ stopping criterion: 0.035     ~     Iteration: 8  ~ bandwidth: 1.936  ~ gradient: 0.089   ~ stopping criterion: 0.035     ~     Iteration: 9  ~ bandwidth: 1.736  ~ gradient: 0.029   ~ stopping criterion: 0.045     ~     Iteration: 10     ~ bandwidth: 1.662  ~ gradient: 0.034   ~ stopping criterion: 0.015     ~     Iteration: 11     ~ bandwidth: 1.567  ~ gradient: -0.003  ~ stopping criterion: 0.017     ~     Iteration: 12     ~ bandwidth: 1.576  ~ gradient: 0.012   ~ stopping criterion: 0.002     ~     Iteration: 13     ~ bandwidth: 1.545  ~ gradient: 0.019   ~ stopping criterion: 0.006     ~     Iteration: 14     ~ bandwidth: 1.501  ~ gradient: -0.006  ~ stopping criterion: 0.009     ~     Iteration: 15     ~ bandwidth: 1.516  ~ gradient: -0.027  ~ stopping criterion: 0.003     ~     Iteration: 16     ~ bandwidth: 1.573  ~ gradient: 0.013   ~ stopping criterion: 0.013     ~     Iteration: 17     ~ bandwidth: 1.544  ~ gradient: 0.008   ~ stopping criterion: 0.006     ~     Iteration: 18     ~ bandwidth: 1.528  ~ gradient: -0.004  ~ stopping criterion: 0.004     ~     Iteration: 19     ~ bandwidth: 1.536  ~ gradient: -0.030  ~ stopping criterion: 0.002     ~     Iteration: 20     ~ bandwidth: 1.588  ~ gradient: -0.020  ~ stopping criterion: 0.015     ~     Iteration: 21     ~ bandwidth: 1.625  ~ gradient: 0.036   ~ stopping criterion: 0.010     ~     Iteration: 22     ~ bandwidth: 1.555  ~ gradient: -0.025  ~ stopping criterion: 0.018     ~     Iteration: 23     ~ bandwidth: 1.596  ~ gradient: 0.041   ~ stopping criterion: 0.013     ~     Iteration: 24     ~ bandwidth: 1.540  ~ gradient: -0.020  ~ stopping criterion: 0.020     ~     Iteration: 25     ~ bandwidth: 1.563  ~ gradient: -0.051  ~ stopping criterion: 0.010     ~     Iteration: 26     ~ bandwidth: 1.611  ~ gradient: 0.002   ~ stopping criterion: 0.025     ~     Iteration: 27     ~ bandwidth: 1.609  ~ gradient: -0.028  ~ stopping criterion: 0.001     ~     Iteration: 28     ~ bandwidth: 1.631  ~ gradient: 0.007   ~ stopping criterion: 0.014     ~     Iteration: 29     ~ bandwidth: 1.626  ~ gradient: 0.003   ~ stopping criterion: 0.003     ~     Iteration: 30     ~ bandwidth: 1.625  ~ gradient: 0.004   ~ stopping criterion: 0.002     ~     Iteration: 31     ~ bandwidth: 1.622  ~ gradient: -0.011  ~ stopping criterion: 0.002     ~     Iteration: 32     ~ bandwidth: 1.628  ~ gradient: 0.011   ~ stopping criterion: 0.005     ~     Iteration: 33     ~ bandwidth: 1.623  ~ gradient: 0.041   ~ stopping criterion: 0.006     ~     Iteration: 34     ~ bandwidth: 1.606  ~ gradient: -0.006  ~ stopping criterion: 0.021     ~     Iteration: 35     ~ bandwidth: 1.609  ~ gradient: -0.015  ~ stopping criterion: 0.003     ~     Iteration: 36     ~ bandwidth: 1.614  ~ gradient: 0.039   ~ stopping criterion: 0.007     ~     Iteration: 37     ~ bandwidth: 1.600  ~ gradient: -0.001  ~ stopping criterion: 0.019     ~     Iteration: 38     ~ bandwidth: 1.600  ~ gradient: -0.013  ~ stopping criterion: 0.001     ~     Iteration: 39     ~ bandwidth: 1.603  ~ gradient: 0.025   ~ stopping criterion: 0.007     ~     Iteration: 40     ~ bandwidth: 1.597  ~ gradient: 0.010   ~ stopping criterion: 0.012     ~     Iteration: 41     ~ bandwidth: 1.595  ~ gradient: 0.000   ~ stopping criterion: 0.005     ~     Iteration: 42     ~ bandwidth: 1.595  ~ gradient: -0.011  ~ stopping criterion: 0.000     ~     Iteration: 43     ~ bandwidth: 1.598  ~ gradient: -0.020  ~ stopping criterion: 0.006     ~     Iteration: 44     ~ bandwidth: 1.601  ~ gradient: -0.023  ~ stopping criterion: 0.010     ~     Iteration: 45     ~ bandwidth: 1.605  ~ gradient: 0.011   ~ stopping criterion: 0.012     ~     Iteration: 46     ~ bandwidth: 1.603  ~ gradient: -0.005  ~ stopping criterion: 0.005     ~     Iteration: 47     ~ bandwidth: 1.604  ~ gradient: -0.015  ~ stopping criterion: 0.002     ~     Iteration: 48     ~ bandwidth: 1.606  ~ gradient: -0.007  ~ stopping criterion: 0.008     ~     Iteration: 49     ~ bandwidth: 1.607  ~ gradient: 0.022   ~ stopping criterion: 0.004     ~     Iteration: 50     ~ bandwidth: 1.604  ~ gradient: -0.002  ~ stopping criterion: 0.011     ~     Iteration: 51     ~ bandwidth: 1.604  ~ gradient: -0.001  ~ stopping criterion: 0.001     ~     Iteration: 52     ~ bandwidth: 1.604  ~ gradient: 0.026   ~ stopping criterion: 0.001     ~     Iteration: 53     ~ bandwidth: 1.602  ~ gradient: 0.039   ~ stopping criterion: 0.013     ~     Iteration: 54     ~ bandwidth: 1.599  ~ gradient: 0.018   ~ stopping criterion: 0.019     ~     Iteration: 55     ~ bandwidth: 1.598  ~ gradient: 0.020   ~ stopping criterion: 0.009     ~     Iteration: 56     ~ bandwidth: 1.597  ~ gradient: 0.039   ~ stopping criterion: 0.010     ~     Iteration: 57     ~ bandwidth: 1.594  ~ gradient: 0.043   ~ stopping criterion: 0.020     ~     Iteration: 58     ~ bandwidth: 1.591  ~ gradient: 0.008   ~ stopping criterion: 0.021     ~     Iteration: 59     ~ bandwidth: 1.590  ~ gradient: -0.007  ~ stopping criterion: 0.004     ~     Iteration: 60     ~ bandwidth: 1.590  ~ gradient: 0.008   ~ stopping criterion: 0.003     ~     Iteration: 61     ~ bandwidth: 1.590  ~ gradient: 0.022   ~ stopping criterion: 0.004     ~     Iteration: 62     ~ bandwidth: 1.589  ~ gradient: -0.000  ~ stopping criterion: 0.011     ~     Iteration: 63     ~ bandwidth: 1.589  ~ gradient: -0.006  ~ stopping criterion: 0.000     ~     Iteration: 64     ~ bandwidth: 1.589  ~ gradient: -0.003  ~ stopping criterion: 0.003     ~     Iteration: 65     ~ bandwidth: 1.589  ~ gradient: 0.012   ~ stopping criterion: 0.001     ~     Iteration: 66     ~ bandwidth: 1.589  ~ gradient: 0.010   ~ stopping criterion: 0.006     ~     Iteration: 67     ~ bandwidth: 1.588  ~ gradient: 0.026   ~ stopping criterion: 0.005     ~     Iteration: 68     ~ bandwidth: 1.587  ~ gradient: 0.023   ~ stopping criterion: 0.013     ~     Iteration: 69     ~ bandwidth: 1.587  ~ gradient: 0.019   ~ stopping criterion: 0.012     ~     Iteration: 70     ~ bandwidth: 1.586  ~ gradient: -0.013  ~ stopping criterion: 0.010     ~     Iteration: 71     ~ bandwidth: 1.586  ~ gradient: -0.004  ~ stopping criterion: 0.006     ~     Iteration: 72     ~ bandwidth: 1.586  ~ gradient: -0.003  ~ stopping criterion: 0.002     ~     Iteration: 73     ~ bandwidth: 1.587  ~ gradient: 0.033   ~ stopping criterion: 0.002     ~     Iteration: 74     ~ bandwidth: 1.585  ~ gradient: -0.006  ~ stopping criterion: 0.016     ~     Iteration: 75     ~ bandwidth: 1.586  ~ gradient: -0.004  ~ stopping criterion: 0.003     ~     Iteration: 76     ~ bandwidth: 1.586  ~ gradient: -0.001  ~ stopping criterion: 0.002     ~     Iteration: 77     ~ bandwidth: 1.586  ~ gradient: -0.011  ~ stopping criterion: 0.001     ~     Iteration: 78     ~ bandwidth: 1.586  ~ gradient: 0.006   ~ stopping criterion: 0.005     ~     Iteration: 79     ~ bandwidth: 1.586  ~ gradient: 0.005   ~ stopping criterion: 0.003     ~     Iteration: 80     ~ bandwidth: 1.586  ~ gradient: -0.029  ~ stopping criterion: 0.002     ~     Iteration: 81     ~ bandwidth: 1.586  ~ gradient: -0.005  ~ stopping criterion: 0.014     ~     Iteration: 82     ~ bandwidth: 1.586  ~ gradient: -0.008  ~ stopping criterion: 0.003     ~     Iteration: 83     ~ bandwidth: 1.586  ~ gradient: -0.005  ~ stopping criterion: 0.004     ~     Iteration: 84     ~ bandwidth: 1.587  ~ gradient: -0.019  ~ stopping criterion: 0.003     ~     Iteration: 85     ~ bandwidth: 1.587  ~ gradient: 0.021   ~ stopping criterion: 0.009     ~     Iteration: 86     ~ bandwidth: 1.586  ~ gradient: 0.023   ~ stopping criterion: 0.010     ~     Iteration: 87     ~ bandwidth: 1.586  ~ gradient: 0.003   ~ stopping criterion: 0.011     ~     Iteration: 88     ~ bandwidth: 1.586  ~ gradient: 0.008   ~ stopping criterion: 0.002     ~     Iteration: 89     ~ bandwidth: 1.586  ~ gradient: 0.014   ~ stopping criterion: 0.004     ~     Iteration: 90     ~ bandwidth: 1.586  ~ gradient: -0.023  ~ stopping criterion: 0.007     ~     Iteration: 91     ~ bandwidth: 1.586  ~ gradient: -0.000  ~ stopping criterion: 0.012     ~     Iteration: 92     ~ bandwidth: 1.586  ~ gradient: 0.018   ~ stopping criterion: 0.000                                                                                                                                                                                 ~    Stopped at: 92     ~ bandwidth: 1.586  ~ gradient: 0.018   ~ stopping criterion: 0.000

Now, let’s compare it to the previous example.

print("Estimated bandwidth :" + str(gc2_fit.optimization_outputs['opt_bandwidth']))
gc2_fit.draw_learning_curve()
Estimated bandwidth :1.5859882672261734

Compare MSE and MAPE.

y_pred2 = gc2_fit.predict(X_test1)
print(mean_absolute_percentage_error(y_test1, y_pred2))
print(mean_squared_error(y_test1, y_pred2))
0.3526384948601235
43.787418972430785

Compare qq-plot.

gc2_fit.draw_learning_curve(y_test=y_test1, fig_type='qq')

Real dataset

We look at the California housing dataset from sklearn.datasets module. To illustrate the idea, we only work with the first \(1000\) observations.

from sklearn.datasets import fetch_california_housing
data = fetch_california_housing()
X_real, y_real = data['data'], data['target']

X_train_real, X_test_real, y_train_real, y_test_real = train_test_split(X_real[:1000,:], y_real[:1000], test_size=0.2)
print('shape: x_train = {} , x_train = {} , y_train = {} , y_test = {}'.format(X_train_real.shape, X_test_real.shape, y_train_real.shape, y_test_real.shape))
shape: x_train = (800, 8) , x_train = (200, 8) , y_train = (800,) , y_test = (200,)

We gave some random parameters to the method as follows.

gc_real = GradientCOBRA(opt_method="grad",
                        estimator_list=['random_forest', 'knn', 'ridge', 'lasso'],
                        estimator_params={
                                'random_forest' : {'n_estimators': 300},
                                'knn' : {'n_neighbors' : 30}
                        })
gc_real_fit = gc_real.fit(X_train_real, y_train_real)

    * Gradient descent with radial kernel is implemented...
        ~ Initial t = 0:        ~ bandwidth: 2.500  ~ gradient: -0.000  ~ threshold: 1e-05
        ~     Iteration: 1  ~ bandwidth: 2.510  ~ gradient: -0.000  ~ stopping criterion: 0.002     ~     Iteration: 2  ~ bandwidth: 2.510  ~ gradient: -0.000  ~ stopping criterion: 0.002     ~     Iteration: 3  ~ bandwidth: 2.510  ~ gradient: -0.000  ~ stopping criterion: 0.002     ~     Iteration: 4  ~ bandwidth: 2.510  ~ gradient: -0.000  ~ stopping criterion: 0.002     ~     Iteration: 5  ~ bandwidth: 2.510  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 6  ~ bandwidth: 2.517  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 7  ~ bandwidth: 2.527  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 8  ~ bandwidth: 2.535  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 9  ~ bandwidth: 2.538  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 10     ~ bandwidth: 2.542  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 11     ~ bandwidth: 2.544  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 12     ~ bandwidth: 2.542  ~ gradient: -0.001  ~ stopping criterion: 0.000     ~     Iteration: 13     ~ bandwidth: 2.551  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 14     ~ bandwidth: 2.557  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 15     ~ bandwidth: 2.560  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 16     ~ bandwidth: 2.564  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 17     ~ bandwidth: 2.564  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 18     ~ bandwidth: 2.566  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 19     ~ bandwidth: 2.570  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 20     ~ bandwidth: 2.573  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 21     ~ bandwidth: 2.576  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 22     ~ bandwidth: 2.578  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 23     ~ bandwidth: 2.580  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 24     ~ bandwidth: 2.582  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 25     ~ bandwidth: 2.583  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 26     ~ bandwidth: 2.586  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 27     ~ bandwidth: 2.587  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 28     ~ bandwidth: 2.587  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 29     ~ bandwidth: 2.589  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 30     ~ bandwidth: 2.589  ~ gradient: -0.000  ~ stopping criterion: 0.000                                                                                                                                                                                 ~    Stopped at: 30     ~ bandwidth: 2.590  ~ gradient: -0.000  ~ stopping criterion: 0.000

Now, let’s look at the obtained bandwidth and the optimization result.

print("Optimal bandwidth: {}".format(gc_real_fit.optimization_outputs['opt_bandwidth']))
gc_real_fit.draw_learning_curve()
Optimal bandwidth: 2.5897746339888927

We look at the numerical and graphical performance.

y_pred_real = gc_real_fit.predict(X_test_real)
print(mean_absolute_percentage_error(y_test_real, y_pred_real))
print(mean_squared_error(y_test_real, y_pred_real))
gc_real_fit.draw_learning_curve(y_test=y_test_real, fig_type='qq')
0.1688280221166909
0.18541016068153693

Compare to Adaboost

We campare the fitted method on California data with Adaboost method.

from sklearn.ensemble import AdaBoostRegressor

ada = AdaBoostRegressor(n_estimators=1000)
ada_fit = ada.fit(X_train_real, y_train_real)
ada_pred = ada_fit.predict(X_test_real)
print(mean_absolute_percentage_error(y_test_real, ada_pred))
print(mean_squared_error(y_test_real, ada_pred))
0.2979434345988137
0.2985101975884548

Pretrained basic estimators

An interesting application of consensual aggregation methods is having pretrained estimators from some source, then applying them on other testing data (not necessarily from the same source, having the same input structure is enough). The only requirement is the predictability of the basic estimators on new the observations, and only the predicted features are used in the aggregation. Here, we build pretrained estimators including XGBoost then aggregate it to some sklearn basic estimators.

We first split the training data into two parts: \(X_k\) for building basic estimators, and \(X_\ell\) for aggregation. We use the constructed estimators to predict the test data and only those predictions are used in the final predictions.

import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.tree import DecisionTreeRegressor

# X_k and X_l split
id_k = np.random.permutation(range(len(y_train_real)))
k =int(.5 * len(y_train_real))
X_k, X_l, y_k, y_l = X_train_real[id_k[:k],:], X_train_real[id_k[k:],:], y_train_real[id_k[:k]], y_train_real[id_k[k:]]

# Building basic estiators
rf_real = RandomForestRegressor(n_estimators=300).fit(X_k, y_k)
lm_real = LinearRegression().fit(X_k, y_k)
knn_real = KNeighborsRegressor(n_neighbors=10).fit(X_k, y_k)
tr_real = DecisionTreeRegressor(min_samples_leaf=5).fit(X_k, y_k)

# External XGBoost estiator
import xgboost 
xgb = xgboost.XGBRegressor(n_estimators = 500)
xgb_real = xgb.fit(X_k, y_k)

# All pretrained estimators
basic_estimators = (rf_real, lm_real, knn_real, tr_real, xgb_real)

# Predicted features on X_l for aggregation
pred_feature_l = np.column_stack([est.predict(X_l) for est in basic_estimators])

# Predicted features on Testing data
pred_feature_test = np.column_stack([est.predict(X_test_real) for est in basic_estimators])

To fit the aggregation method on predicted features (NOT the input data), we have to set argument as_predictions = True. This tells the fit method not to build any basic estimators on the given input (which is already the predictions), and the optimization method is performed directly on \(X\).

gc3_fit = GradientCOBRA().fit(X=pred_feature_l,
                              y=y_l,
                              as_predictions=True)

    * Gradient descent with radial kernel is implemented...
        ~ Initial t = 0:        ~ bandwidth: 2.500  ~ gradient: 0.011   ~ threshold: 1e-05
        ~     Iteration: 1  ~ bandwidth: 2.490  ~ gradient: 0.011   ~ stopping criterion: 0.008     ~     Iteration: 2  ~ bandwidth: 2.490  ~ gradient: 0.011   ~ stopping criterion: 0.008     ~     Iteration: 3  ~ bandwidth: 2.490  ~ gradient: 0.011   ~ stopping criterion: 0.008     ~     Iteration: 4  ~ bandwidth: 2.490  ~ gradient: 0.011   ~ stopping criterion: 0.008     ~     Iteration: 5  ~ bandwidth: 2.490  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 6  ~ bandwidth: 2.480  ~ gradient: 0.011   ~ stopping criterion: 0.006     ~     Iteration: 7  ~ bandwidth: 2.470  ~ gradient: 0.011   ~ stopping criterion: 0.006     ~     Iteration: 8  ~ bandwidth: 2.460  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 9  ~ bandwidth: 2.450  ~ gradient: 0.011   ~ stopping criterion: 0.006     ~     Iteration: 10     ~ bandwidth: 2.440  ~ gradient: 0.011   ~ stopping criterion: 0.006     ~     Iteration: 11     ~ bandwidth: 2.431  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 12     ~ bandwidth: 2.421  ~ gradient: 0.011   ~ stopping criterion: 0.006     ~     Iteration: 13     ~ bandwidth: 2.411  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 14     ~ bandwidth: 2.400  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 15     ~ bandwidth: 2.390  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 16     ~ bandwidth: 2.380  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 17     ~ bandwidth: 2.369  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 18     ~ bandwidth: 2.359  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 19     ~ bandwidth: 2.349  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 20     ~ bandwidth: 2.338  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 21     ~ bandwidth: 2.328  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 22     ~ bandwidth: 2.317  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 23     ~ bandwidth: 2.307  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 24     ~ bandwidth: 2.296  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 25     ~ bandwidth: 2.285  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 26     ~ bandwidth: 2.275  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 27     ~ bandwidth: 2.264  ~ gradient: 0.013   ~ stopping criterion: 0.006     ~     Iteration: 28     ~ bandwidth: 2.253  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 29     ~ bandwidth: 2.242  ~ gradient: 0.013   ~ stopping criterion: 0.006     ~     Iteration: 30     ~ bandwidth: 2.231  ~ gradient: 0.013   ~ stopping criterion: 0.006     ~     Iteration: 31     ~ bandwidth: 2.220  ~ gradient: 0.013   ~ stopping criterion: 0.006     ~     Iteration: 32     ~ bandwidth: 2.208  ~ gradient: 0.013   ~ stopping criterion: 0.007     ~     Iteration: 33     ~ bandwidth: 2.197  ~ gradient: 0.013   ~ stopping criterion: 0.006     ~     Iteration: 34     ~ bandwidth: 2.186  ~ gradient: 0.013   ~ stopping criterion: 0.007     ~     Iteration: 35     ~ bandwidth: 2.175  ~ gradient: 0.013   ~ stopping criterion: 0.007     ~     Iteration: 36     ~ bandwidth: 2.163  ~ gradient: 0.013   ~ stopping criterion: 0.007     ~     Iteration: 37     ~ bandwidth: 2.152  ~ gradient: 0.014   ~ stopping criterion: 0.006     ~     Iteration: 38     ~ bandwidth: 2.140  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 39     ~ bandwidth: 2.128  ~ gradient: 0.013   ~ stopping criterion: 0.007     ~     Iteration: 40     ~ bandwidth: 2.117  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 41     ~ bandwidth: 2.105  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 42     ~ bandwidth: 2.093  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 43     ~ bandwidth: 2.081  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 44     ~ bandwidth: 2.069  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 45     ~ bandwidth: 2.056  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 46     ~ bandwidth: 2.044  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 47     ~ bandwidth: 2.032  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 48     ~ bandwidth: 2.020  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 49     ~ bandwidth: 2.007  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 50     ~ bandwidth: 1.995  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 51     ~ bandwidth: 1.983  ~ gradient: 0.015   ~ stopping criterion: 0.007     ~     Iteration: 52     ~ bandwidth: 1.970  ~ gradient: 0.015   ~ stopping criterion: 0.007     ~     Iteration: 53     ~ bandwidth: 1.957  ~ gradient: 0.015   ~ stopping criterion: 0.007     ~     Iteration: 54     ~ bandwidth: 1.944  ~ gradient: 0.015   ~ stopping criterion: 0.007     ~     Iteration: 55     ~ bandwidth: 1.931  ~ gradient: 0.015   ~ stopping criterion: 0.007     ~     Iteration: 56     ~ bandwidth: 1.918  ~ gradient: 0.015   ~ stopping criterion: 0.007     ~     Iteration: 57     ~ bandwidth: 1.905  ~ gradient: 0.015   ~ stopping criterion: 0.008     ~     Iteration: 58     ~ bandwidth: 1.892  ~ gradient: 0.015   ~ stopping criterion: 0.008     ~     Iteration: 59     ~ bandwidth: 1.879  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 60     ~ bandwidth: 1.865  ~ gradient: 0.015   ~ stopping criterion: 0.008     ~     Iteration: 61     ~ bandwidth: 1.852  ~ gradient: 0.015   ~ stopping criterion: 0.008     ~     Iteration: 62     ~ bandwidth: 1.838  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 63     ~ bandwidth: 1.825  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 64     ~ bandwidth: 1.811  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 65     ~ bandwidth: 1.797  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 66     ~ bandwidth: 1.783  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 67     ~ bandwidth: 1.769  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 68     ~ bandwidth: 1.755  ~ gradient: 0.017   ~ stopping criterion: 0.008     ~     Iteration: 69     ~ bandwidth: 1.740  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 70     ~ bandwidth: 1.726  ~ gradient: 0.017   ~ stopping criterion: 0.008     ~     Iteration: 71     ~ bandwidth: 1.712  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 72     ~ bandwidth: 1.697  ~ gradient: 0.017   ~ stopping criterion: 0.008     ~     Iteration: 73     ~ bandwidth: 1.683  ~ gradient: 0.017   ~ stopping criterion: 0.008     ~     Iteration: 74     ~ bandwidth: 1.668  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 75     ~ bandwidth: 1.653  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 76     ~ bandwidth: 1.638  ~ gradient: 0.017   ~ stopping criterion: 0.008     ~     Iteration: 77     ~ bandwidth: 1.623  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 78     ~ bandwidth: 1.608  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 79     ~ bandwidth: 1.593  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 80     ~ bandwidth: 1.578  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 81     ~ bandwidth: 1.563  ~ gradient: 0.018   ~ stopping criterion: 0.009     ~     Iteration: 82     ~ bandwidth: 1.548  ~ gradient: 0.018   ~ stopping criterion: 0.009     ~     Iteration: 83     ~ bandwidth: 1.533  ~ gradient: 0.018   ~ stopping criterion: 0.009     ~     Iteration: 84     ~ bandwidth: 1.517  ~ gradient: 0.018   ~ stopping criterion: 0.009     ~     Iteration: 85     ~ bandwidth: 1.502  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 86     ~ bandwidth: 1.487  ~ gradient: 0.018   ~ stopping criterion: 0.009     ~     Iteration: 87     ~ bandwidth: 1.472  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 88     ~ bandwidth: 1.456  ~ gradient: 0.018   ~ stopping criterion: 0.009     ~     Iteration: 89     ~ bandwidth: 1.441  ~ gradient: 0.018   ~ stopping criterion: 0.009     ~     Iteration: 90     ~ bandwidth: 1.425  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 91     ~ bandwidth: 1.410  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 92     ~ bandwidth: 1.395  ~ gradient: 0.018   ~ stopping criterion: 0.009     ~     Iteration: 93     ~ bandwidth: 1.380  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 94     ~ bandwidth: 1.365  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 95     ~ bandwidth: 1.350  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 96     ~ bandwidth: 1.335  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 97     ~ bandwidth: 1.320  ~ gradient: 0.017   ~ stopping criterion: 0.008     ~     Iteration: 98     ~ bandwidth: 1.305  ~ gradient: 0.017   ~ stopping criterion: 0.009     ~     Iteration: 99     ~ bandwidth: 1.291  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 100    ~ bandwidth: 1.276  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 101    ~ bandwidth: 1.262  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 102    ~ bandwidth: 1.248  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 103    ~ bandwidth: 1.234  ~ gradient: 0.016   ~ stopping criterion: 0.008     ~     Iteration: 104    ~ bandwidth: 1.220  ~ gradient: 0.015   ~ stopping criterion: 0.008     ~     Iteration: 105    ~ bandwidth: 1.207  ~ gradient: 0.015   ~ stopping criterion: 0.008     ~     Iteration: 106    ~ bandwidth: 1.193  ~ gradient: 0.015   ~ stopping criterion: 0.008     ~     Iteration: 107    ~ bandwidth: 1.180  ~ gradient: 0.015   ~ stopping criterion: 0.008     ~     Iteration: 108    ~ bandwidth: 1.167  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 109    ~ bandwidth: 1.155  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 110    ~ bandwidth: 1.142  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 111    ~ bandwidth: 1.130  ~ gradient: 0.014   ~ stopping criterion: 0.007     ~     Iteration: 112    ~ bandwidth: 1.118  ~ gradient: 0.013   ~ stopping criterion: 0.007     ~     Iteration: 113    ~ bandwidth: 1.107  ~ gradient: 0.013   ~ stopping criterion: 0.007     ~     Iteration: 114    ~ bandwidth: 1.095  ~ gradient: 0.013   ~ stopping criterion: 0.006     ~     Iteration: 115    ~ bandwidth: 1.084  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 116    ~ bandwidth: 1.074  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 117    ~ bandwidth: 1.063  ~ gradient: 0.012   ~ stopping criterion: 0.006     ~     Iteration: 118    ~ bandwidth: 1.053  ~ gradient: 0.011   ~ stopping criterion: 0.006     ~     Iteration: 119    ~ bandwidth: 1.043  ~ gradient: 0.011   ~ stopping criterion: 0.006     ~     Iteration: 120    ~ bandwidth: 1.034  ~ gradient: 0.011   ~ stopping criterion: 0.005     ~     Iteration: 121    ~ bandwidth: 1.024  ~ gradient: 0.010   ~ stopping criterion: 0.005     ~     Iteration: 122    ~ bandwidth: 1.016  ~ gradient: 0.010   ~ stopping criterion: 0.005     ~     Iteration: 123    ~ bandwidth: 1.007  ~ gradient: 0.010   ~ stopping criterion: 0.005     ~     Iteration: 124    ~ bandwidth: 0.999  ~ gradient: 0.009   ~ stopping criterion: 0.005     ~     Iteration: 125    ~ bandwidth: 0.991  ~ gradient: 0.009   ~ stopping criterion: 0.005     ~     Iteration: 126    ~ bandwidth: 0.983  ~ gradient: 0.009   ~ stopping criterion: 0.005     ~     Iteration: 127    ~ bandwidth: 0.975  ~ gradient: 0.008   ~ stopping criterion: 0.004     ~     Iteration: 128    ~ bandwidth: 0.968  ~ gradient: 0.008   ~ stopping criterion: 0.004     ~     Iteration: 129    ~ bandwidth: 0.961  ~ gradient: 0.008   ~ stopping criterion: 0.004     ~     Iteration: 130    ~ bandwidth: 0.954  ~ gradient: 0.007   ~ stopping criterion: 0.004     ~     Iteration: 131    ~ bandwidth: 0.948  ~ gradient: 0.007   ~ stopping criterion: 0.004     ~     Iteration: 132    ~ bandwidth: 0.942  ~ gradient: 0.007   ~ stopping criterion: 0.004     ~     Iteration: 133    ~ bandwidth: 0.936  ~ gradient: 0.006   ~ stopping criterion: 0.003     ~     Iteration: 134    ~ bandwidth: 0.931  ~ gradient: 0.007   ~ stopping criterion: 0.003     ~     Iteration: 135    ~ bandwidth: 0.925  ~ gradient: 0.006   ~ stopping criterion: 0.003     ~     Iteration: 136    ~ bandwidth: 0.920  ~ gradient: 0.006   ~ stopping criterion: 0.003     ~     Iteration: 137    ~ bandwidth: 0.915  ~ gradient: 0.006   ~ stopping criterion: 0.003     ~     Iteration: 138    ~ bandwidth: 0.910  ~ gradient: 0.006   ~ stopping criterion: 0.003     ~     Iteration: 139    ~ bandwidth: 0.905  ~ gradient: 0.005   ~ stopping criterion: 0.003     ~     Iteration: 140    ~ bandwidth: 0.900  ~ gradient: 0.005   ~ stopping criterion: 0.003     ~     Iteration: 141    ~ bandwidth: 0.896  ~ gradient: 0.005   ~ stopping criterion: 0.002     ~     Iteration: 142    ~ bandwidth: 0.892  ~ gradient: 0.005   ~ stopping criterion: 0.002     ~     Iteration: 143    ~ bandwidth: 0.888  ~ gradient: 0.004   ~ stopping criterion: 0.002     ~     Iteration: 144    ~ bandwidth: 0.884  ~ gradient: 0.004   ~ stopping criterion: 0.002     ~     Iteration: 145    ~ bandwidth: 0.881  ~ gradient: 0.004   ~ stopping criterion: 0.002     ~     Iteration: 146    ~ bandwidth: 0.877  ~ gradient: 0.004   ~ stopping criterion: 0.002     ~     Iteration: 147    ~ bandwidth: 0.874  ~ gradient: 0.004   ~ stopping criterion: 0.002     ~     Iteration: 148    ~ bandwidth: 0.871  ~ gradient: 0.003   ~ stopping criterion: 0.002     ~     Iteration: 149    ~ bandwidth: 0.868  ~ gradient: 0.004   ~ stopping criterion: 0.002     ~     Iteration: 150    ~ bandwidth: 0.865  ~ gradient: 0.003   ~ stopping criterion: 0.002     ~     Iteration: 151    ~ bandwidth: 0.862  ~ gradient: 0.003   ~ stopping criterion: 0.002     ~     Iteration: 152    ~ bandwidth: 0.860  ~ gradient: 0.003   ~ stopping criterion: 0.001     ~     Iteration: 153    ~ bandwidth: 0.857  ~ gradient: 0.003   ~ stopping criterion: 0.002     ~     Iteration: 154    ~ bandwidth: 0.854  ~ gradient: 0.003   ~ stopping criterion: 0.001     ~     Iteration: 155    ~ bandwidth: 0.852  ~ gradient: 0.002   ~ stopping criterion: 0.001     ~     Iteration: 156    ~ bandwidth: 0.850  ~ gradient: 0.003   ~ stopping criterion: 0.001     ~     Iteration: 157    ~ bandwidth: 0.848  ~ gradient: 0.002   ~ stopping criterion: 0.001     ~     Iteration: 158    ~ bandwidth: 0.846  ~ gradient: 0.002   ~ stopping criterion: 0.001     ~     Iteration: 159    ~ bandwidth: 0.844  ~ gradient: 0.002   ~ stopping criterion: 0.001     ~     Iteration: 160    ~ bandwidth: 0.842  ~ gradient: 0.002   ~ stopping criterion: 0.001     ~     Iteration: 161    ~ bandwidth: 0.841  ~ gradient: 0.002   ~ stopping criterion: 0.001     ~     Iteration: 162    ~ bandwidth: 0.839  ~ gradient: 0.002   ~ stopping criterion: 0.001     ~     Iteration: 163    ~ bandwidth: 0.837  ~ gradient: 0.002   ~ stopping criterion: 0.001     ~     Iteration: 164    ~ bandwidth: 0.836  ~ gradient: 0.001   ~ stopping criterion: 0.001     ~     Iteration: 165    ~ bandwidth: 0.834  ~ gradient: 0.001   ~ stopping criterion: 0.001     ~     Iteration: 166    ~ bandwidth: 0.833  ~ gradient: 0.001   ~ stopping criterion: 0.001     ~     Iteration: 167    ~ bandwidth: 0.832  ~ gradient: 0.001   ~ stopping criterion: 0.001     ~     Iteration: 168    ~ bandwidth: 0.831  ~ gradient: 0.001   ~ stopping criterion: 0.001     ~     Iteration: 169    ~ bandwidth: 0.830  ~ gradient: 0.001   ~ stopping criterion: 0.001     ~     Iteration: 170    ~ bandwidth: 0.828  ~ gradient: 0.001   ~ stopping criterion: 0.001     ~     Iteration: 171    ~ bandwidth: 0.827  ~ gradient: 0.001   ~ stopping criterion: 0.001     ~     Iteration: 172    ~ bandwidth: 0.826  ~ gradient: 0.001   ~ stopping criterion: 0.001     ~     Iteration: 173    ~ bandwidth: 0.825  ~ gradient: 0.001   ~ stopping criterion: 0.001     ~     Iteration: 174    ~ bandwidth: 0.824  ~ gradient: 0.001   ~ stopping criterion: 0.001     ~     Iteration: 175    ~ bandwidth: 0.824  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 176    ~ bandwidth: 0.823  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 177    ~ bandwidth: 0.822  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 178    ~ bandwidth: 0.821  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 179    ~ bandwidth: 0.820  ~ gradient: 0.001   ~ stopping criterion: 0.001     ~     Iteration: 180    ~ bandwidth: 0.820  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 181    ~ bandwidth: 0.819  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 182    ~ bandwidth: 0.819  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 183    ~ bandwidth: 0.818  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 184    ~ bandwidth: 0.818  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 185    ~ bandwidth: 0.817  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 186    ~ bandwidth: 0.816  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 187    ~ bandwidth: 0.816  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 188    ~ bandwidth: 0.816  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 189    ~ bandwidth: 0.815  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 190    ~ bandwidth: 0.815  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 191    ~ bandwidth: 0.814  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 192    ~ bandwidth: 0.814  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 193    ~ bandwidth: 0.813  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 194    ~ bandwidth: 0.813  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 195    ~ bandwidth: 0.813  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 196    ~ bandwidth: 0.813  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 197    ~ bandwidth: 0.813  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 198    ~ bandwidth: 0.812  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 199    ~ bandwidth: 0.812  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 200    ~ bandwidth: 0.812  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 201    ~ bandwidth: 0.811  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 202    ~ bandwidth: 0.811  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 203    ~ bandwidth: 0.811  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 204    ~ bandwidth: 0.811  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 205    ~ bandwidth: 0.811  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 206    ~ bandwidth: 0.811  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 207    ~ bandwidth: 0.811  ~ gradient: 0.001   ~ stopping criterion: 0.000     ~     Iteration: 208    ~ bandwidth: 0.811  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 209    ~ bandwidth: 0.811  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 210    ~ bandwidth: 0.811  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 211    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 212    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 213    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 214    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 215    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 216    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 217    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 218    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 219    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 220    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 221    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 222    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 223    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 224    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 225    ~ bandwidth: 0.810  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 226    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 227    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 228    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 229    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 230    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 231    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 232    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 233    ~ bandwidth: 0.810  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 234    ~ bandwidth: 0.809  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 235    ~ bandwidth: 0.809  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 236    ~ bandwidth: 0.809  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 237    ~ bandwidth: 0.809  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 238    ~ bandwidth: 0.809  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 239    ~ bandwidth: 0.809  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 240    ~ bandwidth: 0.809  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 241    ~ bandwidth: 0.809  ~ gradient: 0.000   ~ stopping criterion: 0.000     ~     Iteration: 242    ~ bandwidth: 0.809  ~ gradient: -0.000  ~ stopping criterion: 0.000     ~     Iteration: 243    ~ bandwidth: 0.809  ~ gradient: -0.000  ~ stopping criterion: 0.000                                                                                                                                                                                 ~    Stopped at: 243    ~ bandwidth: 0.809  ~ gradient: -0.000  ~ stopping criterion: 0.000

We look at the optimization algorithm performance.

print("Estimated bandwidth :" + str(gc3_fit.optimization_outputs['opt_bandwidth']))
gc3_fit.draw_learning_curve()
Estimated bandwidth :0.8092867317195862

Now, let’s look at the result.

pred_add = gc3_fit.predict(pred_feature_test)
print(mean_absolute_percentage_error(y_test_real, pred_add))
print(mean_squared_error(y_test_real, pred_add))
0.156300234677995
0.16485466823568465