Giter VIP home page Giter VIP logo

proxyda's Introduction

KPLA is a python implementation of the paper, ``Proxy methods for domain adaptation''.

Installation

Dependencies

KPLA requires:

  • cosde download link
  • latent_shift_adaptation download link
  • cvxpy (>= 1.4.1)
  • cvxopt (>= 1.3.2)
  • jax (>= 0.4.25)
  • pandas (>= 2.0.3)
  • matplotlib (>= 3.7.2)
  • numpy (>= 1.24.3)
  • scikit-image (>= 0.22.0)
  • scikit-learn (>= 1.3.0)
  • scipy (>= 1.11.2)
  • tqdm (>= 4.66.1)
  • tensorflow (>= 2.11.0)

=======

Install the KPLA

python setup.py install

Usage

Steps to run the experiment of adaptation with concepts and proxies:

  1. Prepare data in the following dictionary form:
  data = {
    "X": jax.numpy.ndarray, (n_samples, n_features) or (n_samples,)
    "Y": jax.numpy.ndarray, (n_samples, n_features) or (n_samples,)
    "W": jax.numpy.ndarray, (n_samples, n_features) or (n_samples,)
    "C": jax.numpy.ndarray, (n_samples, n_features) or (n_samples,)
  }
  1. Specify the index of the kernel function (string) of each variable.
  kernel_dict = {}

  kernel_dict["cme_w_xc"] = {"X":  KERNEL_X,
                            "C":   KERNEL_C,
                            "Y":   KERNEL_W} #Y is W
  kernel_dict["cme_wc_x"] = {"X":  KERNEL_X,
                            "Y": [{"kernel": KERNEL_W, "dim": DIM_W},            # 
                                  {"kernel": KERNEL_C, "dim": DIM_C}]} # Y is (W,C)
  kernel_dict["h0"]       = {"C": KERNEL_C}

Current implementation of the kernel function:

Kernel function Index
radial basis function (RBF) "rbf"
columnwise RBF "rbf_column"
binary "binary"
columnwise binary "binary_column
  1. Prepare method set and lambda (regualrization) set:
  method_set = {"cme": "original", "h0": "original"}
  lam_set = {"cme":    LAM_1,   # L2 penalty for the conditional mean embedding
            "h0":      LAM_2,   # L2 penalty for the bridge function
            "lam_min": LAM_MIN, 
            "lam_max": LAM_MAX}

Note: the current version only implements method "original" which computes the whole Gram matrix. We plan to implement Nystrom or other approximation method in the future. 4. Train the model

from KPLA.models.plain_kernel.adaptation import FullAdapt

estimator_full = FullAdapt(source_train,
                           target_train,
                           source_test,
                           target_test,
                           split,       # split the training data or not, Boolean
                           scale,       # kernel length-scale, float
                           lam_set,
                           method_set,
                           kernel_dict)

estimator_full.fit(task = TASK) # task="c" for classification, task="r" for regression
estimator_full.evaluation(task = TASK)
  1. Model selection using cross-validation or validation set. We implement the function to select the parameters of lam_set and scale.
from KPLA.models.plain_kernel.model_selection import tune_adapt_model_cv

b_estimator, b_params = tune_adapt_model_cv(source_train,
                                            target_train,
                                            source_test,
                                            target_test,
                                            method_set,
                                            kernel_dict,
                                            use_validation = USE_VAL, # True: extra validation set False:cross-validation 
                                            val_data,       
                                            model          = FullAdapt,          
                                            task           = TASK,
                                            fit_task       = TASK,
                                            n_params       = N_PARAMS, 
                                            n_fold         = N_FOLD,
                                            min_log        = MIN_VAL, # minimum value of grid search, log10 scale
                                            max_log        = MAX_VAL, # maximum value of grid search, log10 scale
                                            )

Steps to run the experiment of multi-source adaptation:

  1. Prepare data list in the following dictionary form:
  data_domain_i = {
    "X": jax.numpy.ndarray, (n_samples, n_features) or (n_samples,)
    "Y": jax.numpy.ndarray, (n_samples, n_features) or (n_samples,)
    "W": jax.numpy.ndarray, (n_samples, n_features) or (n_samples,)
    "Z": jax.numpy.ndarray, (n_samples, n_domains) or (n_samples,) # domain index, every entry has the same value
  }
  data = [data_domain_1, ..., data_domain_n]
  1. Specify the index of the kernel function (string) of each variable. There are two versions of the multi-source adaptation.

To use MultiEnvAdaptCAT, the kernel script is:

kernel_dict = {}

kernel_dict['cme_w_xz'] = {'X': KERNEL_X, 'Y': KERNEL_W} # Y is W
kernel_dict['cme_w_x']  = {'X': KERNEL_X, 'Y': KERNEL_W} # Y is W
kernel_dict['m0']       = {'X': KERNEL_X}

To use MultiEnvAdapt, the kernel script is:

kernel_dict = {}

kernel_dict['cme_w_xz'] = {'X': KERNEL_X, 'Y': KERNEL_W, 'Z': KERNEL_Z} # Y is W
kernel_dict['cme_w_x']  = {'X': KERNEL_X, 'Y': KERNEL_W}                # Y is W
kernel_dict['m0']       = {'X': KERNEL_X}

Let $Z$ be the domain index and is one-hot encoded. Setting KENEL_Z='binary' in MultiEnvAdapt is the same as using MultiEnvAdaptCAT. MultiEnvAdapt is a more flexible version of MultiEnvAdaptCAT that can take continuous value of $Z$ and allows user to specify the underlying kernel function. 3. Prepare method set and lambda (regualrization) set:

  method_set = {"cme": "original", "m0": "original"}
  lam_set = {"cme":     LAM_1,   # L2 penalty for the conditional mean embedding
             "m0":      LAM_2,   # L2 penalty for the bridge function
             "lam_min": LAM_MIN, 
             "lam_max": LAM_MAX}
  1. Train the model
from KPLA.models.plain_kernel.multienv_adaptation import  MultiEnvAdaptCAT


from KPLA.models.plain_kernel.multienv_adaptation import  MultiEnvAdapt
estimator_multi_a = MultiEnvAdapt(source_train,
                                  target_train,
                                  source_test,
                                  target_test,
                                  split,
                                  scale,
                                  lam_set,
                                  method_set,
                                  kernel_dict)
  1. Model selection using cross-validation or validation set.

Run Experiments

Navigate the examples in ./tests directory.

First execute the model selection program under ./model_selection then run the program under ./experiments folder.

For the simulated regression tasks, sim_multisource_bin and sim_multisource_cont, first launch execute the hyperparameter tuning program test_proposed_onehot.py for each regression task. Then run the experiments for the baseline (sweep_baselines.py) and proposed (sweep_proposed.py) methods.

proxyda's People

Contributors

nicchiou avatar tsai-kailin avatar olawalesalaudeen avatar

Stargazers

Zian Chen avatar  avatar

Watchers

Sanmi Koyejo avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.