|
def my_plot_feature_importance_by_feature_plotly( |
|
model: ModelBridge = None, |
|
feature_importances: dict = None, |
|
error_x: dict = None, |
|
metric_names: Iterable[str] = None, |
|
relative: bool = True, |
|
caption: str = "", |
|
) -> go.Figure: |
|
"""One plot per metric, showing importances by feature. |
|
|
|
Args: |
|
model: A model with a ``feature_importances`` method. |
|
relative: whether to normalize feature importances so that they add to 1. |
|
caption: an HTML-formatted string to place at the bottom of the plot. |
|
|
|
Returns a go.Figure of feature importances. |
|
|
|
Notes: |
|
Copyright (c) Facebook, Inc. and its affiliates. |
|
|
|
This source code is licensed under the MIT license. Modifed by @sgbaird. |
|
""" |
|
traces = [] |
|
dropdown = [] |
|
if metric_names is None: |
|
assert model is not None, "specify model or metric_names" |
|
metric_names = model.metric_names |
|
assert metric_names is not None, "specify model or metric_names" |
|
for i, metric_name in enumerate(sorted(metric_names)): |
|
try: |
|
if feature_importances is not None: |
|
importances = feature_importances |
|
else: |
|
assert model is not None, "model is None" |
|
importances = model.feature_importances(metric_name) |
|
except NotImplementedError: |
|
logger.warning( |
|
f"Model for {metric_name} does not support feature importances." |
|
) |
|
continue |
|
factor_col = "Factor" |
|
importance_col = "Importance" |
|
std_col = "StdDev" |
|
low_col = "err_minus" |
|
assert error_x is not None, "specify error_x" |
|
df = pd.DataFrame( |
|
[ |
|
{factor_col: factor, importance_col: importance} |
|
for factor, importance in importances.items() |
|
] |
|
) |
|
err_df = pd.Series(error_x).to_frame(name=std_col) |
|
err_df.index.names = [factor_col] |
|
df = pd.concat((df.set_index(factor_col), err_df), axis=1).reset_index() |
|
|
|
if relative: |
|
totals = df[importance_col].sum() |
|
df[importance_col] = df[importance_col].div(totals) |
|
df[std_col] = df[std_col].div(totals) |
|
|
|
low_df = df[std_col] |
|
low_df[low_df > df[importance_col]] = df[importance_col] |
|
df[low_col] = low_df |
|
|
|
df = df.sort_values(importance_col) |
|
traces.append( |
|
go.Bar( |
|
name=importance_col, |
|
orientation="h", |
|
visible=i == 0, |
|
x=df[importance_col], |
|
y=df[factor_col], |
|
error_x=dict( |
|
type="data", |
|
symmetric=False, |
|
array=df[std_col].to_list(), |
|
arrayminus=df[low_col].to_list(), |
|
), |
|
) |
|
) |
|
|
|
is_visible = [False] * len(sorted(metric_names)) |
|
is_visible[i] = True |
|
dropdown.append( |
|
{"args": ["visible", is_visible], "label": metric_name, "method": "restyle"} |
|
) |
|
if not traces: |
|
raise NotImplementedError("No traces found for metric") |
|
|
|
updatemenus = [ |
|
{ |
|
"x": 0, |
|
"y": 1, |
|
"yanchor": "top", |
|
"xanchor": "left", |
|
"buttons": dropdown, |
|
"pad": { |
|
"t": -40 |
|
}, # hack to put dropdown below title regardless of number of features |
|
} |
|
] |
|
features = traces[0].y |
|
title = ( |
|
"Relative Feature Importances" if relative else "Absolute Feature Importances" |
|
) |
|
layout = go.Layout( |
|
height=200 + len(features) * 20, |
|
hovermode="closest", |
|
margin=go.layout.Margin( |
|
l=8 * min(max(len(idx) for idx in features), 75) # noqa E741 |
|
), |
|
showlegend=False, |
|
title=title, |
|
updatemenus=updatemenus, |
|
annotations=compose_annotation(caption=caption), |
|
) |
|
|
|
if relative: |
|
layout.update({"xaxis": {"tickformat": ".0%"}}) |
|
|
|
fig = go.Figure(data=traces, layout=layout) |
|
|
|
return fig |
|
|
|
|
|
def add_r2(fig, xref=0.05, yref=0.95): |
|
y_act = fig["data"][1].x |
|
y_pred = fig["data"][1].y |
|
r2 = r2_score(y_act, y_pred) |
|
fig.add_annotation( |
|
text=f"r2={r2:.2f}", |
|
xref="paper", |
|
yref="paper", |
|
x=xref, |
|
y=yref, |
|
showarrow=False, |
|
) |
|
return fig |