Giter VIP home page Giter VIP logo

Comments (2)

smallsudarshan avatar smallsudarshan commented on June 2, 2024

Ok so here you go. I picked the code for training from this repo.

train.py:

import torch
import wandb
from models.dvae import DiscreteVAE
from utils.arch_utils import TorchMelSpectrogram
from torch.utils.data import DataLoader
from utils.dvae_dataset import DVAEDataset
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_

import pdb
from TTS.tts.datasets import load_tts_samples
from TTS.config.shared_configs import BaseDatasetConfig

dvae_checkpoint = '/home/ubuntu/test_tts/SimpleTTS/xtts/run/training/XTTS_v2.0_original_model_files/dvae.pth'
mel_norm_file = '/home/ubuntu/test_tts/SimpleTTS/xtts/run/training/XTTS_v2.0_original_model_files/mel_stats.pth'

config_dataset = BaseDatasetConfig(
    formatter="ljspeech",
    dataset_name="ljspeech",
    path="/home/ubuntu/test_tts/sapien-formatted-english-22050",
    meta_file_train="/home/ubuntu/test_tts/sapien-formatted-english-22050/metadata_norm.txt",
    language="en",
)

# Add here the configs of the datasets
DATASETS_CONFIG_LIST = [config_dataset]
GRAD_CLIP_NORM = 0.5
LEARNING_RATE = 5e-05

dvae = DiscreteVAE(
            channels=80,
            normalization=None,
            positional_dims=1,
            num_tokens=1024,
            codebook_dim=512,
            hidden_dim=512,
            num_resnet_blocks=3,
            kernel_size=3,
            num_layers=2,
            use_transposed_convs=False,
        )

dvae.load_state_dict(torch.load(dvae_checkpoint), strict=False)
dvae.cuda()
opt = Adam(dvae.parameters(), lr = LEARNING_RATE)
torch_mel_spectrogram_dvae = TorchMelSpectrogram(
            mel_norm_file=mel_norm_file, sampling_rate=22050
        ).cuda()

train_samples, eval_samples = load_tts_samples(
        DATASETS_CONFIG_LIST,
        eval_split=True,
        eval_split_max_size=256,
        eval_split_size=0.01,
    )

eval_dataset = DVAEDataset(eval_samples, 22050, True)
train_dataset = DVAEDataset(train_samples, 22050, False)
epochs = 20
eval_data_loader = DataLoader(
                    eval_dataset,
                    batch_size=3,
                    shuffle=False,
                    drop_last=False,
                    collate_fn=eval_dataset.collate_fn,
                    num_workers=0,
                    pin_memory=False,
                )

train_data_loader = DataLoader(
                    train_dataset,
                    batch_size=3,
                    shuffle=False,
                    drop_last=False,
                    collate_fn=train_dataset.collate_fn,
                    num_workers=4,
                    pin_memory=False,
                )

torch.set_grad_enabled(True)
dvae.train()

wandb.init(project = 'train_dvae')
wandb.watch(dvae)

def to_cuda(x: torch.Tensor) -> torch.Tensor:
    if x is None:
        return None
    if torch.is_tensor(x):
        x = x.contiguous()
        if torch.cuda.is_available():
            x = x.cuda(non_blocking=True)
    return x

@torch.no_grad()
def format_batch(batch):
    if isinstance(batch, dict):
        for k, v in batch.items():
            batch[k] = to_cuda(v)
    elif isinstance(batch, list):
        batch = [to_cuda(v) for v in batch]

    try:
        batch['mel'] = torch_mel_spectrogram_dvae(batch['wav'])
        # if the mel spectogram is not divisible by 4 then input.shape != output.shape 
        # for dvae
        remainder = batch['mel'].shape[-1] % 4
        if remainder:
            batch['mel'] = batch['mel'][:, :, :-remainder]
    except NotImplementedError:
        pass
    return batch

for i in range(epochs):
    for cur_step, batch in enumerate(train_data_loader):

        opt.zero_grad()
        batch = format_batch(batch)
        recon_loss, commitment_loss, out = dvae(batch['mel'])
        total_loss = recon_loss + commitment_loss
        total_loss.backward()
        clip_grad_norm_(dvae.parameters(), GRAD_CLIP_NORM)
        opt.step()

        log = {'epoch': i,
               'cur_step': cur_step,
               'loss': total_loss.item(),
               'recon_loss': recon_loss.item(),
               'commit_loss': commitment_loss.item()}
        print(f"epoch: {i}", print(f"step: {cur_step}"), f'loss - {total_loss.item()}', f'recon_loss - {recon_loss.item()}', f'commit_loss - {commitment_loss.item()}')
        wandb.log(log)
        torch.cuda.empty_cache()
#     if i%10:
#         save_model(f'.dvae.pth')

# wandb.save('./dvae.pth')
# wandb.finish()

Wrote a custom DVAEDataset that is imported in the above train.py file.


import torch
import random
from utils.dataset import key_samples_by_col
from TTS.tts.models.xtts import load_audio

torch.set_num_threads(1)

class DVAEDataset(torch.utils.data.Dataset):
    def __init__(self, samples, sample_rate, is_eval):
        self.sample_rate = sample_rate
        self.is_eval = is_eval
        self.max_wav_len = 255995
        self.samples = samples
        self.training_seed = 1
        self.failed_samples = set()
        if not is_eval:
            random.seed(self.training_seed)
            # random.shuffle(self.samples)
            random.shuffle(self.samples)
            # order by language
            self.samples = key_samples_by_col(self.samples, "language")
            print(" > Sampling by language:", self.samples.keys())
        else:
            # for evaluation load and check samples that are corrupted to ensures the reproducibility
            self.check_eval_samples()

    def check_eval_samples(self):
        print(" > Filtering invalid eval samples!!")
        new_samples = []
        for sample in self.samples:
            try:
                _, wav = self.load_item(sample)
            except:
                continue
            # Basically, this audio file is nonexistent or too long to be supported by the dataset.
            if (
                wav is None
                or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
            ):
                continue
            new_samples.append(sample)
        self.samples = new_samples
        print(" > Total eval samples after filtering:", len(self.samples))

    def load_item(self, sample):
        audiopath = sample["audio_file"]
        wav = load_audio(audiopath, self.sample_rate)
        if wav is None or wav.shape[-1] < (0.5 * self.sample_rate):
            # Ultra short clips are also useless (and can cause problems within some models).
            raise ValueError

        return audiopath, wav
    
    def __getitem__(self, index):
        if self.is_eval:
            sample = self.samples[index]
            sample_id = str(index)
        else:
            # select a random language
            lang = random.choice(list(self.samples.keys()))
            # select random sample
            index = random.randint(0, len(self.samples[lang]) - 1)
            sample = self.samples[lang][index]
            # a unique id for each sampel to deal with fails
            sample_id = lang + "_" + str(index)

        # ignore samples that we already know that is not valid ones
        if sample_id in self.failed_samples:
            # call get item again to get other sample
            return self[1]

        # try to load the sample, if fails added it to the failed samples list
        try:
            audiopath, wav = self.load_item(sample)
        except:
            self.failed_samples.add(sample_id)
            return self[1]

        # check if the audio and text size limits and if it out of the limits, added it failed_samples
        if (
            wav is None
            or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
        ):
            # Basically, this audio file is nonexistent or too long to be supported by the dataset.
            # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
            self.failed_samples.add(sample_id)
            return self[1]

        res = {
            "wav": wav,
            "wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
            "filenames": audiopath,
        }
        return res
    
    def __len__(self):
        if self.is_eval:
            return len(self.samples)
        return sum([len(v) for v in self.samples.values()])

    def collate_fn(self, batch):
        # convert list of dicts to dict of lists
        B = len(batch)

        batch = {k: [dic[k] for dic in batch] for k in batch[0]}

        # stack for features that already have the same shape
        batch["wav_lengths"] = torch.stack(batch["wav_lengths"])

        max_wav_len = batch["wav_lengths"].max()

        # create padding tensors
        wav_padded = torch.FloatTensor(B, 1, max_wav_len)

        # initialize tensors for zero padding
        wav_padded = wav_padded.zero_()
        for i in range(B):
            wav = batch["wav"][i]
            wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)

        batch["wav"] = wav_padded
        return batch

This trains the DVAE to encode and decode mel-spectograms.

Few things:

  1. You can see my import paths are not standard. That is because I have changed the structure of the repo a bit in my personal fork. You can follow the standard import paths as per TTS.
  2. There is a loss called DiscretizationLoss here but I am not sure where or how this is used? So I am not using it currently.
  3. For some reason, in the dvae.py on line 378, the author has added self.loss_fn(img, out, reduction="none"). I am not sure what is the purpose of doing reduction='none'. So I have summed it up in my code and just added it to calculate loss.
  4. I am not sure what recipe to use for training(grad ACC steps, LR changes etc), I am just doing basic fine-tuning for now.
  5. For a small batch, my train loss seems to be very low initially and also converge quickly:
    loss

Next step would be to fine-tune a larger dataset. @erogol @eginhard if this is in the right direction, I can convert this into a training recipe
and add to the repo.

PS: The code is a bit dirty since I have just re-used whatever was available as long as it doesn't harm my training.

from tts.

smallsudarshan avatar smallsudarshan commented on June 2, 2024

I also now understand that the decoder of DVAE is not used, but instead an LM head is used on the GPT-2 to recompute the mel from the audio-codes. Need to understand this a bit better before writing the next stage training code.

from tts.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.