Giter VIP home page Giter VIP logo

sparks-baird / crabnet-hyperparameter Goto Github PK

View Code? Open in Web Editor NEW
3.0 3.0 1.0 212.02 MB

Using Bayesian optimization via Ax platform + SAASBO model to simultaneously optimize 23 hyperparameters in 100 iterations (set a new Matbench benchmark).

Home Page: https://doi.org/10.1016/j.commatsci.2022.111505

License: MIT License

Python 0.01% HTML 99.96% Jupyter Notebook 0.03%
adaptive-design adaptive-experimentation-platform bayesian-optimization benchmark machine-learning materials-discovery materials-informatics transformer-network

crabnet-hyperparameter's People

Contributors

mliu7051 avatar sgbaird avatar sgbaird-alt avatar

Stargazers

 avatar  avatar  avatar

Forkers

itzme-jp

crabnet-hyperparameter's Issues

How to calculate/plot feature importances?

@rekumar

based on plot_feature_importances_by_feature_plotly:

def my_plot_feature_importance_by_feature_plotly(
model: ModelBridge = None,
feature_importances: dict = None,
error_x: dict = None,
metric_names: Iterable[str] = None,
relative: bool = True,
caption: str = "",
) -> go.Figure:
"""One plot per metric, showing importances by feature.
Args:
model: A model with a ``feature_importances`` method.
relative: whether to normalize feature importances so that they add to 1.
caption: an HTML-formatted string to place at the bottom of the plot.
Returns a go.Figure of feature importances.
Notes:
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license. Modifed by @sgbaird.
"""
traces = []
dropdown = []
if metric_names is None:
assert model is not None, "specify model or metric_names"
metric_names = model.metric_names
assert metric_names is not None, "specify model or metric_names"
for i, metric_name in enumerate(sorted(metric_names)):
try:
if feature_importances is not None:
importances = feature_importances
else:
assert model is not None, "model is None"
importances = model.feature_importances(metric_name)
except NotImplementedError:
logger.warning(
f"Model for {metric_name} does not support feature importances."
)
continue
factor_col = "Factor"
importance_col = "Importance"
std_col = "StdDev"
low_col = "err_minus"
assert error_x is not None, "specify error_x"
df = pd.DataFrame(
[
{factor_col: factor, importance_col: importance}
for factor, importance in importances.items()
]
)
err_df = pd.Series(error_x).to_frame(name=std_col)
err_df.index.names = [factor_col]
df = pd.concat((df.set_index(factor_col), err_df), axis=1).reset_index()
if relative:
totals = df[importance_col].sum()
df[importance_col] = df[importance_col].div(totals)
df[std_col] = df[std_col].div(totals)
low_df = df[std_col]
low_df[low_df > df[importance_col]] = df[importance_col]
df[low_col] = low_df
df = df.sort_values(importance_col)
traces.append(
go.Bar(
name=importance_col,
orientation="h",
visible=i == 0,
x=df[importance_col],
y=df[factor_col],
error_x=dict(
type="data",
symmetric=False,
array=df[std_col].to_list(),
arrayminus=df[low_col].to_list(),
),
)
)
is_visible = [False] * len(sorted(metric_names))
is_visible[i] = True
dropdown.append(
{"args": ["visible", is_visible], "label": metric_name, "method": "restyle"}
)
if not traces:
raise NotImplementedError("No traces found for metric")
updatemenus = [
{
"x": 0,
"y": 1,
"yanchor": "top",
"xanchor": "left",
"buttons": dropdown,
"pad": {
"t": -40
}, # hack to put dropdown below title regardless of number of features
}
]
features = traces[0].y
title = (
"Relative Feature Importances" if relative else "Absolute Feature Importances"
)
layout = go.Layout(
height=200 + len(features) * 20,
hovermode="closest",
margin=go.layout.Margin(
l=8 * min(max(len(idx) for idx in features), 75) # noqa E741
),
showlegend=False,
title=title,
updatemenus=updatemenus,
annotations=compose_annotation(caption=caption),
)
if relative:
layout.update({"xaxis": {"tickformat": ".0%"}})
fig = go.Figure(data=traces, layout=layout)
return fig
def add_r2(fig, xref=0.05, yref=0.95):
y_act = fig["data"][1].x
y_pred = fig["data"][1].y
r2 = r2_score(y_act, y_pred)
fig.add_annotation(
text=f"r2={r2:.2f}",
xref="paper",
yref="paper",
x=xref,
y=yref,
showarrow=False,
)
return fig

with the function called in

fig = my_plot_feature_importance_by_feature_plotly(
model=None,
feature_importances=avg_saas_importances,
error_x=std_saas_importances,
metric_names=["crabnet_mae"],
)

The feature importances are the inverse lengthscales of the SAASBO model, extracted from each repeat campaign in

saas_feature_importances.append(saas.feature_importances(metric))

It's also probably worth mentioning that "SAASBO places strong priors on the inverse lengthscales to avoid overfitting in high-dimensional spaces" (SAASBO docs page).

Maybe restrict the parameter search to fewer parameters

Based on some of the initial results, the hyperparameter optimization doesn't seem to be making as much of a difference in terms of reducing the best model test MAE relative to the test MAE for the default parameters. This could be for a number of reasons:

  1. 100 is too few iterations for such a large search space
  2. the default parameters are already really good/nearly optimum
  3. there are lots of parameter choices that are equally good/nearly optimum, including the default
  4. not enough information to choose the best hyperparameters without overfitting

After it's done running (it's ~2/3 through as of now), I'm thinking maybe we choose just a few hyperparameters and run it again with 100 iterations. Thoughts?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.