Comments (2)
Ok so here you go. I picked the code for training from this repo.
train.py
:
import torch
import wandb
from models.dvae import DiscreteVAE
from utils.arch_utils import TorchMelSpectrogram
from torch.utils.data import DataLoader
from utils.dvae_dataset import DVAEDataset
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
import pdb
from TTS.tts.datasets import load_tts_samples
from TTS.config.shared_configs import BaseDatasetConfig
dvae_checkpoint = '/home/ubuntu/test_tts/SimpleTTS/xtts/run/training/XTTS_v2.0_original_model_files/dvae.pth'
mel_norm_file = '/home/ubuntu/test_tts/SimpleTTS/xtts/run/training/XTTS_v2.0_original_model_files/mel_stats.pth'
config_dataset = BaseDatasetConfig(
formatter="ljspeech",
dataset_name="ljspeech",
path="/home/ubuntu/test_tts/sapien-formatted-english-22050",
meta_file_train="/home/ubuntu/test_tts/sapien-formatted-english-22050/metadata_norm.txt",
language="en",
)
# Add here the configs of the datasets
DATASETS_CONFIG_LIST = [config_dataset]
GRAD_CLIP_NORM = 0.5
LEARNING_RATE = 5e-05
dvae = DiscreteVAE(
channels=80,
normalization=None,
positional_dims=1,
num_tokens=1024,
codebook_dim=512,
hidden_dim=512,
num_resnet_blocks=3,
kernel_size=3,
num_layers=2,
use_transposed_convs=False,
)
dvae.load_state_dict(torch.load(dvae_checkpoint), strict=False)
dvae.cuda()
opt = Adam(dvae.parameters(), lr = LEARNING_RATE)
torch_mel_spectrogram_dvae = TorchMelSpectrogram(
mel_norm_file=mel_norm_file, sampling_rate=22050
).cuda()
train_samples, eval_samples = load_tts_samples(
DATASETS_CONFIG_LIST,
eval_split=True,
eval_split_max_size=256,
eval_split_size=0.01,
)
eval_dataset = DVAEDataset(eval_samples, 22050, True)
train_dataset = DVAEDataset(train_samples, 22050, False)
epochs = 20
eval_data_loader = DataLoader(
eval_dataset,
batch_size=3,
shuffle=False,
drop_last=False,
collate_fn=eval_dataset.collate_fn,
num_workers=0,
pin_memory=False,
)
train_data_loader = DataLoader(
train_dataset,
batch_size=3,
shuffle=False,
drop_last=False,
collate_fn=train_dataset.collate_fn,
num_workers=4,
pin_memory=False,
)
torch.set_grad_enabled(True)
dvae.train()
wandb.init(project = 'train_dvae')
wandb.watch(dvae)
def to_cuda(x: torch.Tensor) -> torch.Tensor:
if x is None:
return None
if torch.is_tensor(x):
x = x.contiguous()
if torch.cuda.is_available():
x = x.cuda(non_blocking=True)
return x
@torch.no_grad()
def format_batch(batch):
if isinstance(batch, dict):
for k, v in batch.items():
batch[k] = to_cuda(v)
elif isinstance(batch, list):
batch = [to_cuda(v) for v in batch]
try:
batch['mel'] = torch_mel_spectrogram_dvae(batch['wav'])
# if the mel spectogram is not divisible by 4 then input.shape != output.shape
# for dvae
remainder = batch['mel'].shape[-1] % 4
if remainder:
batch['mel'] = batch['mel'][:, :, :-remainder]
except NotImplementedError:
pass
return batch
for i in range(epochs):
for cur_step, batch in enumerate(train_data_loader):
opt.zero_grad()
batch = format_batch(batch)
recon_loss, commitment_loss, out = dvae(batch['mel'])
total_loss = recon_loss + commitment_loss
total_loss.backward()
clip_grad_norm_(dvae.parameters(), GRAD_CLIP_NORM)
opt.step()
log = {'epoch': i,
'cur_step': cur_step,
'loss': total_loss.item(),
'recon_loss': recon_loss.item(),
'commit_loss': commitment_loss.item()}
print(f"epoch: {i}", print(f"step: {cur_step}"), f'loss - {total_loss.item()}', f'recon_loss - {recon_loss.item()}', f'commit_loss - {commitment_loss.item()}')
wandb.log(log)
torch.cuda.empty_cache()
# if i%10:
# save_model(f'.dvae.pth')
# wandb.save('./dvae.pth')
# wandb.finish()
Wrote a custom DVAEDataset
that is imported in the above train.py
file.
import torch
import random
from utils.dataset import key_samples_by_col
from TTS.tts.models.xtts import load_audio
torch.set_num_threads(1)
class DVAEDataset(torch.utils.data.Dataset):
def __init__(self, samples, sample_rate, is_eval):
self.sample_rate = sample_rate
self.is_eval = is_eval
self.max_wav_len = 255995
self.samples = samples
self.training_seed = 1
self.failed_samples = set()
if not is_eval:
random.seed(self.training_seed)
# random.shuffle(self.samples)
random.shuffle(self.samples)
# order by language
self.samples = key_samples_by_col(self.samples, "language")
print(" > Sampling by language:", self.samples.keys())
else:
# for evaluation load and check samples that are corrupted to ensures the reproducibility
self.check_eval_samples()
def check_eval_samples(self):
print(" > Filtering invalid eval samples!!")
new_samples = []
for sample in self.samples:
try:
_, wav = self.load_item(sample)
except:
continue
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
if (
wav is None
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
):
continue
new_samples.append(sample)
self.samples = new_samples
print(" > Total eval samples after filtering:", len(self.samples))
def load_item(self, sample):
audiopath = sample["audio_file"]
wav = load_audio(audiopath, self.sample_rate)
if wav is None or wav.shape[-1] < (0.5 * self.sample_rate):
# Ultra short clips are also useless (and can cause problems within some models).
raise ValueError
return audiopath, wav
def __getitem__(self, index):
if self.is_eval:
sample = self.samples[index]
sample_id = str(index)
else:
# select a random language
lang = random.choice(list(self.samples.keys()))
# select random sample
index = random.randint(0, len(self.samples[lang]) - 1)
sample = self.samples[lang][index]
# a unique id for each sampel to deal with fails
sample_id = lang + "_" + str(index)
# ignore samples that we already know that is not valid ones
if sample_id in self.failed_samples:
# call get item again to get other sample
return self[1]
# try to load the sample, if fails added it to the failed samples list
try:
audiopath, wav = self.load_item(sample)
except:
self.failed_samples.add(sample_id)
return self[1]
# check if the audio and text size limits and if it out of the limits, added it failed_samples
if (
wav is None
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
):
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
self.failed_samples.add(sample_id)
return self[1]
res = {
"wav": wav,
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
"filenames": audiopath,
}
return res
def __len__(self):
if self.is_eval:
return len(self.samples)
return sum([len(v) for v in self.samples.values()])
def collate_fn(self, batch):
# convert list of dicts to dict of lists
B = len(batch)
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
# stack for features that already have the same shape
batch["wav_lengths"] = torch.stack(batch["wav_lengths"])
max_wav_len = batch["wav_lengths"].max()
# create padding tensors
wav_padded = torch.FloatTensor(B, 1, max_wav_len)
# initialize tensors for zero padding
wav_padded = wav_padded.zero_()
for i in range(B):
wav = batch["wav"][i]
wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)
batch["wav"] = wav_padded
return batch
This trains the DVAE to encode and decode mel-spectograms.
Few things:
- You can see my import paths are not standard. That is because I have changed the structure of the repo a bit in my personal fork. You can follow the standard import paths as per TTS.
- There is a loss called
DiscretizationLoss
here but I am not sure where or how this is used? So I am not using it currently. - For some reason, in the
dvae.py
on line378
, the author has addedself.loss_fn(img, out, reduction="none")
. I am not sure what is the purpose of doingreduction='none'
. So I have summed it up in my code and just added it to calculate loss. - I am not sure what recipe to use for training(grad ACC steps, LR changes etc), I am just doing basic fine-tuning for now.
- For a small batch, my train loss seems to be very low initially and also converge quickly:
Next step would be to fine-tune a larger dataset. @erogol @eginhard if this is in the right direction, I can convert this into a training recipe
and add to the repo.
PS: The code is a bit dirty since I have just re-used whatever was available as long as it doesn't harm my training.
from tts.
I also now understand that the decoder of DVAE is not used, but instead an LM head is used on the GPT-2 to recompute the mel from the audio-codes. Need to understand this a bit better before writing the next stage training code.
from tts.
Related Issues (20)
- [Bug] Time taken to run TTS command far greater than actual processing time HOT 8
- [Bug] Unable to use xtts_v2 with mps device on Apple Silicon HOT 1
- [Bug] Cannot use Docker image HOT 1
- [Bug] very longinstallation that ends up with error HOT 2
- [Feature request] Language Support ("Hindi") missing in XTTS on local machine. HOT 2
- [Bug] bug in tts_to_file HOT 1
- PermissionError: [WinError 32] The process cannot access the file because it is being used by another process. HOT 8
- [Bug] Install bug Failed to download the model file to tts_models--multilingual--multi-dataset--xtts_v2 HOT 1
- [Bug] Unable to install Coqui TTS HOT 5
- [Bug] compute_statistics.py isn't working.
- [Bug] Training xtts v2 with original dataset which is multilingual and multispeaker HOT 8
- [Bug] Voice lag and pronounce punctuation
- [Feature request] update doc for convert model to hugginface
- Realtime voice conversion support
- [Bug] VITS gpu utilization
- [Bug] ModuleNotFoundError: No module named 'TTS' (From inside the TTS folder) HOT 1
- [Feature request] Run in the browser?
- Can't start training due to recursion depth error
- Can't install TTL on Windows 11: Could not build wheels for TTS HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tts.