Giter VIP home page Giter VIP logo

tfdeepsurv's Introduction

TFDeepSurv

Deep Cox proportional risk model and survival analysis implemented by tensorflow.

1. Differences from DeepSurv

DeepSurv, a package of Deep Cox proportional risk model, is open-source on Github. But our works may shine in:

  • Evaluating variable importance in deep neural network.
  • Identifying ties of death time in your survival data, which means different loss function and estimator for survival function (Breslow or Efron approximation).
  • Providing survival function estimated by three optional algorithm.
  • Tuning hyperparameters of DNN using scientific method - Bayesian Hyperparameters Optimization.

2. Statement

The project is based on the research of Breast Cancer. The paper about this project has been submitted to IEEE JBHI. We will update status here once paper published !

3. Installation

From source

Download TFDeepSurv package and install from the directory (Python version : 3.x):

git clone https://github.com/liupei101/TFDeepSurv.git
cd TFDeepSurv
pip install .

4. Get it started:

4.1 Runing with simulated data

import packages and prepare data

### import package
from tfdeepsurv import dsl
from tfdeepsurv.dataset import SimulatedData
### generate simulated data
# data configuration: 
#     hazard ratio = 2000
#     number of features = 10
#     number of valid features = 2
data_generator = SimulatedData(2000, num_var=2, num_features=10)
# training dataset: 
#     number of rows = 2000
#     random seed = 1
train_data = data_generator.generate_data(2000, seed=1)
# test dataset :
#     number of rows = 800
#     random seed = 1
test_data = data_generator.generate_data(800, seed=1)

Visualize survival status

import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter
from lifelines.plotting import add_at_risk_counts

### Visualize survival status
fig, ax = plt.subplots(figsize=(8, 6))

l_kmf = []
# training set
kmf = KaplanMeierFitter()
kmf.fit(train_data['t'], event_observed=train_data['e'], label='Training Set')
kmf.survival_function_.plot(ax=ax)
l_kmf.append(kmf)
# test set
kmf = KaplanMeierFitter()
kmf.fit(test_data['t'], event_observed=test_data['e'], label='Test Set')
kmf.survival_function_.plot(ax=ax)
l_kmf.append(kmf)

# 
plt.ylim(0, 1.01)
plt.xlabel("Time")
plt.ylabel("Survival rate")
plt.title("Survival Curve")
plt.legend(loc="best", title="Dataset")
add_at_risk_counts(*l_kmf, ax=ax)
plt.show()

result :

Initialize your neural network

input_nodes = 10
output_nodes = 1
train_X = train_data['x']
train_y = {'e': train_data['e'], 't': train_data['t']}
# the arguments of dsnn is obtained by Bayesian Hyperparameters Tuning
model = dsl.dsnn(
    train_X, train_y,
    input_nodes, [6, 3], output_nodes, 
    learning_rate=0.7,
    learning_rate_decay=1.0,
    activation='relu', 
    L1_reg=3.4e-5, 
    L2_reg=8.8e-5, 
    optimizer='adam',
    dropout_keep_prob=1.0
)
# Get the type of ties (three types)
# 'noties', 'breslow' when ties occur or 'efron' when ties occur frequently
print(model.get_ties())

Train neural network model

# Plot curve of loss and CI on train data
model.train(num_epoch=1900, iteration=100,
            plot_train_loss=True, plot_train_ci=True)

result :

-------------------------------------------------
training steps 1:
loss = 7.07988.
CI = 0.494411.
-------------------------------------------------
training steps 101:
loss = 7.0797.
CI = 0.524628.
-------------------------------------------------
training steps 201:
loss = 7.06293.
CI = 0.569339.
...
...
...
-------------------------------------------------
training steps 1801:
loss = 6.27862.
CI = 0.823937.

Curve of loss and CI:

Loss Value CI

Evaluate model performance

test_X = test_data['x']
test_y = {'e': test_data['e'], 't': test_data['t']}
print("CI on train set: %g" % model.score(train_X, train_y))
print("CI on test set: %g" % model.score(test_X, test_y))

result :

CI on train set: 0.823772
CI on test set: 0.812503

Evaluate variable importance

model.get_vip_byweights()

result:

0th feature score : 1.
1th feature score : 0.149105.
2th feature score : -0.126712.
3th feature score : 0.033377.
4th feature score : 0.123096.
5th feature score : 0.0321232.
6th feature score : 0.101529.
7th feature score : -0.0707392.
8th feature score : -0.0415884.
9th feature score : 0.0439712.

Get estimation of survival function

# optional algo: 'wwe', 'bls' or 'kp', the algorithm for estimating survival function
model.survival_function(test_X[0:3], algo="wwe")

result:

Survival rate

4.2 Runing with real-world data

The procedure on real-world data is similar with the described on simulated data. One we need to notice is data preparation. This package provides functions for loading standard dataset for traning or testing.

load real-world data

# import package
from tfdeepsurv import dsl
from tfdeepsurv.utils import load_data

# Notice: the object train_X or test_X returned from function load_data is numpy.array.
# the object train_y or test_y returned from function load_data is dict like {'e': numpy.array,'t': numpy.array}.

# You can load training data and testing data, respectively
train_X, train_y = load_data('train.csv', excluded_col=['ID'], surv_col={'e': 'event', 't': 'time'})
test_X, test_y = load_data('test.csv', excluded_col=['ID'], surv_col={'e': 'event', 't': 'time'})
# Or load full data, then split it into training and testing set (=8:2).
train_X, train_y, test_X, test_y = load_data('full_data.csv', excluded_col=['ID'], surv_col={'e': 'event', 't': 'time'}, split_ratio=0.8)

Traning or testing tfdeepsurv model

This is the same as doing in simulated data.

5. More properties

We provide tools for hyperparameters tuning (Bayesian Hyperparameters Optimization) in deep neural network, which is automatic in searching optimal hyperparameters of DNN.

For more usage of Bayesian Hyperparameters Optimization, you can refer to here

tfdeepsurv's People

Contributors

liupei101 avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.