Abstract
The following deliverable investigated and predicted superconducting temperatures for a variety of materials. Linear regression with regularization was the primary statistical method used. A comparison of different models and their performance resulted in a recommendation of a combined model that used L1 regularization for feature selection followed by L2 regularization to fit the model. Feature importance, highlighting the most important features contributing to the critical temperature for superconducting, was then discussed.
This case study explores linear predictive models and feature importance in the superconductor domain. Superconductors are materials that provide minimal or no resistance to electrical current. These materials are used in a variety of applications such as delivering fast-speed connections between computer microchips. The goal of this study is to predict the temperature at which a material will become a superconductor as well as identify important features that contribute to this critical temperature.
This dataset contains 21,263 data points describing material properties and their composition as independent variables. The dependent variable the model predicted was the critical superconducting temperature.
Linear regression with two types of regularization, L1 and L2, was used to build models for this case study. Linear regression is a widely used, efficient, and highly interpretable algorithm used by many scientists to model the relationship between dependent and independent variables. Typically, linear regression uses mean squared error (MSE) as a loss function. Loss functions are an iterative step where the regression calculates the prediction error and attempts to minimize it. This case study used negative mean squared error in order to maintain consistency with the Scikit-learn application programming interface (API).
The following is the formula used by a typical multiple regression (1), where $ m_0 $ represents the intercept and $ m_n \cdot x_n $ represents the slopes:
The dataset contains 158 independent variables, meaning that without any feature selection there would be 158 slopes. In simple terms, regularization adds a penalty to the regression coefficients or slopes. The purpose of this penalty is to prevent overfitting, which happens when the model fits too well on the training data rather than on the problem.
The first regularization type is Lasso or L1. The penalty for this is the absolute value of the coefficients multiplied by Lambda, which controls the strength of the penalty. L1 regularization can also be used as a feature selection tool. The penalty magnitude can be large enough that it can turn the regression coefficient into zero. A regression coefficient of zero indicates that the feature was not important for the relationship between independent and dependent variables.
Penalty Term
$\lambda \sum\limits_{j=0}^k \mid m_j \mid $
Where $\lambda $ is the strength of the penalty. If $\lambda=0 $ then there would be no penalty applied and the original coefficient would be returned.
Complete Formula
$y= m_0 + m_1 \cdot x_1 + m_2 \cdot x_2 + m_3 \cdot x_3... + m_n \cdot x_n + \lambda \sum\limits_{j=0}^k \mid m_j \mid$
The second regularization type is Ridge or L2. The penalty for this is the squared coefficients multiplied by Lambda, which controls the strength of the penalty. Unlike L1 regularization, L2 does not provide feature selection. All features are penalized uniformly but will never reach to zero. In general, L2 is the primary regularization method used to prevent overfitting the model.
Penalty Term
$\lambda \sum\limits_{j=0}^k m_j ^2 $
Where $\lambda $ is the strength of the penalty. If $\lambda=0 $ then there would be no penalty applied and the original coefficient would be returned.
Complete Formula
$y= m_0 + m_1 \cdot x_1 + m_2 \cdot x_2 + m_3 \cdot x_3... + m_n \cdot x_n + \lambda \sum\limits_{j=0}^k m_j ^2 $
There was no missing data in this study.
Within the dataset, 86 features represented elements. Each indicated if a particular element was present or not in the material composition of each superconductor. Nine of these element features contained only zeros (Table 1), indicating the elements were in none of the superconductors in the dataset. These elements ['He', 'Ne', 'Ar', 'Kr', 'Xe', 'Pm', 'Po', 'At', 'Rn'] were deleted from the dataset prior to analysis.
file_2.loc[:, (file_2.sum(axis=0) == 0)].describe()
| He | Ne | Ar | Kr | Xe | Pm | Po | At | Rn | |
|---|---|---|---|---|---|---|---|---|---|
| count | 21263.0 | 21263.0 | 21263.0 | 21263.0 | 21263.0 | 21263.0 | 21263.0 | 21263.0 | 21263.0 |
| mean | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| std | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| min | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 25% | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 50% | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 75% | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| max | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
The dataset included two files. The first described the physical properties of each material. The second represented the material composition of each superconductor, with each column indicating the amount of a particular element in each superconductor. The rows and indexes of both files referred to the same superconductor material, so the two files were merged by index. Next, the redundant "critical_temp" variable was dropped as it was a duplication between the files. Finally, the "material" column from the second file was also dropped because all the data within the material name was also represented in the material composition features.
High correlations can be problematic when interpreting feature importance, which was one of the main goals of this study. When two variables are very highly correlated, it is likely that one is derived from the other. An algorithm was employed to detect and delete one of the variables within each pair that had a correlation greater than 0.95. This idea is borrowed and shown here. Without detailed knowledge of each variable, the decision of which to delete was determined by the order of the variable columns. As a result, 23 of the highly correlated variables were deleted from the dataset.
Variables Deleted: ['wtd_gmean_atomic_mass', 'std_atomic_mass', 'gmean_fie', 'wtd_gmean_fie', 'entropy_fie', 'std_fie', 'wtd_gmean_atomic_radius', 'entropy_atomic_radius', 'wtd_entropy_atomic_radius', 'std_atomic_radius', 'wtd_std_atomic_radius', 'wtd_gmean_Density', 'std_Density', 'std_ElectronAffinity', 'wtd_gmean_FusionHeat', 'std_FusionHeat', 'std_ThermalConductivity', 'wtd_std_ThermalConductivity', 'gmean_Valence', 'wtd_gmean_Valence', 'entropy_Valence', 'wtd_entropy_Valence', 'std_Valence']
An exploration of the distributions of the remaining 136 variables highlighted a lack of normal distribution for many variables. For example, Fig. 1 shows the bimodal distribution of "range_atomic_mass", the left skewed distribution of "range_atomic_radius", and the right skewed distribution of "gmean_Density." This implies a unit variance transformation could result in many outliers and make it more difficult to interpret feature importance. Instead, scaling all variables between the range of (0 to 1) using MixMaxScaler from Scikit-learn was chosen to make feature importance comparison more manageable.
fig, ax = plt.subplots(1,3,figsize=(12,3))
sns.histplot(x = subset_plot['range_atomic_mass'],bins=100, ax=ax[0]).set_title("Distribution for variable range_atomic_mass")
sns.histplot(x = subset_plot['range_atomic_radius'],bins=100, ax=ax[1]).set_title("Distribution for variable range_atomic_radius")
sns.histplot(x = subset_plot['gmean_Density'],bins=100, ax=ax[2]).set_title("Distribution for variable gmean_Density")
plt.tight_layout()
Each model was created using a pipeline. A pipeline allowed for the streamlining of scaling using MinMaxScaler() and fitting the linear models with either L1 or L2 regularization. Pipelines prevent data leakage when using grid search with 10-fold cross validation in order to narrow down the best regularization parameter (alpha) for either L1 or L2. During the initial trials, there was evidence that the data was structured or ordered. To address this, data was first shuffled before 10-fold cross validation to introduce more randomness for each cross validation split. This shuffle reduced the differences between test scores on the cross validation splits. A random state was also set for reproducibility.
A pipeline was created for LASSO under the variable 'pipe_lasso' to tune the alpha hyperparameter and assess predictions. An initial wider grid of alpha values [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 1, 1.5, 2, 2.5, 5] was first tried using 10-fold cross validation. The best alpha value (highest negative mean squared error on mean test score) during the first trial was alpha = 0.1 (Fig. 2).
To further fine-tune alpha, this process was repeated with more refined grids. Grids of alpha values [0.075, 0.085, 0.095, 0.1, 0.105, 0.11, 0.115], and [0.085, 0.086, 0.087, 0.088, 0.089, 0.09] were tried. The ultimate alpha value was 0.089 with a negative mean test score of -379.84.
As previously mentioned, LASSO can be used as feature selection method. To this end, the features with non-zero coefficients under the best performing alpha value of 0.089 were recorded for creating the third model that will be introduced below.
plot_learning_curve(lasso_results,"Lasso")
A pipeline was created for Ridge (l2) under the variable 'pipe_ridge' to tune the alpha hyperparameter and assess predictions. An initial wider grid of alpha values [100, 10, 1, 0.1, 0.01, 0.001] was first tried using 10-fold cross validation. The best alpha value (highest negative mean squared error on mean test score) during the first trial was alpha = 100 (Fig. 3).
To further fine-tune alpha, this process was repeated with more refined grids. Subsequent grids of alpha values [75, 100, 125, 150, 175, 200, 225, 250], [50, 60, 70, 80, 90] and [78, 79, 80, 81, 82, 83, 84] were tried, and the best performing alpha was 82 with a negative mean test score of -394.62.
plot_learning_curve(ridge_results,"Ridge")
The third model used a combination of LASSO and Ridge regression. LASSO provided feature selection and the subset of features were then fit using Ridge regression. Using the LASSO model with the best alpha (0.089), feature selection was conducted by identifying those features with non-zero coefficients. Using these 24 features, a sub dataset was created.
A pipeline was created for Ridge with feature selection named 'pipe_ridge_mm'. An initial wider grid of alpha values [0.01, 0.1, 0.5, 1, 1.25, 1.5, 5, 25, 50, 75] was first tried using 10-fold cross validation. The best alpha value (highest negative mean squared error on mean test score) during the first trial was alpha = 0.1 (Fig. 4). This alpha value was drastically different from the alpha value obtained through the second model using Ridge regression. It seems the required penalty term is much smaller with only important features.
To further fine-tune alpha, this process was repeated with more refined grids. A grid of alpha values [0.05, 0.07, 0.09, 0.1, 0.11, 0.12] was then tried. The best alpha value was alpha = 0.11, with a negative mean test score of -335.33.
Since this model generated the best score during 10-fold cross validation, the coefficients of the variables were then extracted to be analyzed for feature importance.
plot_learning_curve(ridge_subset_results_mm,"Ridge")
A comparison of the three models (Table 2) shows the best performing model was Model 3, which used LASSO for feature selection followed by Ridge to fit the model. Model performance was assessed using negative mean squared error. Best mean test score was calculated using the mean test score (negative mean squared error) from a 10-fold cross validation. Model 3 had the best mean test score of -335.33. The LASSO model was next with a mean test score of -379.84, and Ridge (without feature selection) performed the worst, with a mean test score of -394.62. Model 3 also had the lowest test score standard deviation, indicating the performance among the cross validation splits was consistent.
| Model | Best Alpha | Best Mean Test Score | Test Score Standard Deviation |
|---|---|---|---|
| Model 3: Ridge with Feature Selection | 0.11 | -335.33 | 15.73 |
| Model 1: LASSO | 0.089 | -379.84 | 55.15 |
| Model 2: Ridge | 82 | -394.62 | 55.15 |
The importance of features contributing to the critical superconductor temperature are displayed in Fig. 5. The most important variable is "Ba", or the element barium. This element has a positive relationship with the outcome, meaning that if Ba increases, so too does the critical temperature. The second most important variable is the weighted mean thermal conductivity, which also has a positive relationship with the outcome. Next is the weighted geometric mean thermal conductivity, which has an inverse relationship with the outcome. This means that as the weighted geometric mean thermal conductivity increases, the critical superconductor temperature decreases. The importance of the remaining features and the direction of their relationship to the critical superconductor temperature can be examined below.
sns.catplot(x="coefficient",
y="feature name",
kind="bar",
palette = "light:b_r",
data=feature_importance_laso_ridg_sorted).set(title="Feature Importance Plot")
<seaborn.axisgrid.FacetGrid at 0x7f85f07dc2e0>
<Figure size 864x1152 with 0 Axes>
| Model 3 | LASSO | Ridge |
|---|---|---|
| Ba | Ba | Ba |
| wtd_mean_ThermalConductivity | wtd_mean_ThermalConductivity | wtd_std_Valence |
| wtd_gmean_ThermalConductivity | wtd_gmean_ThermalConductivity | wtd_entropy_atomic_mass |
| Bi | Bi | wtd_mean_ThermalConductivity |
| wtd_std_atomic_mass | wtd_gmean_ElectronAffinity | range_atomic_mass |
Model 3 is the best model for predicting new superconductors and the temperature at which they operate. It had the most accurate estimates of critical temperatures of all the models that were tested, and it also showed the lowest standard deviation in cross validation test scores, indicating the least overfit and best generalization to new data. Model 3 also requires less attributes, and therefore less data, to make predictions. It is interpretable, as all features were scaled so direct comparisons among the magnitudes of the coefficients can be made. The larger the magnitude of the coefficient, the more important the variable. If a coefficient is positive, its relationship with the outcome is positive, and if it is negative, its relationship with the outcome is inverse. The most important features were Barium and weighted mean thermal conductivity.
This model does have a limitation, however. As none of the sample data included nine of the elements, these elements were excluded from our model. Our model would not apply to any new superconductors with these nine elements. Any future improvements to this model could include consultation with a subject matter expert to further optimize the removal of highly correlated variables.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import Lasso
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn.pipeline import Pipeline
from sklearn.linear_model import Ridge
from sklearn.feature_selection import RFE
from sklearn.tree import DecisionTreeRegressor
# options
pd.set_option('display.max_columns', None)
# Using fixed os structure
import os
cwd = os.getcwd()
d = os.path.dirname(cwd)
# d
path_file1 = f"{d}/casestudy1/superconduct/train.csv"
path_file2 = f"{d}/casestudy1/superconduct/unique_m.csv"
file_1 = pd.read_csv(path_file1)
file_2 = pd.read_csv(path_file2)
# Here are the elements from unique_m that have all 0 in their column:
col_to_delete = file_2.loc[:, (file_2.sum(axis=0) == 0)].columns
# delete all zero columns from file_2
file_2.drop(col_to_delete, axis=1, inplace=True)
data = file_1.merge(file_2,left_index=True, right_index=True)
data.drop(["critical_temp_x", "material"], axis=1, inplace=True)
data = data.rename(columns={"critical_temp_y": "critical_temp"})
len(data.columns)
159
data.describe()
| number_of_elements | mean_atomic_mass | wtd_mean_atomic_mass | gmean_atomic_mass | wtd_gmean_atomic_mass | entropy_atomic_mass | wtd_entropy_atomic_mass | range_atomic_mass | wtd_range_atomic_mass | std_atomic_mass | wtd_std_atomic_mass | mean_fie | wtd_mean_fie | gmean_fie | wtd_gmean_fie | entropy_fie | wtd_entropy_fie | range_fie | wtd_range_fie | std_fie | wtd_std_fie | mean_atomic_radius | wtd_mean_atomic_radius | gmean_atomic_radius | wtd_gmean_atomic_radius | entropy_atomic_radius | wtd_entropy_atomic_radius | range_atomic_radius | wtd_range_atomic_radius | std_atomic_radius | wtd_std_atomic_radius | mean_Density | wtd_mean_Density | gmean_Density | wtd_gmean_Density | entropy_Density | wtd_entropy_Density | range_Density | wtd_range_Density | std_Density | wtd_std_Density | mean_ElectronAffinity | wtd_mean_ElectronAffinity | gmean_ElectronAffinity | wtd_gmean_ElectronAffinity | entropy_ElectronAffinity | wtd_entropy_ElectronAffinity | range_ElectronAffinity | wtd_range_ElectronAffinity | std_ElectronAffinity | wtd_std_ElectronAffinity | mean_FusionHeat | wtd_mean_FusionHeat | gmean_FusionHeat | wtd_gmean_FusionHeat | entropy_FusionHeat | wtd_entropy_FusionHeat | range_FusionHeat | wtd_range_FusionHeat | std_FusionHeat | wtd_std_FusionHeat | mean_ThermalConductivity | wtd_mean_ThermalConductivity | gmean_ThermalConductivity | wtd_gmean_ThermalConductivity | entropy_ThermalConductivity | wtd_entropy_ThermalConductivity | range_ThermalConductivity | wtd_range_ThermalConductivity | std_ThermalConductivity | wtd_std_ThermalConductivity | mean_Valence | wtd_mean_Valence | gmean_Valence | wtd_gmean_Valence | entropy_Valence | wtd_entropy_Valence | range_Valence | wtd_range_Valence | std_Valence | wtd_std_Valence | H | Li | Be | B | C | N | O | F | Na | Mg | Al | Si | P | S | Cl | K | Ca | Sc | Ti | V | Cr | Mn | Fe | Co | Ni | Cu | Zn | Ga | Ge | As | Se | Br | Rb | Sr | Y | Zr | Nb | Mo | Tc | Ru | Rh | Pd | Ag | Cd | In | Sn | Sb | Te | I | Cs | Ba | La | Ce | Pr | Nd | Sm | Eu | Gd | Tb | Dy | Ho | Er | Tm | Yb | Lu | Hf | Ta | W | Re | Os | Ir | Pt | Au | Hg | Tl | Pb | Bi | critical_temp | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 |
| mean | 4.115224 | 87.557631 | 72.988310 | 71.290627 | 58.539916 | 1.165608 | 1.063884 | 115.601251 | 33.225218 | 44.391893 | 41.448045 | 769.614748 | 870.442317 | 737.474751 | 832.769649 | 1.299172 | 0.926726 | 572.222612 | 483.517264 | 215.631279 | 224.050033 | 157.983101 | 134.720039 | 144.448738 | 120.989330 | 1.267756 | 1.131138 | 139.325025 | 51.369885 | 51.601267 | 52.340465 | 6111.465214 | 5267.188547 | 3460.692235 | 3117.241110 | 1.072425 | 0.856037 | 8665.438818 | 2902.736814 | 3416.910784 | 3319.170628 | 76.879751 | 92.717486 | 54.359502 | 72.416225 | 1.070250 | 0.770757 | 120.730514 | 59.332637 | 48.912207 | 44.409356 | 14.296113 | 13.848001 | 10.136977 | 10.141161 | 1.093343 | 0.914065 | 21.138994 | 8.218528 | 8.323333 | 7.717576 | 89.706911 | 81.549080 | 29.841727 | 27.308061 | 0.727630 | 0.539991 | 250.893443 | 62.033066 | 98.943993 | 96.234051 | 3.198228 | 3.153127 | 3.056536 | 3.055885 | 1.295682 | 1.052841 | 2.041010 | 1.483007 | 0.839342 | 0.673987 | 0.017685 | 0.012125 | 0.034638 | 0.142594 | 0.384968 | 0.013284 | 3.009129 | 0.014874 | 0.008892 | 0.026772 | 0.061678 | 0.189889 | 0.028143 | 0.106246 | 0.009050 | 0.016042 | 0.258347 | 0.010919 | 0.156817 | 0.224782 | 0.006119 | 0.003191 | 0.153182 | 0.035323 | 0.090182 | 1.276751 | 0.014034 | 0.073997 | 0.082556 | 0.155197 | 0.078662 | 0.003940 | 0.007799 | 0.326909 | 0.177556 | 0.370901 | 0.442349 | 0.146367 | 0.002291 | 0.055325 | 0.068072 | 0.085034 | 0.007834 | 0.009152 | 0.049468 | 0.120994 | 0.101269 | 0.040491 | 0.004744 | 0.004129 | 0.568440 | 0.264953 | 0.030662 | 0.041494 | 0.039666 | 0.021992 | 0.017821 | 0.023959 | 0.002857 | 0.009536 | 0.008832 | 0.014217 | 0.008909 | 0.012716 | 0.026849 | 0.009168 | 0.036086 | 0.010424 | 0.038206 | 0.022512 | 0.061558 | 0.034108 | 0.020535 | 0.036663 | 0.047954 | 0.042461 | 0.201009 | 34.421219 |
| std | 1.439295 | 29.676497 | 33.490406 | 31.030272 | 36.651067 | 0.364930 | 0.401423 | 54.626887 | 26.967752 | 20.035430 | 19.983544 | 87.488694 | 143.278200 | 78.327275 | 119.772520 | 0.381935 | 0.334018 | 309.614442 | 224.042874 | 109.966774 | 127.927104 | 20.147288 | 28.801567 | 22.090958 | 35.837843 | 0.375411 | 0.407159 | 67.272228 | 35.019356 | 22.898396 | 25.294524 | 2846.785185 | 3221.314506 | 3703.256370 | 3975.122587 | 0.342356 | 0.319761 | 4097.126831 | 2398.471020 | 1673.624915 | 1611.799629 | 27.701890 | 32.276387 | 29.007425 | 31.648444 | 0.343391 | 0.285986 | 58.700327 | 28.620409 | 21.740521 | 20.429293 | 11.300188 | 14.279335 | 10.065901 | 13.134007 | 0.375932 | 0.370125 | 20.370620 | 11.414066 | 8.671651 | 7.288239 | 38.517485 | 45.519256 | 34.059581 | 40.191150 | 0.325976 | 0.318248 | 158.703557 | 43.123317 | 60.143272 | 63.710355 | 1.044611 | 1.191249 | 1.046257 | 1.174815 | 0.393155 | 0.380291 | 1.242345 | 0.978176 | 0.484676 | 0.455580 | 0.267220 | 0.129552 | 0.848541 | 1.044486 | 4.408032 | 0.150427 | 3.811649 | 0.132119 | 0.101685 | 0.271606 | 1.126254 | 2.217277 | 0.466710 | 0.760821 | 0.119717 | 0.138187 | 0.902732 | 0.185651 | 2.728139 | 3.407763 | 0.254272 | 0.129449 | 0.713075 | 0.580672 | 0.982521 | 2.079375 | 0.403316 | 1.115005 | 1.021279 | 1.076049 | 0.676294 | 0.083907 | 0.121254 | 0.763625 | 0.429953 | 4.846459 | 4.848246 | 2.084302 | 0.064728 | 0.770327 | 1.005898 | 1.554018 | 0.167831 | 0.688729 | 0.521820 | 1.886951 | 1.839020 | 0.718043 | 0.088480 | 0.077676 | 0.983288 | 2.320822 | 0.173147 | 1.282059 | 0.224657 | 0.183173 | 0.151433 | 0.155860 | 0.064737 | 0.104153 | 0.098728 | 0.131417 | 0.130455 | 0.214806 | 0.276861 | 0.208969 | 0.851380 | 0.164628 | 1.177476 | 0.282265 | 0.864859 | 0.307888 | 0.717975 | 0.205846 | 0.272298 | 0.274365 | 0.655927 | 34.254362 |
| min | 1.000000 | 6.941000 | 6.423452 | 5.320573 | 1.960849 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 375.500000 | 375.500000 | 375.500000 | 375.500000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 48.000000 | 48.000000 | 48.000000 | 48.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.429000 | 1.429000 | 1.429000 | 0.686245 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.500000 | 1.500000 | 1.500000 | 1.500000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.222000 | 0.222000 | 0.222000 | 0.222000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.026580 | 0.026580 | 0.026580 | 0.022952 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000210 |
| 25% | 3.000000 | 72.458076 | 52.143839 | 58.041225 | 35.248990 | 0.966676 | 0.775363 | 78.512902 | 16.824174 | 32.890369 | 28.539377 | 723.740000 | 738.946339 | 692.541331 | 720.108284 | 1.085871 | 0.753757 | 262.400000 | 291.088889 | 114.135763 | 92.994286 | 149.333333 | 112.127359 | 133.542493 | 89.210097 | 1.066389 | 0.852181 | 80.000000 | 28.598137 | 35.112518 | 32.016958 | 4513.500000 | 2999.158291 | 883.117278 | 66.746836 | 0.913959 | 0.688693 | 6648.000000 | 1656.847429 | 2819.497063 | 2564.342926 | 62.090000 | 73.350000 | 33.700512 | 50.772124 | 0.890589 | 0.660662 | 86.700000 | 34.036000 | 38.372410 | 33.440123 | 7.588667 | 5.033407 | 4.109978 | 1.322127 | 0.833333 | 0.672732 | 12.878000 | 2.329309 | 4.261340 | 4.603491 | 61.000000 | 54.180953 | 8.339818 | 1.087284 | 0.457810 | 0.250677 | 86.382000 | 29.349419 | 37.933172 | 31.985437 | 2.333333 | 2.116732 | 2.279705 | 2.091251 | 1.060857 | 0.775678 | 1.000000 | 0.921454 | 0.451754 | 0.306892 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 5.365000 |
| 50% | 4.000000 | 84.922750 | 60.696571 | 66.361592 | 39.918385 | 1.199541 | 1.146783 | 122.906070 | 26.636008 | 45.123500 | 44.285984 | 764.900000 | 889.966667 | 727.960610 | 856.202765 | 1.356236 | 0.916843 | 764.100000 | 510.440000 | 266.373871 | 258.449503 | 160.250000 | 125.970297 | 142.807563 | 113.181369 | 1.330735 | 1.242878 | 171.000000 | 43.000000 | 58.663106 | 59.932929 | 5329.085800 | 4303.421500 | 1339.974702 | 1515.364631 | 1.090610 | 0.882747 | 8958.571000 | 2082.956581 | 3301.890502 | 3625.631828 | 73.100000 | 102.856863 | 51.470113 | 73.173958 | 1.138284 | 0.781205 | 127.050000 | 71.156250 | 51.125720 | 48.029866 | 9.304400 | 8.330667 | 5.253498 | 4.929787 | 1.112098 | 0.994998 | 12.878000 | 3.436400 | 4.948155 | 5.500675 | 96.504430 | 73.333333 | 14.287643 | 6.096120 | 0.738694 | 0.545783 | 399.795000 | 56.556240 | 135.762089 | 113.556983 | 2.833333 | 2.618182 | 2.615321 | 2.434057 | 1.368922 | 1.166532 | 2.000000 | 1.063077 | 0.800000 | 0.500000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.900000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 20.000000 |
| 75% | 5.000000 | 100.404410 | 86.103540 | 78.116681 | 73.113234 | 1.444537 | 1.359418 | 154.119320 | 38.356908 | 59.322812 | 53.629284 | 796.320000 | 1004.117384 | 765.715174 | 937.575826 | 1.551120 | 1.061750 | 810.600000 | 690.703310 | 297.724924 | 342.656991 | 169.857143 | 158.265231 | 155.938199 | 150.988640 | 1.512348 | 1.425684 | 205.000000 | 60.224491 | 69.424491 | 73.777278 | 6728.000000 | 6416.333333 | 5794.965188 | 5766.015191 | 1.323930 | 1.080939 | 9778.571000 | 3409.026316 | 4004.273231 | 3959.191394 | 85.504167 | 110.738462 | 67.505900 | 89.975670 | 1.345894 | 0.877541 | 138.630000 | 76.706965 | 56.221787 | 53.320838 | 17.114444 | 18.514286 | 13.600037 | 16.428652 | 1.378110 | 1.157379 | 23.200000 | 10.498780 | 9.041230 | 8.017581 | 111.005316 | 99.062911 | 42.371302 | 47.308041 | 0.962218 | 0.777353 | 399.973420 | 91.869245 | 153.806272 | 162.711144 | 4.000000 | 4.026201 | 3.727919 | 3.914868 | 1.589027 | 1.330801 | 3.000000 | 1.918400 | 1.200000 | 1.020436 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 6.800000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 2.815000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.350000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 63.000000 |
| max | 9.000000 | 208.980400 | 208.980400 | 208.980400 | 208.980400 | 1.983797 | 1.958203 | 207.972460 | 205.589910 | 101.019700 | 101.019700 | 1313.100000 | 1348.028986 | 1313.100000 | 1327.593381 | 2.157777 | 2.038560 | 1304.500000 | 1251.855072 | 499.671949 | 479.162305 | 298.000000 | 298.000000 | 298.000000 | 298.000000 | 2.141961 | 1.903748 | 256.000000 | 240.164344 | 115.500000 | 97.140711 | 22590.000000 | 22590.000000 | 22590.000000 | 22590.000000 | 1.954297 | 1.703420 | 22588.571000 | 22434.160000 | 10724.374500 | 10410.932005 | 326.100000 | 326.100000 | 326.100000 | 326.100000 | 1.767732 | 1.675400 | 349.000000 | 218.696600 | 162.895331 | 169.075862 | 105.000000 | 105.000000 | 105.000000 | 105.000000 | 2.034410 | 1.747165 | 104.778000 | 102.675000 | 51.635000 | 51.680482 | 332.500000 | 406.960000 | 317.883627 | 376.032878 | 1.633977 | 1.612989 | 429.974170 | 401.440000 | 214.986150 | 213.300452 | 7.000000 | 7.000000 | 7.000000 | 7.000000 | 2.141963 | 1.949739 | 6.000000 | 6.992200 | 3.000000 | 3.000000 | 14.000000 | 3.000000 | 40.000000 | 105.000000 | 120.000000 | 12.800000 | 66.000000 | 4.000000 | 4.000000 | 12.000000 | 99.925000 | 100.000000 | 20.000000 | 15.000000 | 3.000000 | 3.300000 | 24.000000 | 5.000000 | 75.000000 | 79.500000 | 34.900000 | 14.000000 | 30.000000 | 35.380000 | 45.000000 | 98.000000 | 20.000000 | 41.000000 | 46.000000 | 18.000000 | 19.000000 | 5.000000 | 4.000000 | 16.700000 | 9.000000 | 96.710000 | 99.976000 | 99.992000 | 6.000000 | 64.000000 | 45.000000 | 50.997450 | 7.000000 | 99.995000 | 31.500000 | 99.200000 | 83.500000 | 66.700000 | 4.000000 | 3.000000 | 24.000000 | 98.000000 | 4.998000 | 185.000000 | 6.000000 | 12.000000 | 6.000000 | 4.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 16.000000 | 7.000000 | 25.000000 | 55.000000 | 14.000000 | 97.240000 | 10.000000 | 45.000000 | 5.800000 | 64.000000 | 8.000000 | 7.000000 | 19.000000 | 14.000000 | 185.000000 |
# Look at Correlation from File_1 train.csv
corr1 = file_1.corr().abs()
# See resource below graph
upper_tri = corr1.where(np.triu(np.ones(corr1.shape),k=1).astype(np.bool))
sns.set_theme(style="white")
plt.subplots(figsize=(18,18))
cmap = sns.diverging_palette(230, 20, as_cmap=True)
sns.heatmap(upper_tri>0.95, cmap=cmap, vmax=1, center=0.5,
square=True, linewidths=.5, cbar_kws={"shrink": .5})
<AxesSubplot:>
Resource: Ideas of how to create upper triangle correlation plot as well as deleting columns that have too high in correlation with existing columns are from here: https://www.dezyre.com/recipes/drop-out-highly-correlated-features-in-python
# Check for correlation greater than 0.95
# correlation greater than 0.95 columns are too high in correlation of existing columns
to_drop = [column for column in upper_tri.columns if any(upper_tri[column] > 0.95)]
print(to_drop)
['wtd_gmean_atomic_mass', 'std_atomic_mass', 'gmean_fie', 'wtd_gmean_fie', 'entropy_fie', 'std_fie', 'wtd_gmean_atomic_radius', 'entropy_atomic_radius', 'wtd_entropy_atomic_radius', 'std_atomic_radius', 'wtd_std_atomic_radius', 'wtd_gmean_Density', 'std_Density', 'std_ElectronAffinity', 'wtd_gmean_FusionHeat', 'std_FusionHeat', 'std_ThermalConductivity', 'wtd_std_ThermalConductivity', 'gmean_Valence', 'wtd_gmean_Valence', 'entropy_Valence', 'wtd_entropy_Valence', 'std_Valence']
# Dropping variables that have correlation over 0.95 compared to other variables
data.drop(to_drop, axis=1, inplace=True)
# There is 136 columns left
len(data.columns)
136
data.describe()
| number_of_elements | mean_atomic_mass | wtd_mean_atomic_mass | gmean_atomic_mass | entropy_atomic_mass | wtd_entropy_atomic_mass | range_atomic_mass | wtd_range_atomic_mass | wtd_std_atomic_mass | mean_fie | wtd_mean_fie | wtd_entropy_fie | range_fie | wtd_range_fie | wtd_std_fie | mean_atomic_radius | wtd_mean_atomic_radius | gmean_atomic_radius | range_atomic_radius | wtd_range_atomic_radius | mean_Density | wtd_mean_Density | gmean_Density | entropy_Density | wtd_entropy_Density | range_Density | wtd_range_Density | wtd_std_Density | mean_ElectronAffinity | wtd_mean_ElectronAffinity | gmean_ElectronAffinity | wtd_gmean_ElectronAffinity | entropy_ElectronAffinity | wtd_entropy_ElectronAffinity | range_ElectronAffinity | wtd_range_ElectronAffinity | wtd_std_ElectronAffinity | mean_FusionHeat | wtd_mean_FusionHeat | gmean_FusionHeat | entropy_FusionHeat | wtd_entropy_FusionHeat | range_FusionHeat | wtd_range_FusionHeat | wtd_std_FusionHeat | mean_ThermalConductivity | wtd_mean_ThermalConductivity | gmean_ThermalConductivity | wtd_gmean_ThermalConductivity | entropy_ThermalConductivity | wtd_entropy_ThermalConductivity | range_ThermalConductivity | wtd_range_ThermalConductivity | mean_Valence | wtd_mean_Valence | range_Valence | wtd_range_Valence | wtd_std_Valence | H | Li | Be | B | C | N | O | F | Na | Mg | Al | Si | P | S | Cl | K | Ca | Sc | Ti | V | Cr | Mn | Fe | Co | Ni | Cu | Zn | Ga | Ge | As | Se | Br | Rb | Sr | Y | Zr | Nb | Mo | Tc | Ru | Rh | Pd | Ag | Cd | In | Sn | Sb | Te | I | Cs | Ba | La | Ce | Pr | Nd | Sm | Eu | Gd | Tb | Dy | Ho | Er | Tm | Yb | Lu | Hf | Ta | W | Re | Os | Ir | Pt | Au | Hg | Tl | Pb | Bi | critical_temp | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 | 21263.000000 |
| mean | 4.115224 | 87.557631 | 72.988310 | 71.290627 | 1.165608 | 1.063884 | 115.601251 | 33.225218 | 41.448045 | 769.614748 | 870.442317 | 0.926726 | 572.222612 | 483.517264 | 224.050033 | 157.983101 | 134.720039 | 144.448738 | 139.325025 | 51.369885 | 6111.465214 | 5267.188547 | 3460.692235 | 1.072425 | 0.856037 | 8665.438818 | 2902.736814 | 3319.170628 | 76.879751 | 92.717486 | 54.359502 | 72.416225 | 1.070250 | 0.770757 | 120.730514 | 59.332637 | 44.409356 | 14.296113 | 13.848001 | 10.136977 | 1.093343 | 0.914065 | 21.138994 | 8.218528 | 7.717576 | 89.706911 | 81.549080 | 29.841727 | 27.308061 | 0.727630 | 0.539991 | 250.893443 | 62.033066 | 3.198228 | 3.153127 | 2.041010 | 1.483007 | 0.673987 | 0.017685 | 0.012125 | 0.034638 | 0.142594 | 0.384968 | 0.013284 | 3.009129 | 0.014874 | 0.008892 | 0.026772 | 0.061678 | 0.189889 | 0.028143 | 0.106246 | 0.009050 | 0.016042 | 0.258347 | 0.010919 | 0.156817 | 0.224782 | 0.006119 | 0.003191 | 0.153182 | 0.035323 | 0.090182 | 1.276751 | 0.014034 | 0.073997 | 0.082556 | 0.155197 | 0.078662 | 0.003940 | 0.007799 | 0.326909 | 0.177556 | 0.370901 | 0.442349 | 0.146367 | 0.002291 | 0.055325 | 0.068072 | 0.085034 | 0.007834 | 0.009152 | 0.049468 | 0.120994 | 0.101269 | 0.040491 | 0.004744 | 0.004129 | 0.568440 | 0.264953 | 0.030662 | 0.041494 | 0.039666 | 0.021992 | 0.017821 | 0.023959 | 0.002857 | 0.009536 | 0.008832 | 0.014217 | 0.008909 | 0.012716 | 0.026849 | 0.009168 | 0.036086 | 0.010424 | 0.038206 | 0.022512 | 0.061558 | 0.034108 | 0.020535 | 0.036663 | 0.047954 | 0.042461 | 0.201009 | 34.421219 |
| std | 1.439295 | 29.676497 | 33.490406 | 31.030272 | 0.364930 | 0.401423 | 54.626887 | 26.967752 | 19.983544 | 87.488694 | 143.278200 | 0.334018 | 309.614442 | 224.042874 | 127.927104 | 20.147288 | 28.801567 | 22.090958 | 67.272228 | 35.019356 | 2846.785185 | 3221.314506 | 3703.256370 | 0.342356 | 0.319761 | 4097.126831 | 2398.471020 | 1611.799629 | 27.701890 | 32.276387 | 29.007425 | 31.648444 | 0.343391 | 0.285986 | 58.700327 | 28.620409 | 20.429293 | 11.300188 | 14.279335 | 10.065901 | 0.375932 | 0.370125 | 20.370620 | 11.414066 | 7.288239 | 38.517485 | 45.519256 | 34.059581 | 40.191150 | 0.325976 | 0.318248 | 158.703557 | 43.123317 | 1.044611 | 1.191249 | 1.242345 | 0.978176 | 0.455580 | 0.267220 | 0.129552 | 0.848541 | 1.044486 | 4.408032 | 0.150427 | 3.811649 | 0.132119 | 0.101685 | 0.271606 | 1.126254 | 2.217277 | 0.466710 | 0.760821 | 0.119717 | 0.138187 | 0.902732 | 0.185651 | 2.728139 | 3.407763 | 0.254272 | 0.129449 | 0.713075 | 0.580672 | 0.982521 | 2.079375 | 0.403316 | 1.115005 | 1.021279 | 1.076049 | 0.676294 | 0.083907 | 0.121254 | 0.763625 | 0.429953 | 4.846459 | 4.848246 | 2.084302 | 0.064728 | 0.770327 | 1.005898 | 1.554018 | 0.167831 | 0.688729 | 0.521820 | 1.886951 | 1.839020 | 0.718043 | 0.088480 | 0.077676 | 0.983288 | 2.320822 | 0.173147 | 1.282059 | 0.224657 | 0.183173 | 0.151433 | 0.155860 | 0.064737 | 0.104153 | 0.098728 | 0.131417 | 0.130455 | 0.214806 | 0.276861 | 0.208969 | 0.851380 | 0.164628 | 1.177476 | 0.282265 | 0.864859 | 0.307888 | 0.717975 | 0.205846 | 0.272298 | 0.274365 | 0.655927 | 34.254362 |
| min | 1.000000 | 6.941000 | 6.423452 | 5.320573 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 375.500000 | 375.500000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 48.000000 | 48.000000 | 48.000000 | 0.000000 | 0.000000 | 1.429000 | 1.429000 | 1.429000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.500000 | 1.500000 | 1.500000 | 1.500000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.222000 | 0.222000 | 0.222000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.026580 | 0.026580 | 0.026580 | 0.022952 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000210 |
| 25% | 3.000000 | 72.458076 | 52.143839 | 58.041225 | 0.966676 | 0.775363 | 78.512902 | 16.824174 | 28.539377 | 723.740000 | 738.946339 | 0.753757 | 262.400000 | 291.088889 | 92.994286 | 149.333333 | 112.127359 | 133.542493 | 80.000000 | 28.598137 | 4513.500000 | 2999.158291 | 883.117278 | 0.913959 | 0.688693 | 6648.000000 | 1656.847429 | 2564.342926 | 62.090000 | 73.350000 | 33.700512 | 50.772124 | 0.890589 | 0.660662 | 86.700000 | 34.036000 | 33.440123 | 7.588667 | 5.033407 | 4.109978 | 0.833333 | 0.672732 | 12.878000 | 2.329309 | 4.603491 | 61.000000 | 54.180953 | 8.339818 | 1.087284 | 0.457810 | 0.250677 | 86.382000 | 29.349419 | 2.333333 | 2.116732 | 1.000000 | 0.921454 | 0.306892 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 5.365000 |
| 50% | 4.000000 | 84.922750 | 60.696571 | 66.361592 | 1.199541 | 1.146783 | 122.906070 | 26.636008 | 44.285984 | 764.900000 | 889.966667 | 0.916843 | 764.100000 | 510.440000 | 258.449503 | 160.250000 | 125.970297 | 142.807563 | 171.000000 | 43.000000 | 5329.085800 | 4303.421500 | 1339.974702 | 1.090610 | 0.882747 | 8958.571000 | 2082.956581 | 3625.631828 | 73.100000 | 102.856863 | 51.470113 | 73.173958 | 1.138284 | 0.781205 | 127.050000 | 71.156250 | 48.029866 | 9.304400 | 8.330667 | 5.253498 | 1.112098 | 0.994998 | 12.878000 | 3.436400 | 5.500675 | 96.504430 | 73.333333 | 14.287643 | 6.096120 | 0.738694 | 0.545783 | 399.795000 | 56.556240 | 2.833333 | 2.618182 | 2.000000 | 1.063077 | 0.500000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.900000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 20.000000 |
| 75% | 5.000000 | 100.404410 | 86.103540 | 78.116681 | 1.444537 | 1.359418 | 154.119320 | 38.356908 | 53.629284 | 796.320000 | 1004.117384 | 1.061750 | 810.600000 | 690.703310 | 342.656991 | 169.857143 | 158.265231 | 155.938199 | 205.000000 | 60.224491 | 6728.000000 | 6416.333333 | 5794.965188 | 1.323930 | 1.080939 | 9778.571000 | 3409.026316 | 3959.191394 | 85.504167 | 110.738462 | 67.505900 | 89.975670 | 1.345894 | 0.877541 | 138.630000 | 76.706965 | 53.320838 | 17.114444 | 18.514286 | 13.600037 | 1.378110 | 1.157379 | 23.200000 | 10.498780 | 8.017581 | 111.005316 | 99.062911 | 42.371302 | 47.308041 | 0.962218 | 0.777353 | 399.973420 | 91.869245 | 4.000000 | 4.026201 | 3.000000 | 1.918400 | 1.020436 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 6.800000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 2.815000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.350000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 63.000000 |
| max | 9.000000 | 208.980400 | 208.980400 | 208.980400 | 1.983797 | 1.958203 | 207.972460 | 205.589910 | 101.019700 | 1313.100000 | 1348.028986 | 2.038560 | 1304.500000 | 1251.855072 | 479.162305 | 298.000000 | 298.000000 | 298.000000 | 256.000000 | 240.164344 | 22590.000000 | 22590.000000 | 22590.000000 | 1.954297 | 1.703420 | 22588.571000 | 22434.160000 | 10410.932005 | 326.100000 | 326.100000 | 326.100000 | 326.100000 | 1.767732 | 1.675400 | 349.000000 | 218.696600 | 169.075862 | 105.000000 | 105.000000 | 105.000000 | 2.034410 | 1.747165 | 104.778000 | 102.675000 | 51.680482 | 332.500000 | 406.960000 | 317.883627 | 376.032878 | 1.633977 | 1.612989 | 429.974170 | 401.440000 | 7.000000 | 7.000000 | 6.000000 | 6.992200 | 3.000000 | 14.000000 | 3.000000 | 40.000000 | 105.000000 | 120.000000 | 12.800000 | 66.000000 | 4.000000 | 4.000000 | 12.000000 | 99.925000 | 100.000000 | 20.000000 | 15.000000 | 3.000000 | 3.300000 | 24.000000 | 5.000000 | 75.000000 | 79.500000 | 34.900000 | 14.000000 | 30.000000 | 35.380000 | 45.000000 | 98.000000 | 20.000000 | 41.000000 | 46.000000 | 18.000000 | 19.000000 | 5.000000 | 4.000000 | 16.700000 | 9.000000 | 96.710000 | 99.976000 | 99.992000 | 6.000000 | 64.000000 | 45.000000 | 50.997450 | 7.000000 | 99.995000 | 31.500000 | 99.200000 | 83.500000 | 66.700000 | 4.000000 | 3.000000 | 24.000000 | 98.000000 | 4.998000 | 185.000000 | 6.000000 | 12.000000 | 6.000000 | 4.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 16.000000 | 7.000000 | 25.000000 | 55.000000 | 14.000000 | 97.240000 | 10.000000 | 45.000000 | 5.800000 | 64.000000 | 8.000000 | 7.000000 | 19.000000 | 14.000000 | 185.000000 |
subset_plot = data[['range_atomic_mass','range_atomic_radius','gmean_Density',]]
for i in data.columns:
if len(i)>2:
plt.hist(data[i],bins=100)
plt.title(i)
plt.show()
X = data.drop(["critical_temp"], axis=1)
y = data.critical_temp
cv = KFold(n_splits=10, shuffle=True, random_state=101)
len(X.columns)
135
# Method to plot learning curve for resulting model
def plot_learning_curve(df, name):
result_wide = pd.melt(df, id_vars=[f'param_{name.lower()}__alpha'],value_vars=['mean_test_score','mean_train_score'])
result_wide.columns = ['Alpha Values','Train/Test','Negative Mean Squared Error']
plt.figure(figsize=(8,6))
sns.lineplot(data=result_wide,x='Alpha Values',y='Negative Mean Squared Error',hue='Train/Test').set_title(f'{name} Learning Curve')
%%time
# Try for better alpha value
# Lasso Uses L1 Regularization
# Meaning it can do feature selection
pipe_lasso = Pipeline(steps=[('scaling', MinMaxScaler()), ('lasso', Lasso(max_iter=10000))])
# Define Initial Grid
grid_lasso = {"lasso__alpha": [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 1, 1.5, 2, 2.5, 5]}
search_lasso = GridSearchCV(pipe_lasso, grid_lasso, cv=cv, n_jobs=-2,verbose=10, scoring='neg_mean_squared_error',
return_train_score=True)
model_lasso = search_lasso.fit(X,y)
lasso_results = pd.DataFrame(model_lasso.cv_results_)
Fitting 10 folds for each of 11 candidates, totalling 110 fits CPU times: user 2.08 s, sys: 772 ms, total: 2.85 s Wall time: 15.9 s
lasso_results
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_lasso__alpha | params | split0_test_score | split1_test_score | split2_test_score | split3_test_score | split4_test_score | split5_test_score | split6_test_score | split7_test_score | split8_test_score | split9_test_score | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | split5_train_score | split6_train_score | split7_train_score | split8_train_score | split9_train_score | mean_train_score | std_train_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 3.175188 | 0.619312 | 0.014818 | 0.002749 | 0.01 | {'lasso__alpha': 0.01} | -314.262952 | -320.247890 | -307.861977 | -293.733770 | -303.761778 | -1325.839945 | -283.620574 | -292.592124 | -884.092433 | -315.894810 | -464.190825 | 335.444442 | 3 | -305.081110 | -303.926221 | -305.718518 | -307.306390 | -306.239423 | -299.244265 | -308.526881 | -306.910187 | -305.929018 | -305.152322 | -305.403434 | 2.387600 |
| 1 | 0.585883 | 0.049234 | 0.012821 | 0.002212 | 0.1 | {'lasso__alpha': 0.1} | -379.007626 | -370.633392 | -374.514588 | -359.029290 | -372.943234 | -483.297433 | -355.331418 | -351.105544 | -368.404464 | -395.316642 | -380.958363 | 36.169098 | 1 | -372.829590 | -373.428497 | -373.383802 | -374.249473 | -374.528525 | -368.896619 | -375.931715 | -372.628705 | -373.222767 | -371.808707 | -373.090840 | 1.766758 |
| 2 | 0.357261 | 0.015721 | 0.010753 | 0.001103 | 0.2 | {'lasso__alpha': 0.2} | -445.012949 | -433.197123 | -443.145952 | -424.590248 | -442.095758 | -492.675297 | -422.218438 | -418.891281 | -431.986711 | -464.013779 | -441.782754 | 21.139299 | 2 | -441.202686 | -441.295917 | -442.071719 | -443.504751 | -441.491947 | -435.243051 | -444.265012 | -440.023425 | -443.302875 | -439.022704 | -441.142409 | 2.481121 |
| 3 | 0.322193 | 0.019779 | 0.011054 | 0.001001 | 0.3 | {'lasso__alpha': 0.3} | -468.643813 | -456.009132 | -470.247531 | -450.789481 | -467.043153 | -515.148649 | -449.327802 | -445.803009 | -459.295293 | -489.332304 | -467.164017 | 20.086232 | 4 | -466.415364 | -467.284109 | -467.350310 | -468.597327 | -467.091797 | -459.702811 | -469.391001 | -466.087266 | -468.681549 | -464.592532 | -466.519407 | 2.630929 |
| 4 | 0.281605 | 0.011852 | 0.010940 | 0.001330 | 0.4 | {'lasso__alpha': 0.4} | -490.511339 | -478.123067 | -497.310128 | -474.816787 | -489.267532 | -534.615749 | -472.383398 | -465.360520 | -483.537716 | -510.008229 | -489.593446 | 19.427503 | 5 | -489.543180 | -490.885283 | -490.915496 | -491.196358 | -489.920713 | -480.733507 | -492.246286 | -485.181934 | -491.548758 | -487.123874 | -488.929539 | 3.414783 |
| 5 | 0.287159 | 0.012232 | 0.010622 | 0.001115 | 0.5 | {'lasso__alpha': 0.5} | -516.700752 | -503.601718 | -527.290800 | -503.192441 | -516.165732 | -556.558931 | -499.202699 | -486.908772 | -511.257109 | -536.005635 | -515.688459 | 19.081345 | 6 | -516.888940 | -517.991293 | -517.392865 | -518.001032 | -517.423084 | -504.601300 | -518.850914 | -506.405491 | -518.013573 | -514.674302 | -515.024279 | 4.890083 |
| 6 | 0.239706 | 0.011696 | 0.010346 | 0.001186 | 1 | {'lasso__alpha': 1} | -571.865903 | -559.512232 | -587.324745 | -563.417194 | -571.820426 | -607.695493 | -551.875781 | -554.918380 | -568.001691 | -596.572002 | -573.300385 | 17.430722 | 7 | -572.722882 | -574.940155 | -572.027541 | -574.985405 | -573.371336 | -566.357341 | -574.658973 | -574.751362 | -573.653511 | -571.994735 | -572.946324 | 2.453968 |
| 7 | 0.222080 | 0.009106 | 0.009477 | 0.001090 | 1.5 | {'lasso__alpha': 1.5} | -591.686263 | -585.317466 | -610.422776 | -587.847144 | -592.464833 | -629.998503 | -570.694693 | -575.183095 | -589.292751 | -622.886968 | -595.579449 | 18.480898 | 8 | -595.038699 | -597.967478 | -593.777959 | -596.956469 | -595.160723 | -590.930378 | -597.073772 | -597.016505 | -595.766353 | -593.659943 | -595.334828 | 2.008607 |
| 8 | 0.217786 | 0.005892 | 0.009512 | 0.001177 | 2 | {'lasso__alpha': 2} | -614.739151 | -615.391020 | -636.038463 | -615.494643 | -616.340129 | -653.581339 | -593.132166 | -598.841793 | -613.966384 | -652.119944 | -620.964503 | 19.224707 | 9 | -620.344014 | -623.432198 | -619.118695 | -622.331202 | -620.554519 | -616.276995 | -622.502745 | -622.409430 | -621.172118 | -618.962418 | -620.710433 | 2.043285 |
| 9 | 0.212364 | 0.004052 | 0.009236 | 0.001115 | 2.5 | {'lasso__alpha': 2.5} | -644.824545 | -652.997851 | -668.795922 | -650.390873 | -647.524280 | -684.324732 | -622.989279 | -629.806642 | -645.986263 | -688.374916 | -653.601530 | 20.174026 | 10 | -652.879441 | -656.172562 | -651.699665 | -654.955903 | -653.203703 | -648.865384 | -655.197142 | -655.057490 | -653.836688 | -651.494200 | -653.336218 | 2.088789 |
| 10 | 0.164910 | 0.013356 | 0.008759 | 0.001855 | 5 | {'lasso__alpha': 5} | -788.804618 | -810.248337 | -816.641252 | -810.545588 | -790.229976 | -832.717323 | -765.477951 | -767.990551 | -787.962625 | -851.776017 | -802.239424 | 26.031256 | 11 | -802.388456 | -803.027845 | -800.198457 | -803.651220 | -802.006608 | -798.471809 | -804.928406 | -803.602133 | -801.991055 | -799.680297 | -801.994629 | 1.901337 |
lasso_results.loc[:,['param_lasso__alpha','rank_test_score','mean_test_score']]
| param_lasso__alpha | rank_test_score | mean_test_score | |
|---|---|---|---|
| 0 | 0.01 | 3 | -464.190825 |
| 1 | 0.1 | 1 | -380.958363 |
| 2 | 0.2 | 2 | -441.782754 |
| 3 | 0.3 | 4 | -467.164017 |
| 4 | 0.4 | 5 | -489.593446 |
| 5 | 0.5 | 6 | -515.688459 |
| 6 | 1 | 7 | -573.300385 |
| 7 | 1.5 | 8 | -595.579449 |
| 8 | 2 | 9 | -620.964503 |
| 9 | 2.5 | 10 | -653.601530 |
| 10 | 5 | 11 | -802.239424 |
plot_learning_curve(lasso_results,"Lasso")
%%time
grid_lasso_2 = {"lasso__alpha": [0.075,0.085,0.095, 0.1, 0.105,0.11,0.115]}
search_lasso_2 = GridSearchCV(pipe_lasso, grid_lasso_2, cv=cv, n_jobs=-2,verbose=10, scoring='neg_mean_squared_error',
return_train_score=True)
model_lasso_2 = search_lasso_2.fit(X,y)
lasso_results_2 = pd.DataFrame(model_lasso_2.cv_results_)
Fitting 10 folds for each of 7 candidates, totalling 70 fits CPU times: user 690 ms, sys: 70.1 ms, total: 760 ms Wall time: 7.43 s
lasso_results_2.loc[:,['param_lasso__alpha','rank_test_score','mean_test_score']]
| param_lasso__alpha | rank_test_score | mean_test_score | |
|---|---|---|---|
| 0 | 0.075 | 4 | -382.175857 |
| 1 | 0.085 | 1 | -380.010692 |
| 2 | 0.095 | 2 | -380.109532 |
| 3 | 0.1 | 3 | -380.958363 |
| 4 | 0.105 | 5 | -382.486461 |
| 5 | 0.11 | 6 | -384.612670 |
| 6 | 0.115 | 7 | -387.392172 |
plot_learning_curve(lasso_results_2,"Lasso")
%%time
grid_lasso_final = {"lasso__alpha": [0.085, 0.086, 0.087, 0.088, 0.089, 0.09]}
search_lasso_final = GridSearchCV(pipe_lasso, grid_lasso_final, cv=cv, n_jobs=-2,verbose=10, scoring='neg_mean_squared_error',
return_train_score=True)
model_lasso_final = search_lasso_final.fit(X,y)
lasso_results_final = pd.DataFrame(model_lasso_final.cv_results_)
Fitting 10 folds for each of 6 candidates, totalling 60 fits CPU times: user 1.53 s, sys: 580 ms, total: 2.11 s Wall time: 11.1 s
lasso_results_final.loc[:,['param_lasso__alpha','rank_test_score','mean_test_score','std_test_score']]
| param_lasso__alpha | rank_test_score | mean_test_score | std_test_score | |
|---|---|---|---|---|
| 0 | 0.085 | 6 | -380.010692 | 63.541241 |
| 1 | 0.086 | 5 | -379.929670 | 61.357413 |
| 2 | 0.087 | 4 | -379.875065 | 59.230653 |
| 3 | 0.088 | 2 | -379.846306 | 57.158920 |
| 4 | 0.089 | 1 | -379.841762 | 55.149377 |
| 5 | 0.09 | 3 | -379.858578 | 53.203206 |
plot_learning_curve(lasso_results_final,"Lasso")
feature_importance_lasso = pd.DataFrame(zip(X.columns,model_lasso_final.best_estimator_._final_estimator.coef_),columns=('feature name','coefficient'))
feature_importance_lasso[feature_importance_lasso['coefficient']!=0].sort_values(by='coefficient',ascending=False)
| feature name | coefficient | |
|---|---|---|
| 108 | Ba | 83.356943 |
| 46 | wtd_mean_ThermalConductivity | 77.776971 |
| 134 | Bi | 35.343711 |
| 6 | range_atomic_mass | 28.569184 |
| 18 | range_atomic_radius | 22.479571 |
| 5 | wtd_entropy_atomic_mass | 21.142982 |
| 51 | range_ThermalConductivity | 11.673996 |
| 41 | wtd_entropy_FusionHeat | 8.638347 |
| 50 | wtd_entropy_ThermalConductivity | 7.816948 |
| 12 | range_fie | 7.626267 |
| 74 | Ca | 7.514675 |
| 11 | wtd_entropy_fie | 3.519679 |
| 9 | mean_fie | 0.189866 |
| 49 | entropy_ThermalConductivity | 0.162290 |
| 13 | wtd_range_fie | -0.463950 |
| 27 | wtd_std_Density | -0.767240 |
| 22 | gmean_Density | -7.036095 |
| 33 | wtd_entropy_ElectronAffinity | -11.066250 |
| 44 | wtd_std_FusionHeat | -12.751822 |
| 32 | entropy_ElectronAffinity | -19.473483 |
| 8 | wtd_std_atomic_mass | -24.498506 |
| 57 | wtd_std_Valence | -31.882476 |
| 31 | wtd_gmean_ElectronAffinity | -32.810794 |
| 48 | wtd_gmean_ThermalConductivity | -60.120680 |
# try using this for feature selection with Ridge later
selected_features = feature_importance_lasso.loc[feature_importance_lasso['coefficient']!=0,['feature name']]
selected_features_list = selected_features['feature name'].tolist()
len(selected_features_list)
24
%%time
# Ridge Uses L2 Regularization
# Meaning it cannot do feature selection
pipe_ridge = Pipeline(steps=[('scaling', MinMaxScaler()), ('ridge', Ridge(max_iter=10000))])
grid_ridge = {"ridge__alpha": [100, 10, 1, 0.1, 0.01, 0.001]}
search_ridge = GridSearchCV(pipe_ridge, grid_ridge, cv=cv, n_jobs=-2,verbose=10, scoring='neg_mean_squared_error',
return_train_score=True)
model_ridge = search_ridge.fit(X,y)
ridge_results = pd.DataFrame(model_ridge.cv_results_)
Fitting 10 folds for each of 6 candidates, totalling 60 fits CPU times: user 408 ms, sys: 101 ms, total: 510 ms Wall time: 2.64 s
# take a look to see if we can improve alpha:
ridge_results.loc[:,['param_ridge__alpha','rank_test_score','mean_test_score']]
| param_ridge__alpha | rank_test_score | mean_test_score | |
|---|---|---|---|
| 0 | 100 | 1 | -395.498911 |
| 1 | 10 | 2 | -485.638123 |
| 2 | 1 | 5 | -576.743639 |
| 3 | 0.1 | 6 | -578.854177 |
| 4 | 0.01 | 4 | -576.547145 |
| 5 | 0.001 | 3 | -576.210867 |
plot_learning_curve(ridge_results,"Ridge")
%%time
# Try to improve alpha for Ridge
grid_ridge_2 = {"ridge__alpha": [75, 100, 125, 150, 175, 200, 225, 250]}
search_ridge_2 = GridSearchCV(pipe_ridge, grid_ridge_2, cv=cv, n_jobs=-2,verbose=10, scoring='neg_mean_squared_error',
return_train_score=True)
model_ridge_2 = search_ridge_2.fit(X,y)
ridge_results_2 = pd.DataFrame(model_ridge_2.cv_results_)
Fitting 10 folds for each of 8 candidates, totalling 80 fits CPU times: user 334 ms, sys: 69.6 ms, total: 404 ms Wall time: 3.15 s
# take a look to see if we can improve alpha:
ridge_results_2.loc[:,['param_ridge__alpha','rank_test_score','mean_test_score']]
| param_ridge__alpha | rank_test_score | mean_test_score | |
|---|---|---|---|
| 0 | 75 | 1 | -394.789944 |
| 1 | 100 | 2 | -395.498911 |
| 2 | 125 | 3 | -398.415299 |
| 3 | 150 | 4 | -402.151752 |
| 4 | 175 | 5 | -406.145938 |
| 5 | 200 | 6 | -410.151326 |
| 6 | 225 | 7 | -414.057103 |
| 7 | 250 | 8 | -417.815313 |
plot_learning_curve(ridge_results_2,"Ridge")
%%time
# Try to improve alpha for Ridge
grid_ridge_3 = {"ridge__alpha": [50, 60, 70, 80, 90]}
search_ridge_3 = GridSearchCV(pipe_ridge, grid_ridge_3, cv=cv, n_jobs=-2,verbose=10, scoring='neg_mean_squared_error',
return_train_score=True)
model_ridge_3 = search_ridge_3.fit(X,y)
ridge_results_3 = pd.DataFrame(model_ridge_3.cv_results_)
Fitting 10 folds for each of 5 candidates, totalling 50 fits CPU times: user 246 ms, sys: 48.8 ms, total: 294 ms Wall time: 2.1 s
# take a look to see if we can improve alpha:
ridge_results_3.loc[:,['param_ridge__alpha','rank_test_score','mean_test_score']]
| param_ridge__alpha | rank_test_score | mean_test_score | |
|---|---|---|---|
| 0 | 50 | 5 | -400.270095 |
| 1 | 60 | 4 | -396.834991 |
| 2 | 70 | 3 | -395.172284 |
| 3 | 80 | 1 | -394.630424 |
| 4 | 90 | 2 | -394.820510 |
plot_learning_curve(ridge_results_3,"Ridge")
%%time
# Try to improve alpha for Ridge
grid_ridge_final = {"ridge__alpha": [78, 79, 80, 81, 82, 83, 84]}
search_ridge_final = GridSearchCV(pipe_ridge, grid_ridge_final, cv=cv, n_jobs=-2,verbose=10, scoring='neg_mean_squared_error',
return_train_score=True)
model_ridge_final = search_ridge_final.fit(X,y)
ridge_results_final = pd.DataFrame(model_ridge_final.cv_results_)
Fitting 10 folds for each of 7 candidates, totalling 70 fits CPU times: user 295 ms, sys: 57.3 ms, total: 352 ms Wall time: 2.57 s
# take a look to see if we can improve alpha:
ridge_results_final.loc[:,['param_ridge__alpha','rank_test_score','mean_test_score','std_test_score']]
| param_ridge__alpha | rank_test_score | mean_test_score | std_test_score | |
|---|---|---|---|---|
| 0 | 78 | 7 | -394.670402 | 92.213836 |
| 1 | 79 | 6 | -394.646658 | 91.146970 |
| 2 | 80 | 4 | -394.630424 | 90.103995 |
| 3 | 81 | 2 | -394.621394 | 89.084177 |
| 4 | 82 | 1 | -394.619280 | 88.086810 |
| 5 | 83 | 3 | -394.623804 | 87.111214 |
| 6 | 84 | 5 | -394.634702 | 86.156738 |
plot_learning_curve(ridge_results_final,"Ridge")
X_subset = X[selected_features_list]
%%time
# Try with Lasso column selections instead:
pipe_ridge_mm = Pipeline(steps = [('scaler', MinMaxScaler()), ('ridge', Ridge(max_iter=10000))])
grid_ridge_subset_mm = {'ridge__alpha':[0.01, 0.1, 0.5, 1, 1.25, 1.5, 5, 25, 50, 75]}
search_ridge_subset_mm = GridSearchCV(pipe_ridge_mm, grid_ridge_subset_mm, cv=cv, n_jobs=-2, verbose=10, scoring='neg_mean_squared_error', return_train_score=True)
model_ridge_subset_mm = search_ridge_subset_mm.fit(X_subset,y)
ridge_subset_results_mm = pd.DataFrame(model_ridge_subset_mm.cv_results_)
Fitting 10 folds for each of 10 candidates, totalling 100 fits CPU times: user 351 ms, sys: 78 ms, total: 429 ms Wall time: 965 ms
# take a look to see if we can improve alpha:
ridge_subset_results_mm.loc[:,['param_ridge__alpha','rank_test_score','mean_test_score']]
| param_ridge__alpha | rank_test_score | mean_test_score | |
|---|---|---|---|
| 0 | 0.01 | 2 | -335.334441 |
| 1 | 0.1 | 1 | -335.331890 |
| 2 | 0.5 | 3 | -335.372032 |
| 3 | 1 | 4 | -335.523519 |
| 4 | 1.25 | 5 | -335.634505 |
| 5 | 1.5 | 6 | -335.765602 |
| 6 | 5 | 7 | -338.829416 |
| 7 | 25 | 8 | -358.485648 |
| 8 | 50 | 9 | -373.661375 |
| 9 | 75 | 10 | -383.952647 |
plot_learning_curve(ridge_subset_results_mm,"Ridge")
%%time
# Tune Alpha to be better
grid_ridge_subset_mm = {'ridge__alpha':[0.05, 0.07, 0.09, 0.1, 0.11, 0.12]}
search_ridge_subset_mm = GridSearchCV(pipe_ridge_mm, grid_ridge_subset_mm, cv=cv, n_jobs=-2, verbose=10, scoring='neg_mean_squared_error', return_train_score=True)
model_ridge_subset_mm = search_ridge_subset_mm.fit(X_subset,y)
ridge_subset_results_mm = pd.DataFrame(model_ridge_subset_mm.cv_results_)
Fitting 10 folds for each of 6 candidates, totalling 60 fits CPU times: user 211 ms, sys: 45.7 ms, total: 257 ms Wall time: 653 ms
# take a look to see if we can improve alpha:
ridge_subset_results_mm.loc[:,['param_ridge__alpha','rank_test_score','mean_test_score','std_test_score']]
| param_ridge__alpha | rank_test_score | mean_test_score | std_test_score | |
|---|---|---|---|---|
| 0 | 0.05 | 6 | -335.332744 | 15.726260 |
| 1 | 0.07 | 5 | -335.332235 | 15.727224 |
| 2 | 0.09 | 4 | -335.331949 | 15.728206 |
| 3 | 0.1 | 2 | -335.331890 | 15.728702 |
| 4 | 0.11 | 1 | -335.331886 | 15.729203 |
| 5 | 0.12 | 3 | -335.331937 | 15.729708 |
plot_learning_curve(ridge_subset_results_mm,"Ridge")
feature_importance_laso_ridg = pd.DataFrame(zip(X_subset.columns,model_ridge_subset_mm.best_estimator_._final_estimator.coef_),columns=('feature name','coefficient'))
feature_importance_laso_ridg.sort_values(by='coefficient',ascending=False)
| feature name | coefficient | |
|---|---|---|
| 22 | Ba | 165.111484 |
| 15 | wtd_mean_ThermalConductivity | 150.939474 |
| 23 | Bi | 90.118661 |
| 1 | range_atomic_mass | 41.213109 |
| 18 | wtd_entropy_ThermalConductivity | 29.193261 |
| 21 | Ca | 28.111319 |
| 3 | mean_fie | 25.843994 |
| 13 | wtd_entropy_FusionHeat | 19.238351 |
| 7 | range_atomic_radius | 17.392559 |
| 0 | wtd_entropy_atomic_mass | 9.328552 |
| 5 | range_fie | 2.526282 |
| 6 | wtd_range_fie | 0.966458 |
| 9 | wtd_std_Density | -1.072042 |
| 17 | entropy_ThermalConductivity | -4.044376 |
| 19 | range_ThermalConductivity | -4.619915 |
| 4 | wtd_entropy_fie | -9.056061 |
| 8 | gmean_Density | -10.809142 |
| 14 | wtd_std_FusionHeat | -11.877327 |
| 11 | entropy_ElectronAffinity | -14.837596 |
| 12 | wtd_entropy_ElectronAffinity | -25.187103 |
| 10 | wtd_gmean_ElectronAffinity | -30.798639 |
| 20 | wtd_std_Valence | -33.991949 |
| 2 | wtd_std_atomic_mass | -43.456091 |
| 16 | wtd_gmean_ThermalConductivity | -140.271813 |
feature_importance_laso_ridg_sorted = feature_importance_laso_ridg.sort_values(by='coefficient', key=abs,ascending=False)
# top ten features by magnitude of coefficient
feature_importance_laso_ridg_sorted_top = feature_importance_laso_ridg_sorted[:10]
feature_importance_laso_ridg_sorted_top
| feature name | coefficient | |
|---|---|---|
| 22 | Ba | 165.111484 |
| 15 | wtd_mean_ThermalConductivity | 150.939474 |
| 16 | wtd_gmean_ThermalConductivity | -140.271813 |
| 23 | Bi | 90.118661 |
| 2 | wtd_std_atomic_mass | -43.456091 |
| 1 | range_atomic_mass | 41.213109 |
| 20 | wtd_std_Valence | -33.991949 |
| 10 | wtd_gmean_ElectronAffinity | -30.798639 |
| 18 | wtd_entropy_ThermalConductivity | 29.193261 |
| 21 | Ca | 28.111319 |
sns.catplot(x="coefficient",
y="feature name",
kind="bar",
palette = "light:b_r",
data=feature_importance_laso_ridg_sorted_top).set(title="Feature Importance Plot")
<seaborn.axisgrid.FacetGrid at 0x7fedfe324880>
sns.catplot(x="coefficient",
y="feature name",
kind="bar",
palette = "light:b_r",
data=feature_importance_laso_ridg_sorted).set(title="Feature Importance Plot")
<seaborn.axisgrid.FacetGrid at 0x7fedfec5b4f0>
feature_importance_laso = pd.DataFrame(zip(X.columns,model_lasso_final.best_estimator_._final_estimator.coef_),columns=('feature name','coefficient'))
feature_importance_laso_sorted = feature_importance_laso.sort_values(by='coefficient', key=abs,ascending=False)
feature_importance_laso_sorted[:10]
| feature name | coefficient | |
|---|---|---|
| 108 | Ba | 83.356943 |
| 46 | wtd_mean_ThermalConductivity | 77.776971 |
| 48 | wtd_gmean_ThermalConductivity | -60.120680 |
| 134 | Bi | 35.343711 |
| 31 | wtd_gmean_ElectronAffinity | -32.810794 |
| 57 | wtd_std_Valence | -31.882476 |
| 6 | range_atomic_mass | 28.569184 |
| 8 | wtd_std_atomic_mass | -24.498506 |
| 18 | range_atomic_radius | 22.479571 |
| 5 | wtd_entropy_atomic_mass | 21.142982 |
feature_importance_ridge = pd.DataFrame(zip(X.columns,model_ridge_final.best_estimator_._final_estimator.coef_),columns=('feature name','coefficient'))
feature_importance_ridge_sorted = feature_importance_ridge.sort_values(by='coefficient', key=abs,ascending=False)
feature_importance_ridge_sorted[:10]
| feature name | coefficient | |
|---|---|---|
| 108 | Ba | 33.262584 |
| 57 | wtd_std_Valence | -30.083138 |
| 5 | wtd_entropy_atomic_mass | 27.203488 |
| 46 | wtd_mean_ThermalConductivity | 26.858327 |
| 6 | range_atomic_mass | 25.581244 |
| 18 | range_atomic_radius | 25.296134 |
| 48 | wtd_gmean_ThermalConductivity | -22.472541 |
| 33 | wtd_entropy_ElectronAffinity | -20.790031 |
| 51 | range_ThermalConductivity | 20.583099 |
| 134 | Bi | 19.469492 |