birkhoffg / relax Goto Github PK
View Code? Open in Web Editor NEWRecourse Explanation Library in JAX
Home Page: https://birkhoffg.github.io/ReLax/
License: Apache License 2.0
Recourse Explanation Library in JAX
Home Page: https://birkhoffg.github.io/ReLax/
License: Apache License 2.0
Plugin to #50
Plugin to #50
We might want to provide a hook to the data_module for accessing cat_idx
and cat_array
.
Create a Jupyter Notebook to benchmark methods in Carla
Pytorch is only needed for loading data. Our library mainly handles tabular data, so data loading would not be a bottleneck to most scenarios. Pytorch Dataloader is overkill for our project in most use cases.
Write a drop-in NumpyLoader
.
Delete the Pytorch Dependency
https://github.com/BirkhoffG/cfnet/blob/24783713dc787cd5b13e70aa483e455c4856198f/settings.ini#L18
Next, modify the following code to make them not inherent Pytorch Dataset
and DataLoader
:
NumpyDataset
should contain all the input data.
# x, y are jax.numpy.array, such that len(x) == len(y)
dataset = NumpyDataset(x, y)
x, y = dataset[:] # access all the data of x, y
x_5, y_5 = dataset[:5] # access first five data of x, y
NumpyLoader
iterates the NumpyDataset
. See Pytorch Docs.
batch_size = 128
dataloader = NumpyLoader(
dataset, # a `NumpyDataset`
batchsize=batch_size,
shuffle=True, # if True, shuffle the data; else, return the data in order
drop_last=False # if True, discard the last batch (if len(dataset) % batchsize != 0); else, return the last batch
)
for x, y in dataloader:
assert len(x) == batch_size
assert len(y) == batch_size
...
Move
binary_cross_entropy
in cfnet.methods.vanilla
grad_update
, cat_normalize
in cfnet.training_module
into cfnet.utils
You can write something like:
batch_data = self.dataset[self.indices]
This line seems to slow the entire training
data_module = TabularDataModule('adult')
As such, TabularDataModule
will automatically load data_configs of the adult dataset.
We should also allow TabularDataModule
to pass user-defined configs (i.e., current argument data_configs: str | dict
).
We don't support PredictiveTrainingModuleConfigs
Line 154 in ec1a411
There are some outdated links (i.e., some links are still written in cfnet
). Please try to fix ALL of them to the correct links.
For example,
https://github.com/BirkhoffG/ReLax/blob/v0.1/relax/data/module.py#L415-L416
Seems to run some unnecessary tests (e.g., train some models) during the testing
Check monitor_metrics before actually finding the metric in logs.
Before this line:
https://github.com/BirkhoffG/cfnet/blob/2ee1a3203a9935e89b2ed8adf175ee1150fd9960/cfnet/_ckpt_manager.py#L54
Check monitor_metrics.
raise ValueError(...)
if monitor_metrics
is not appropriately configured.
We use cat_normalize
for encoding features, and clip continuous features to [0, 1]. This is because we use one-hot encoding for cat features, and min-max scalar for cont features.
If a user wants to use other encoding methods (e.g., standardized cont features), our current way of handing normalized data is not applicable.
Proposed features:
Pass seed
and batch_size
to TabularDataModule.train_dataloader
, TabularDataModule.val_dataloader
, and TabularDataModule.test_dataloader
.
batch_size
should also be an argument in TrainingConfigs
https://github.com/BirkhoffG/cfnet/blob/2ee1a3203a9935e89b2ed8adf175ee1150fd9960/cfnet/train.py#L15
Deprecated batch_size
in DataConfigs
Finally, pass appropriate arguments:
https://github.com/BirkhoffG/cfnet/blob/2ee1a3203a9935e89b2ed8adf175ee1150fd9960/cfnet/train.py#L58-L59
For example, remove this function:
Line 156 in 729cacf
Something like:
If monitor_metrics is None:
# no checkpoint manage
Currently, we assume pred_fn
is a function of only one input x
. E.g., it is something like:
pred_fn = lambda x: 2 * x + 1
However, it is possible that user-defined pred_fn
takes other arguments.
Hence, I propose
def generate_cf_explanations(
cf_module: BaseCFModule,
datamodule: TabularDataModule,
pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray] = None,
*,
t_configs=None,
pred_fn_args: dict=None
)
where inside, we call pred_fn
as
pred_fn(x, **pred_fn_args)
This offers additional flexibility for models that are not implemented using our framework.
Supporting hyper-parameter searching enables us to properly benchmark the algorithms. This issue is a thread discussing how to support hyperparameter searching in CF explanation methods.
In essence, this problem is a multi-objective problem (i.e., minimizing invalidity and cost).
Some open-sourced libraries of hyper-parameter searching:
Lines 71 to 73 in 729cacf
Reference:
https://stackoverflow.com/a/68293931
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.