Giter VIP home page Giter VIP logo

relax's People

Contributors

birkhoffg avatar grahams-uncle avatar

Watchers

 avatar

Forkers

grahams-uncle

relax's Issues

Get rid of the Pytorch dependencies in `TabularDataModule`

Pytorch is only needed for loading data. Our library mainly handles tabular data, so data loading would not be a bottleneck to most scenarios. Pytorch Dataloader is overkill for our project in most use cases.

Purpose

Write a drop-in NumpyLoader.

ToDo

Delete the Pytorch Dependency

https://github.com/BirkhoffG/cfnet/blob/24783713dc787cd5b13e70aa483e455c4856198f/settings.ini#L18

https://github.com/BirkhoffG/cfnet/blob/24783713dc787cd5b13e70aa483e455c4856198f/cfnet/datasets.py#L9

Next, modify the following code to make them not inherent Pytorch Dataset and DataLoader:

https://github.com/BirkhoffG/cfnet/blob/24783713dc787cd5b13e70aa483e455c4856198f/cfnet/datasets.py#L12-L22

https://github.com/BirkhoffG/cfnet/blob/24783713dc787cd5b13e70aa483e455c4856198f/cfnet/datasets.py#L35-L51

Expected Functionalities

NumpyDataset should contain all the input data.

# x, y are jax.numpy.array, such that len(x) == len(y)
dataset = NumpyDataset(x, y)

x, y = dataset[:] # access all the data of x, y
x_5, y_5 = dataset[:5] # access first five data of x, y

NumpyLoader iterates the NumpyDataset. See Pytorch Docs.

batch_size = 128
dataloader = NumpyLoader(
    dataset, # a `NumpyDataset`
    batchsize=batch_size,
    shuffle=True, # if True, shuffle the data; else, return the data in order
    drop_last=False # if True, discard the last batch (if len(dataset) % batchsize != 0); else, return the last batch
)

for x, y in dataloader:
    assert len(x) == batch_size
    assert len(y) == batch_size
    ...

Refactor util functions

Move

  • binary_cross_entropy in cfnet.methods.vanilla
  • grad_update, cat_normalize in cfnet.training_module

into cfnet.utils

Provide Default Data Configs for `TabularDataModule`

Proposal

data_module = TabularDataModule('adult')

As such, TabularDataModule will automatically load data_configs of the adult dataset.

We should also allow TabularDataModule to pass user-defined configs (i.e., current argument data_configs: str | dict).

CI/CD takes too long

Seems to run some unnecessary tests (e.g., train some models) during the testing

Customize DataModule-dependent constraints

We use cat_normalize for encoding features, and clip continuous features to [0, 1]. This is because we use one-hot encoding for cat features, and min-max scalar for cont features.

If a user wants to use other encoding methods (e.g., standardized cont features), our current way of handing normalized data is not applicable.

Proposed features:

Pass `seed` and `batch_size` to the dataloader functions in `TabularDataModule`

  1. Pass seed and batch_size to TabularDataModule.train_dataloader, TabularDataModule.val_dataloader, and TabularDataModule.test_dataloader.

  2. batch_size should also be an argument in TrainingConfigs
    https://github.com/BirkhoffG/cfnet/blob/2ee1a3203a9935e89b2ed8adf175ee1150fd9960/cfnet/train.py#L15

  3. Deprecated batch_size in DataConfigs

  4. Finally, pass appropriate arguments:
    https://github.com/BirkhoffG/cfnet/blob/2ee1a3203a9935e89b2ed8adf175ee1150fd9960/cfnet/train.py#L58-L59

Support aux arguments of `pred_fn` to be passed to `generate_cf_explanations`

Currently, we assume pred_fn is a function of only one input x. E.g., it is something like:

pred_fn = lambda x: 2 * x + 1

However, it is possible that user-defined pred_fn takes other arguments.

Hence, I propose

def generate_cf_explanations(
    cf_module: BaseCFModule,
    datamodule: TabularDataModule,
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray] = None,
    *,
    t_configs=None,
    pred_fn_args: dict=None
)

where inside, we call pred_fn as

pred_fn(x, **pred_fn_args)

This offers additional flexibility for models that are not implemented using our framework.

Support hyper-parameter searching for CF explanation methods

Supporting hyper-parameter searching enables us to properly benchmark the algorithms. This issue is a thread discussing how to support hyperparameter searching in CF explanation methods.

In essence, this problem is a multi-objective problem (i.e., minimizing invalidity and cost).

Some open-sourced libraries of hyper-parameter searching:

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.