Comments (8)
@csala I think for this we should have a verbose
parameter that turns the printing on/off. However, in either case I think it'd be helpful for fit()
to return data of the training history, so users can inspect/plot it afterwards. Maybe a wrapper around a pandas data frame but you'd probably have a better idea on what the most Pythonic approach is. Let me know your thoughts on this and I'd be happy to whip something together.
from ctgan.
I'll be honest, the only reason I added GPU status was because I liked watching the temperature go up with more epochs 😏
from ctgan.
Any updates on this issue?
Playing around in Colab I put this together : https://colab.research.google.com/drive/1JA_Ap1bQDmlhm_tC1k8RL0MNYKBluJNa
and added some new arguments to the fit()
class in synthesizer.py . However I'm certain that my methods of implementation are probably completely off. Any feedback greatly appreciated.
Args:
train_data (numpy.ndarray or pandas.DataFrame):
Training Data. It must be a 2-dimensional numpy array or a
pandas.DataFrame.
discrete_columns (list-like):
List of discrete columns to be used to generate the Conditional
Vector. If ``train_data`` is a Numpy array, this list should
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
verbosity (boolean):
Choose to display epochs during the run. Defaults to ``True``.
epochs (int):
Number of training epochs. Defaults to 300.
log_frequency (boolean):
Whether to use log frequency of categorical levels in conditional
sampling. Defaults to ``True``.
gpu_stats (boolean):
Whether to display gpu stats for each epoch. Fitting may be slowed down
with this option turned on. Only supports nvidia GPUs at this time.
Defaults to ``False``.
early_stopping (boolean):
Whether to stop fitting early if loss function has not improved for
specified number 'patience' of epochs. Defaults to ``False``.
patience (int):
Number of epochs to monitor to see if loss function improves.
Defaults to ``10`` if early_stopping turned on.
logging (boolean):
Whether to store the generator loss and discriminator loss into a csv
log file with timestamp. Defaults to ``False``.
from ctgan.
@csala it will be very helpful.
IMHO, something similar to Keras model.fit output, may be considered.
ctgan = CTGANSynthesizer()
hist = ctgan.fit(data, discrete_columns)
where hist
is a dictionary containing the generator and discriminator loss per epoch, and may be extended to other metrics in the future.
from ctgan.
In my own implementation I added loops using tqdm (progress bars) for both the epochs and steps. You can add logging information like loss there as well.
Related to how this information should be logged and also the proposal @oregonpillow did, I think the following:
- The information that you're logging is really good and I like it a lot! The GPU stats are also a nice added bonus.
- The histogram should not be returned by
fit
. To me at least, this does not feel intuitive. I think this information can be logged as a attribute, likectgan.hist
orctgan.logs
or something. - Writing directly to files seems a bit much for an implementation in CTGAN.
- I think an option to facilitate many of these things is using a callback systems, similar to FastAI. We call
on_epoch_end
,on_epoch_start
and other methods on the objects inctgan.callbacks
. These callbacks can be anything, ranging from logging objects to early stopping.
from ctgan.
Can i please know what is the metric used here in the loss calculation
Epoch 105, Loss G: -7.7396, Loss D: -0.3223, this is what i get when i try to fit the model over the training data
from ctgan.
@NadeemNicoR De metric is raw logit output iirc. The loss of G is just the average error of the samples produced by G. The loss of D is the loss of G - the loss of the real samples. I'm doing this by heart, so let me know if this is incorrect.
from ctgan.
#147 addressed this issue so I'm closing it off. For further discussion about the verbosity
parameter, let's use the overall SDV GitHub.
from ctgan.
Related Issues (20)
- Remove upper bound for pandas
- load_demo raises urllib.error.HTTPError: HTTP Error 403: Forbidden
- Torch 2.0 fails with cuda=False
- Should a 5-Likert scale be treated as either continuous or discrete? HOT 2
- Multi GPU support
- Avoid generating the conditional column
- Add support for Python 3.11
- Add progress bar for CTGAN fitting (+ save the loss values)
- Question about large amount of training dataset in TVAE -- is there max? HOT 1
- Add verbosity TVAE (progress bar + save the loss values)
- Condition with inequality for continuous columns
- Drop support for Python 3.7
- Question regarding CTGAN for data synthesis and classification tasks
- Tracking and Saving TVAE Loss Values HOT 2
- Set generator to eval mode before sampling?
- Switch default branch from master to main
- Remove or implement CTGAN tests
- `ClusterBasedNormalizer` refactor
- Hyperparameters
- Doubts on the usage of conditional sampling HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ctgan.