Giter VIP home page Giter VIP logo

Comments (66)

JBloodless avatar JBloodless commented on June 27, 2024 2

@atabakp @eagomez2 I found out that my losses was completely wrong, so I'd like to ask you about outputs of this model. From the paper it's completely not obvious in which order masks are for the second set of masks (channels 6-10). On the fig. 2 of the paper it's non-noise and noise, but at the end of section 3.2 its noise and non-noise. Did you figure it out?

the order doesn't matter, just follow one, the network will adapt to correctly assign the corresponding output, irrespective of the initial order. My suggestion is to skip the PHM
for now and make sure the rest of the code is ok.

Снимок экрана 2024-02-15 в 16 39 08

Seems fine to me. Maybe the problem really is PHM computation.

For now I settled with

mask_direct = calculate_PHM(x16[:, :5, :])
result_direct = torch.view_as_complex(x) * mask_direct.squeeze(1)

mask_nonnoise = calculate_PHM(x16[:, 5:, :])  
result_nonnoise = torch.view_as_complex(x) * mask_nonnoise.squeeze(1)
result_noise = torch.view_as_complex(x) - result_nonnoise

mask_revpath = mask_nonnoise - mask_direct
result_revpath = torch.view_as_complex(x) * mask_revpath.squeeze(1)

Since you mentiond that order doesn't matter, I assumed that in the second pair non-noise will be first, so I'm directly calculating mask for direct path and non-noise signal, and then obtaining reveberation mask as in fig.2 of the paper.

from trunet.

atabakp avatar atabakp commented on June 27, 2024 1

Hi,

I had the same question. Has anyone been able to successfully train this network? I think that as @atabakp mentioned, the input has to have shape (time_frames, features, fft_size // 2 + 1) so when a batch is being used, the time_frames axis will grow. Since this is assume to be the N input of a nn.Conv1d, the processing will still be frame-independent so bigger batch sizes would mean a bigger stack of frames. Could someone confirm this?

Thanks, Esteban

Hi Esteban, I am able to train this model.
yes, you are right.

from trunet.

atabakp avatar atabakp commented on June 27, 2024 1

Thanks once again @atabakp! I was thinking something similar:

  1. Use log magnitude (as in the paper)
  2. Use PCEN output (as in the paper)
    For 3. an 4. "real/imaginary of the demodulated phase" didn't make much sense to me as a term initially since the phase would be real, so I was thinking of using normalized real/imag STFT as well since it would somehow put emphasis on the phase information.

One last question: How are you using the outputs, @atabakp ? I think it has 5 channels initially, but there is no explicit mention to what they exactly are. I was assuming two of them are magnitude masks (target and residual), two others are phase terms and the last one is the one used to estimate the phase's sign, but I was not sure.

Yes, you are right.
The output is 10 channels. 2 sets of 5 channels; one set is for predicting direct, 2nd set is for predicting Noise, and you can derive the reverberation by having direct and Noise.
1- z(k)t,f
2- z(¬k)t,f
3- φ
for the next 2 channels, refer to eq(3):https://arxiv.org/pdf/2006.00687.pdf
4- γ(0)(qt,f )
5- γ(1)(qt,f )
if my assumption about the channels is correct, then we don't need 2 separate channels for magnitude(channel 1,2); one is the complement of the other.

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024 1

Thanks once again @atabakp! I was thinking something similar:

  1. Use log magnitude (as in the paper)
  2. Use PCEN output (as in the paper)
    For 3. an 4. "real/imaginary of the demodulated phase" didn't make much sense to me as a term initially since the phase would be real, so I was thinking of using normalized real/imag STFT as well since it would somehow put emphasis on the phase information.

One last question: How are you using the outputs, @atabakp ? I think it has 5 channels initially, but there is no explicit mention to what they exactly are. I was assuming two of them are magnitude masks (target and residual), two others are phase terms and the last one is the one used to estimate the phase's sign, but I was not sure.

Yes, you are right. The output is 10 channels. 2 sets of 5 channels; one set is for predicting direct, 2nd set is for predicting Noise, and you can derive the reverberation by having direct and Noise. 1- z(k)t,f 2- z(¬k)t,f 3- φ for the next 2 channels, refer to eq(3):https://arxiv.org/pdf/2006.00687.pdf 4- γ(0)(qt,f ) 5- γ(1)(qt,f ) if my assumption about the channels is correct, then we don't need 2 separate channels for magnitude(channel 1,2); one is the complement of the other.

Thanks a lot once again, @atabakp ! I'll report back my progress as I manage to allocate time for working on it

from trunet.

atabakp avatar atabakp commented on June 27, 2024 1

I also have a question about the TGRU along the same lines. According to the paper:

The decoder is composed of a Time-axis Gated Recurrent Unit
(TGRU) block and 1D Transposed Convolutional Neural Network
(1D-TrCNN) blocks. The output of the encoder is passed into a unidirectional GRU layer to aggregate the information along the timeaxis

But then, the input to this layer is a (1, 16, 64) and according to PyTorch's GRU documentation, when batch_first=True, the 2nd dimension is the sequence length, which is the case here because batch_first defaults to True and is not changed when the TGRU layer is defined: https://github.com/YangangCao/TRUNet/blob/main/TRUNet.py#LL131C26-L131C26

To my understanding (please correct me if I'm wrong), the TGRU layer will not really aggregate information along the time axis, but will instead do a similar role than the FGRU, but using a unidirectional layer. I assumed first that batch_first should be set to False in order to apply the nn.GRU along the first dimension which is the original time dimension.

#4 (comment)

from trunet.

atabakp avatar atabakp commented on June 27, 2024 1

Hi @atabakp ,

Not sure if my interpretation of the outputs is correct, but I'm trying to follow the paper and even when the model trains, it may become unstable after some epochs. I believe that the cos_phase is causing this because sometimes due to cosine law I get values marginally out of the expected output. How are you dealing with this and how are you obtaining the respective sin_phase? I believe I'm missing something somewhere. I already tried clamping values that could potentially make the values explode with no luck.

    # Control random seed
    rand_seed = torch.manual_seed(0)

    # Lets assume it has shape (1, 5, 257) (the expected output for a single source)
    # Since the activation function is ReLU, values can be equal or greater
    # than 0
    x_features = torch.rand((1, 5, 257), dtype=torch.float32)
    
    # Extract z_tf for target and residual
    z_tf = x_features[:, 0:1, :]
    z_tf_residual = x_features[:, 1:2, :]

    # Extract phi
    phi = x_features[:, 2:3, :]

    # Estimate beta (due to softplus it will be one or greater)
    beta = 1.0 + F.softplus(phi)

    # Estimate sigmod of target and residual
    sigmoid_tf = F.sigmoid(z_tf - z_tf_residual)
    sigmoid_tf_residual = F.sigmoid(z_tf_residual - z_tf)

    # Estimate upper bound for beta
    beta_upper_bound = 1.0 / torch.abs(sigmoid_tf - sigmoid_tf_residual)

    # Because of the absolute value in the denominator, the same upper bound
    # can be applied to both betas
    beta = torch.clip(beta, max=beta_upper_bound)

    # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

    # Now that we have both masks, let's compute the triangle cosine law
    cos_phase = (
        (1.0 + mask_tf.square() - mask_tf_residual.square())
        / (2.0 * mask_tf))
    
    # Use trigonometric identity to obtain the sine
    sin_phase = torch.sqrt(1.0 - cos_phase.square())

    # Now estimate the sign
    q0 = x_features[:, 3:4, :]
    q1 = x_features[:, 4:5, :]
    gamma0 = F.gumbel_softmax(q0, tau=1.0)
    gamma1 = F.gumbel_softmax(q1, tau=1.0)
    sign = torch.where(gamma0 > gamma1, -1.0, 1.0)

    # Finally, estimate the complex mask
    complex_mask = mask_tf * (cos_phase + sign * 1j * sin_phase)

    # Then it should be applied to the stft and inverted using the istft
    ...

1- I didn't apply ReLu for the last layer(x_features).
2- sigmoid_tf_residual = 1-sigmoid_tf This is always true, so I think do not need two outputs for this, but based on the paper, your code is correct.
3- You should always handle the division, log, sqrt,... in your code; for example, in cos_phase there is a chance that mask_tf becomes zero.
4- For obtaining sin, I am using acos: torch.sin(torch.acos(torch.clamp(cos_phase, min=-1 + eps, max=1 - eps)))

5-Estimating sign is a bit confusing, and I believe there is a typo in the formula of the paper. (I believe sign does not much matter for the performance)
This is how I implemented it:
gamma = torch.nn.functional.gumbel_softmax(
torch.stack([q0, q1], dim=-1),
tau=0.5,
hard=False,
)
gamma_0 = gamma[..., 0]
gamma_1 = gamma[..., 1]

sign = torch.where(gamma_0 > gamma_1, -1.0, 1.0)

from trunet.

atabakp avatar atabakp commented on June 27, 2024 1

Hi again @atabakp ,

When training the model, are you using 2s audio as the paper claims or are you using gradient accumulation or something like that to pass more data between steps?

I'm currently trying to train the model for dereverberation only, but 2s per audio in all case is very slow to train. So far I haven't reached to point to evaluate how successful the model is in the task, but it doesn't seem that it's learning quickly.

I am using random-length sequences, single sequence per iteration (batch size =1)

from trunet.

atabakp avatar atabakp commented on June 27, 2024 1

softmax with temperature,

Nope, still doesn't work. The only thing that "worked" is skipping PHM and multiplying one channel of last output with input, but I didn't wait for it to converge yet. I'll try these fixes, thanks.

One more question for @atabakp : I mentioned that my losses was wrong. By that I meant that in paper loss is the sum of losses for direct source, noise and reverberant path (last equation of section 3.3). How did you calculate them, did you do this sum of 3 or something else? Because I don't see good way to calculate target reverberant path, I just subtract clean signal from reverbed signal, and use tensor of 1e-6 for noise target (since I only train for dereverberation) Also for @eagomez2 : I get that you used different loss - did you calculate it with only direct source?

I tried different variations, but I found out that only using loss on direct is good enough.

from trunet.

atabakp avatar atabakp commented on June 27, 2024 1

So is it working now @JBloodless ?
I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some eps to stabilize the cos_phase term and it worked. Also sigmoid_tf_residual can be simplified to 1.0 - sigmoid_tf_target.
All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.

I Totally agree, even the PHM is not very crucial, the network can directly output the clean speech mask.

@atabakp thanks for bringing this up. Have you tried training without PHM? I was also curious about doing this, but I haven't found the time so far.

Yes, I tried, I ended up using a single channel for masking the magnitude.

from trunet.

AmosCch avatar AmosCch commented on June 27, 2024

I think the input tensor in sample code is one-frame feature. If you want to feed a wav into the model, the input dimension might be (B,4,frames,257), but I'm not sure. Please email me ([email protected]) if you have any insight.

from trunet.

amirpashamobinitehrani avatar amirpashamobinitehrani commented on June 27, 2024

@yugeshav Hey! Any progress on this? I am also confused with the input shape

from trunet.

atabakp avatar atabakp commented on June 27, 2024

@amirpashamobinitehrani The input shape for 1D conv is: (T, C,F)
(Time frames, Channels(4 features), Frequency bins).

from trunet.

amirpashamobinitehrani avatar amirpashamobinitehrani commented on June 27, 2024

Thanks for you reply. Interesting! Yes, I had some presumptions. What still remains a mystery to me is to inject batch dimension into the play.

(Batch, Time frames, Channels(4 features), Frequency bins)

Which I assume we should refrain from. Right? We are simply processing 4 different features of 1 audio file in (time-frame) steps. So the time-frame dimension is fulfilling Batch dimension's role.

from trunet.

atabakp avatar atabakp commented on June 27, 2024

Correct!Each frame is a data sample here. If you want to use the (Batch, Time, Features, Frequency) you should use 2D Convolution and set the filters’ dimension to (n, 1).

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

Hi,

I had the same question. Has anyone been able to successfully train this network? I think that as @atabakp mentioned, the input has to have shape (time_frames, features, fft_size // 2 + 1) so when a batch is being used, the time_frames axis will grow. Since this is assume to be the N input of a nn.Conv1d, the processing will still be frame-independent so bigger batch sizes would mean a bigger stack of frames. Could someone confirm this?

Thanks,
Esteban

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

Thanks @atabakp !

As a follow-up question: How are you obtaining the "demodulated phase"?

from trunet.

atabakp avatar atabakp commented on June 27, 2024

There are a few methods to do this, but I don't know what the Authors exactly mean. for example https://arxiv.org/pdf/1608.01953.pdf

But for my training, I used Log Magnitude and normalized real/imag as inputs.

from trunet.

amirpashamobinitehrani avatar amirpashamobinitehrani commented on June 27, 2024

I managed to implement the demodulated phase, using (log_magnitude, demod_real, demod_imag) as inputs to train the model. For some reasons, I am not witnessing the model successfully doing anything useful. It would be nice to get some insights regarding the implementations if any has made a promising progress on this!

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

Thanks once again @atabakp!
I was thinking something similar:

  1. Use log magnitude (as in the paper)
  2. Use PCEN output (as in the paper)
    For 3. an 4. "real/imaginary of the demodulated phase" didn't make much sense to me as a term initially since the phase would be real, so I was thinking of using normalized real/imag STFT as well since it would somehow put emphasis on the phase information.

One last question: How are you using the outputs, @atabakp ? I think it has 5 channels initially, but there is no explicit mention to what they exactly are. I was assuming two of them are magnitude masks (target and residual), two others are phase terms and the last one is the one used to estimate the phase's sign, but I was not sure.

from trunet.

atabakp avatar atabakp commented on June 27, 2024

Section 3 of this paper also has some information about phase demodulation: https://www.isca-speech.org/archive_v0/Interspeech_2018/pdfs/1773.pdf

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

Hi again @atabakp ,

How are you get the 10 channels? I looked again into the model's code and I'm getting only 5 channels. Here I'm attaching the I/O of each layer:

module type input_shape output_shape
root TRUNet (1, 4, 257) (1, 5, 257)
down1 StandardConv1d (1, 4, 257) (1, 64, 128)
down1.StandardConv1d Sequential (1, 4, 257) (1, 64, 128)
down1.StandardConv1d.0 Conv1d (1, 4, 257) (1, 64, 128)
down1.StandardConv1d.1 ReLU (1, 64, 128) (1, 64, 128)
down2 DepthwiseSeparableConv1d (1, 64, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d Sequential (1, 64, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.0 Conv1d (1, 64, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.5 ReLU (1, 128, 128) (1, 128, 128)
down3 DepthwiseSeparableConv1d (1, 128, 128) (1, 128, 64)
down3.DepthwiseSeparableConv1d Sequential (1, 128, 128) (1, 128, 64)
down3.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 128) (1, 128, 128)
down3.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128)
down3.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128)
down3.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 64)
down3.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64)
down3.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64)
down4 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64)
down5 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 32)
down5.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 32)
down5.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64)
down5.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64)
down5.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64)
down5.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 32)
down5.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 32) (1, 128, 32)
down5.DepthwiseSeparableConv1d.5 ReLU (1, 128, 32) (1, 128, 32)
down6 DepthwiseSeparableConv1d (1, 128, 32) (1, 128, 16)
down6.DepthwiseSeparableConv1d Sequential (1, 128, 32) (1, 128, 16)
down6.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 32) (1, 128, 32)
down6.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 32) (1, 128, 32)
down6.DepthwiseSeparableConv1d.2 ReLU (1, 128, 32) (1, 128, 32)
down6.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 32) (1, 128, 16)
down6.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 16) (1, 128, 16)
down6.DepthwiseSeparableConv1d.5 ReLU (1, 128, 16) (1, 128, 16)
FGRU GRUBlock (1, 16, 128) (1, 64, 16)
FGRU.GRU GRU (1, 16, 128) ((1, 16, 128), (2, 1, 64))
FGRU.conv Sequential (1, 128, 16) (1, 64, 16)
FGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16)
FGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16)
FGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16)
TGRU GRUBlock (1, 16, 64) (1, 64, 16)
TGRU.GRU GRU (1, 16, 64) ((1, 16, 128), (1, 1, 128))
TGRU.conv Sequential (1, 128, 16) (1, 64, 16)
TGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16)
TGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16)
TGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16)
up1 FirstTrCNN (1, 64, 16) (1, 64, 31)
up1.FirstTrCNN Sequential (1, 64, 16) (1, 64, 31)
up1.FirstTrCNN.0 Conv1d (1, 64, 16) (1, 64, 16)
up1.FirstTrCNN.1 BatchNorm1d (1, 64, 16) (1, 64, 16)
up1.FirstTrCNN.2 ReLU (1, 64, 16) (1, 64, 16)
up1.FirstTrCNN.3 ConvTranspose1d (1, 64, 16) (1, 64, 31)
up1.FirstTrCNN.4 BatchNorm1d (1, 64, 31) (1, 64, 31)
up1.FirstTrCNN.5 ReLU (1, 64, 31) (1, 64, 31)
up2 TrCNN (1, 64, 31) (1, 64, 65)
up2.TrCNN Sequential (1, 192, 32) (1, 64, 65)
up2.TrCNN.0 Conv1d (1, 192, 32) (1, 64, 32)
up2.TrCNN.1 BatchNorm1d (1, 64, 32) (1, 64, 32)
up2.TrCNN.2 ReLU (1, 64, 32) (1, 64, 32)
up2.TrCNN.3 ConvTranspose1d (1, 64, 32) (1, 64, 65)
up2.TrCNN.4 BatchNorm1d (1, 64, 65) (1, 64, 65)
up2.TrCNN.5 ReLU (1, 64, 65) (1, 64, 65)
up3 TrCNN (1, 64, 65) (1, 64, 66)
up3.TrCNN Sequential (1, 192, 64) (1, 64, 66)
up3.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64)
up3.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64)
up3.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64)
up3.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 66)
up3.TrCNN.4 BatchNorm1d (1, 64, 66) (1, 64, 66)
up3.TrCNN.5 ReLU (1, 64, 66) (1, 64, 66)
up4 TrCNN (1, 64, 66) (1, 64, 129)
up4.TrCNN Sequential (1, 192, 64) (1, 64, 129)
up4.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64)
up4.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64)
up4.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64)
up4.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 129)
up4.TrCNN.4 BatchNorm1d (1, 64, 129) (1, 64, 129)
up4.TrCNN.5 ReLU (1, 64, 129) (1, 64, 129)
up5 TrCNN (1, 64, 129) (1, 64, 130)
up5.TrCNN Sequential (1, 192, 128) (1, 64, 130)
up5.TrCNN.0 Conv1d (1, 192, 128) (1, 64, 128)
up5.TrCNN.1 BatchNorm1d (1, 64, 128) (1, 64, 128)
up5.TrCNN.2 ReLU (1, 64, 128) (1, 64, 128)
up5.TrCNN.3 ConvTranspose1d (1, 64, 128) (1, 64, 130)
up5.TrCNN.4 BatchNorm1d (1, 64, 130) (1, 64, 130)
up5.TrCNN.5 ReLU (1, 64, 130) (1, 64, 130)
up6 LastTrCNN (1, 64, 130) (1, 5, 257)
up6.LastTrCNN Sequential (1, 128, 128) (1, 5, 257)
up6.LastTrCNN.0 Conv1d (1, 128, 128) (1, 5, 128)
up6.LastTrCNN.1 BatchNorm1d (1, 5, 128) (1, 5, 128)
up6.LastTrCNN.2 ReLU (1, 5, 128) (1, 5, 128)
up6.LastTrCNN.3 ConvTranspose1d (1, 5, 128) (1, 5, 257)

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

I also have a question about the TGRU along the same lines. According to the paper:

The decoder is composed of a Time-axis Gated Recurrent Unit
(TGRU) block and 1D Transposed Convolutional Neural Network
(1D-TrCNN) blocks. The output of the encoder is passed into a unidirectional GRU layer to aggregate the information along the timeaxis

But then, the input to this layer is a (1, 16, 64) and according to PyTorch's GRU documentation, when batch_first=True, the 2nd dimension is the sequence length, which is the case here because batch_first defaults to True and is not changed when the TGRU layer is defined: https://github.com/YangangCao/TRUNet/blob/main/TRUNet.py#LL131C26-L131C26

To my understanding (please correct me if I'm wrong), the TGRU layer will not really aggregate information along the time axis, but will instead do a similar role than the FGRU, but using a unidirectional layer. I assumed first that batch_first should be set to False in order to apply the nn.GRU along the first dimension which is the original time dimension.

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

Hi again @atabakp ,

How are you get the 10 channels? I looked again into the model's code and I'm getting only 5 channels. Here I'm attaching the I/O of each layer:

module type input_shape output_shape
root TRUNet (1, 4, 257) (1, 5, 257)
down1 StandardConv1d (1, 4, 257) (1, 64, 128)
down1.StandardConv1d Sequential (1, 4, 257) (1, 64, 128)
down1.StandardConv1d.0 Conv1d (1, 4, 257) (1, 64, 128)
down1.StandardConv1d.1 ReLU (1, 64, 128) (1, 64, 128)
down2 DepthwiseSeparableConv1d (1, 64, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d Sequential (1, 64, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.0 Conv1d (1, 64, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 128) (1, 128, 128)
down2.DepthwiseSeparableConv1d.5 ReLU (1, 128, 128) (1, 128, 128)
down3 DepthwiseSeparableConv1d (1, 128, 128) (1, 128, 64)
down3.DepthwiseSeparableConv1d Sequential (1, 128, 128) (1, 128, 64)
down3.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 128) (1, 128, 128)
down3.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 128) (1, 128, 128)
down3.DepthwiseSeparableConv1d.2 ReLU (1, 128, 128) (1, 128, 128)
down3.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 128) (1, 128, 64)
down3.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64)
down3.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64)
down4 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 64) (1, 128, 64)
down4.DepthwiseSeparableConv1d.5 ReLU (1, 128, 64) (1, 128, 64)
down5 DepthwiseSeparableConv1d (1, 128, 64) (1, 128, 32)
down5.DepthwiseSeparableConv1d Sequential (1, 128, 64) (1, 128, 32)
down5.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 64) (1, 128, 64)
down5.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 64) (1, 128, 64)
down5.DepthwiseSeparableConv1d.2 ReLU (1, 128, 64) (1, 128, 64)
down5.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 64) (1, 128, 32)
down5.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 32) (1, 128, 32)
down5.DepthwiseSeparableConv1d.5 ReLU (1, 128, 32) (1, 128, 32)
down6 DepthwiseSeparableConv1d (1, 128, 32) (1, 128, 16)
down6.DepthwiseSeparableConv1d Sequential (1, 128, 32) (1, 128, 16)
down6.DepthwiseSeparableConv1d.0 Conv1d (1, 128, 32) (1, 128, 32)
down6.DepthwiseSeparableConv1d.1 BatchNorm1d (1, 128, 32) (1, 128, 32)
down6.DepthwiseSeparableConv1d.2 ReLU (1, 128, 32) (1, 128, 32)
down6.DepthwiseSeparableConv1d.3 Conv1d (1, 128, 32) (1, 128, 16)
down6.DepthwiseSeparableConv1d.4 BatchNorm1d (1, 128, 16) (1, 128, 16)
down6.DepthwiseSeparableConv1d.5 ReLU (1, 128, 16) (1, 128, 16)
FGRU GRUBlock (1, 16, 128) (1, 64, 16)
FGRU.GRU GRU (1, 16, 128) ((1, 16, 128), (2, 1, 64))
FGRU.conv Sequential (1, 128, 16) (1, 64, 16)
FGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16)
FGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16)
FGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16)
TGRU GRUBlock (1, 16, 64) (1, 64, 16)
TGRU.GRU GRU (1, 16, 64) ((1, 16, 128), (1, 1, 128))
TGRU.conv Sequential (1, 128, 16) (1, 64, 16)
TGRU.conv.0 Conv1d (1, 128, 16) (1, 64, 16)
TGRU.conv.1 BatchNorm1d (1, 64, 16) (1, 64, 16)
TGRU.conv.2 ReLU (1, 64, 16) (1, 64, 16)
up1 FirstTrCNN (1, 64, 16) (1, 64, 31)
up1.FirstTrCNN Sequential (1, 64, 16) (1, 64, 31)
up1.FirstTrCNN.0 Conv1d (1, 64, 16) (1, 64, 16)
up1.FirstTrCNN.1 BatchNorm1d (1, 64, 16) (1, 64, 16)
up1.FirstTrCNN.2 ReLU (1, 64, 16) (1, 64, 16)
up1.FirstTrCNN.3 ConvTranspose1d (1, 64, 16) (1, 64, 31)
up1.FirstTrCNN.4 BatchNorm1d (1, 64, 31) (1, 64, 31)
up1.FirstTrCNN.5 ReLU (1, 64, 31) (1, 64, 31)
up2 TrCNN (1, 64, 31) (1, 64, 65)
up2.TrCNN Sequential (1, 192, 32) (1, 64, 65)
up2.TrCNN.0 Conv1d (1, 192, 32) (1, 64, 32)
up2.TrCNN.1 BatchNorm1d (1, 64, 32) (1, 64, 32)
up2.TrCNN.2 ReLU (1, 64, 32) (1, 64, 32)
up2.TrCNN.3 ConvTranspose1d (1, 64, 32) (1, 64, 65)
up2.TrCNN.4 BatchNorm1d (1, 64, 65) (1, 64, 65)
up2.TrCNN.5 ReLU (1, 64, 65) (1, 64, 65)
up3 TrCNN (1, 64, 65) (1, 64, 66)
up3.TrCNN Sequential (1, 192, 64) (1, 64, 66)
up3.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64)
up3.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64)
up3.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64)
up3.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 66)
up3.TrCNN.4 BatchNorm1d (1, 64, 66) (1, 64, 66)
up3.TrCNN.5 ReLU (1, 64, 66) (1, 64, 66)
up4 TrCNN (1, 64, 66) (1, 64, 129)
up4.TrCNN Sequential (1, 192, 64) (1, 64, 129)
up4.TrCNN.0 Conv1d (1, 192, 64) (1, 64, 64)
up4.TrCNN.1 BatchNorm1d (1, 64, 64) (1, 64, 64)
up4.TrCNN.2 ReLU (1, 64, 64) (1, 64, 64)
up4.TrCNN.3 ConvTranspose1d (1, 64, 64) (1, 64, 129)
up4.TrCNN.4 BatchNorm1d (1, 64, 129) (1, 64, 129)
up4.TrCNN.5 ReLU (1, 64, 129) (1, 64, 129)
up5 TrCNN (1, 64, 129) (1, 64, 130)
up5.TrCNN Sequential (1, 192, 128) (1, 64, 130)
up5.TrCNN.0 Conv1d (1, 192, 128) (1, 64, 128)
up5.TrCNN.1 BatchNorm1d (1, 64, 128) (1, 64, 128)
up5.TrCNN.2 ReLU (1, 64, 128) (1, 64, 128)
up5.TrCNN.3 ConvTranspose1d (1, 64, 128) (1, 64, 130)
up5.TrCNN.4 BatchNorm1d (1, 64, 130) (1, 64, 130)
up5.TrCNN.5 ReLU (1, 64, 130) (1, 64, 130)
up6 LastTrCNN (1, 64, 130) (1, 5, 257)
up6.LastTrCNN Sequential (1, 128, 128) (1, 5, 257)
up6.LastTrCNN.0 Conv1d (1, 128, 128) (1, 5, 128)
up6.LastTrCNN.1 BatchNorm1d (1, 5, 128) (1, 5, 128)
up6.LastTrCNN.2 ReLU (1, 5, 128) (1, 5, 128)
up6.LastTrCNN.3 ConvTranspose1d (1, 5, 128) (1, 5, 257)

I answer myself about this one. The paper config listing for the decoder says:

DecoderConfig = {1-th:
(3,2,64), 2-th: (5,2,64), 3-th: (3,1,64), 4-th: (5,2,64), 5-th: (3,1,64),
6-th: (5,2,10)}

where the last number is the number of channels, therefore you're right, they should be 10 instead.

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

Hi @atabakp ,

Not sure if my interpretation of the outputs is correct, but I'm trying to follow the paper and even when the model trains, it may become unstable after some epochs. I believe that the cos_phase is causing this because sometimes due to cosine law I get values marginally out of the expected output. How are you dealing with this and how are you obtaining the respective sin_phase? I believe I'm missing something somewhere. I already tried clamping values that could potentially make the values explode with no luck.

    # Control random seed
    rand_seed = torch.manual_seed(0)

    # Lets assume it has shape (1, 5, 257) (the expected output for a single source)
    # Since the activation function is ReLU, values can be equal or greater
    # than 0
    x_features = torch.rand((1, 5, 257), dtype=torch.float32)
    
    # Extract z_tf for target and residual
    z_tf = x_features[:, 0:1, :]
    z_tf_residual = x_features[:, 1:2, :]

    # Extract phi
    phi = x_features[:, 2:3, :]

    # Estimate beta (due to softplus it will be one or greater)
    beta = 1.0 + F.softplus(phi)

    # Estimate sigmod of target and residual
    sigmoid_tf = F.sigmoid(z_tf - z_tf_residual)
    sigmoid_tf_residual = F.sigmoid(z_tf_residual - z_tf)

    # Estimate upper bound for beta
    beta_upper_bound = 1.0 / torch.abs(sigmoid_tf - sigmoid_tf_residual)

    # Because of the absolute value in the denominator, the same upper bound
    # can be applied to both betas
    beta = torch.clip(beta, max=beta_upper_bound)

    # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

    # Now that we have both masks, let's compute the triangle cosine law
    cos_phase = (
        (1.0 + mask_tf.square() - mask_tf_residual.square())
        / (2.0 * mask_tf))
    
    # Use trigonometric identity to obtain the sine
    sin_phase = torch.sqrt(1.0 - cos_phase.square())

    # Now estimate the sign
    q0 = x_features[:, 3:4, :]
    q1 = x_features[:, 4:5, :]
    gamma0 = F.gumbel_softmax(q0, tau=1.0)
    gamma1 = F.gumbel_softmax(q1, tau=1.0)
    sign = torch.where(gamma0 > gamma1, -1.0, 1.0)

    # Finally, estimate the complex mask
    complex_mask = mask_tf * (cos_phase + sign * 1j * sin_phase)

    # Then it should be applied to the stft and inverted using the istft
    ...

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

Hi @atabakp ,
Not sure if my interpretation of the outputs is correct, but I'm trying to follow the paper and even when the model trains, it may become unstable after some epochs. I believe that the cos_phase is causing this because sometimes due to cosine law I get values marginally out of the expected output. How are you dealing with this and how are you obtaining the respective sin_phase? I believe I'm missing something somewhere. I already tried clamping values that could potentially make the values explode with no luck.

    # Control random seed
    rand_seed = torch.manual_seed(0)

    # Lets assume it has shape (1, 5, 257) (the expected output for a single source)
    # Since the activation function is ReLU, values can be equal or greater
    # than 0
    x_features = torch.rand((1, 5, 257), dtype=torch.float32)
    
    # Extract z_tf for target and residual
    z_tf = x_features[:, 0:1, :]
    z_tf_residual = x_features[:, 1:2, :]

    # Extract phi
    phi = x_features[:, 2:3, :]

    # Estimate beta (due to softplus it will be one or greater)
    beta = 1.0 + F.softplus(phi)

    # Estimate sigmod of target and residual
    sigmoid_tf = F.sigmoid(z_tf - z_tf_residual)
    sigmoid_tf_residual = F.sigmoid(z_tf_residual - z_tf)

    # Estimate upper bound for beta
    beta_upper_bound = 1.0 / torch.abs(sigmoid_tf - sigmoid_tf_residual)

    # Because of the absolute value in the denominator, the same upper bound
    # can be applied to both betas
    beta = torch.clip(beta, max=beta_upper_bound)

    # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

    # Now that we have both masks, let's compute the triangle cosine law
    cos_phase = (
        (1.0 + mask_tf.square() - mask_tf_residual.square())
        / (2.0 * mask_tf))
    
    # Use trigonometric identity to obtain the sine
    sin_phase = torch.sqrt(1.0 - cos_phase.square())

    # Now estimate the sign
    q0 = x_features[:, 3:4, :]
    q1 = x_features[:, 4:5, :]
    gamma0 = F.gumbel_softmax(q0, tau=1.0)
    gamma1 = F.gumbel_softmax(q1, tau=1.0)
    sign = torch.where(gamma0 > gamma1, -1.0, 1.0)

    # Finally, estimate the complex mask
    complex_mask = mask_tf * (cos_phase + sign * 1j * sin_phase)

    # Then it should be applied to the stft and inverted using the istft
    ...

1- I didn't apply ReLu for the last layer(x_features). 2- sigmoid_tf_residual = 1-sigmoid_tf This is always true, so I think do not need two outputs for this, but based on the paper, your code is correct. 3- You should always handle the division, log, sqrt,... in your code; for example, in cos_phase there is a chance that mask_tf becomes zero. 4- For obtaining sin, I am using acos: torch.sin(torch.acos(torch.clamp(cos_phase, min=-1 + eps, max=1 - eps)))

5-Estimating sign is a bit confusing, and I believe there is a typo in the formula of the paper. (I believe sign does not much matter for the performance) This is how I implemented it: gamma = torch.nn.functional.gumbel_softmax( torch.stack([q0, q1], dim=-1), tau=0.5, hard=False, ) gamma_0 = gamma[..., 0] gamma_1 = gamma[..., 1]

sign = torch.where(gamma_0 > gamma_1, -1.0, 1.0)

Thanks you very much @atabakp !

  • I tried clamping before with no luck but I ended up discovering that the problem was apparently in one of my losses since not it seems to work as expected.
  • I also can corroborate what you mention about simplifying sigmoid_tf_residual and it works with the simpler version.
  • I was also hesitant about the sign prediction since the phase should already be fit to the triangle inequality, but I haven't done yet and comparison with and without it. I'll update my conclusions as soon as I can check them.

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

Hi again @atabakp ,

When training the model, are you using 2s audio as the paper claims or are you using gradient accumulation or something like that to pass more data between steps?

I'm currently trying to train the model for dereverberation only, but 2s per audio in all case is very slow to train. So far I haven't reached to point to evaluate how successful the model is in the task, but it doesn't seem that it's learning quickly.

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

Sorry for necroposting here, but I'm trying to train this model, and with no luck yet. I managed to add trainable PCEN (as described in paper) and training on spectrograms. I construct input feature from PCEN (output of trainable layer), log magnitude, real and imag parts of STFT and feed it to the rest of the model described here. I also implemented 2d convs since I wanted to train on batches. Losses are the same as in the paper - multires cosine similarity + multires spectrum MSE.
Model trains very weirdly (loss is decreasing for the first couple of hundreds of steps, then increasing, then decreasing again). @atabakp @eagomez2 I assume you managed to train this model - can you evaluate, what did you change compared to paper? Or maybe share your pipeline. Thanks in advance.

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

Hi @JBloodless ,

In my case I am using Conv1d, but I decided to change original losses for a GAN since they weren't quite working for me (the model was converging, but not with the expected quality, and sometimes exploding after this).

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

Hi @JBloodless ,

In my case I am using Conv1d, but I decided to change original losses for a GAN since they weren't quite working for me (the model was converging, but not with the expected quality, and sometimes exploding after this).

Can you elaborate a bit, what do you mean by GAN loss?

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

Check section 2 of this paper: https://arxiv.org/pdf/2010.10677.pdf

Hi @JBloodless ,
In my case I am using Conv1d, but I decided to change original losses for a GAN since they weren't quite working for me (the model was converging, but not with the expected quality, and sometimes exploding after this).

Can you elaborate a bit, what do you mean by GAN loss?

Check section 2 of this paper: https://arxiv.org/pdf/2010.10677.pdf

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

What do you mean by repeating? I thought that network (in this implementation) returns one set of features for PHM (time, 5, bins), and corresponding PHM will be mask for direct source. Since I need to obtain only direct source (clean speech), I just multiply this PHM with input spectrum, and I get clean output. What did I assume wrong?

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?

Yes, that's correct

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?

Yes, that's correct

I think I got it wrong. It seems that your function technically calculates pair of masks

 # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

and features 6-10 are for reveberant and noise separation. Which means that from the first 5 features we can calculate only direct source mask (which is exactly what your function is doing) to perform dereveberation. Which leads to even more frustration for me, since I'm doing everything "right", but model doesn't train at all.

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?

Yes, that's correct

I think I got it wrong. It seems that your function technically calculates pair of masks

 # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

and features 6-10 are for reveberant and noise separation. Which means that from the first 5 features we can calculate only direct source mask (which is exactly what your function is doing) to perform dereveberation. Which leads to even more frustration for me, since I'm doing everything "right", but model doesn't train at all.

Are you using the paper losses or are you trying a different one? Is the model training but with the loss values "all over the place" or is it exploding or so?

If I remember correctly (this was months ago) I double and triple checked that every function that could potentially explode by things like dividing by zero had an epsilon (eps) value or something to prevent such issues before managing to get meaningful results

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@eagomez2 I'm assuming that you using your code above to calculate complex mask and then multiplying this with complex input to produce output waveform, is this correct?

Yes. You will need to repeat the process to obtain both the direct speech waveform and the residual waveform

Reading conversation above, I assume that you changed output layer to 10 channels (as in paper). How should I apply this PHM function then? Channels 1-5 will be the mask for direct speech, and 6-10 - residual?

Yes, that's correct

I think I got it wrong. It seems that your function technically calculates pair of masks

 # Compute both target and residual masks using eq. (1)
    mask_tf = beta * sigmoid_tf
    mask_tf_residual = beta * sigmoid_tf_residual

and features 6-10 are for reveberant and noise separation. Which means that from the first 5 features we can calculate only direct source mask (which is exactly what your function is doing) to perform dereveberation. Which leads to even more frustration for me, since I'm doing everything "right", but model doesn't train at all.

Are you using the paper losses or are you trying a different one? Is the model training but with the loss values "all over the place" or is it exploding or so?

If I remember correctly (this was months ago) I double and triple checked that every function that could potentially explode by things like dividing by zero had an epsilon (eps) value or something to prevent such issues before managing to get meaningful results

It's not exploding, loss just stable around some value. Yes, I'm trying to use paper loss (since this loss looks right for me and I don't see why it shouldn't work).

Снимок экрана 2024-02-13 в 13 11 58

Purple one is the latest try (with 10 channels output)

I already fixed a bunch on NaNs, so I think that zeros handling is not a problem here. Maybe I'm interpreting outputs wrong, I'll try to log what's going on.

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@eagomez2 maybe you'll help me double check output interpretation.

x16 = self.up6(x15, x1) # x16 is the output of this implementation, except self.up6 has 10 channels instead of 5
mask_direct = calculate_PHM(x16[:, :5, :])    # calculate_PHM is your function from above comments (with fixes from @atabakp), which returns complex_mask 
result = x * mask_direct # x is the input -  complex spectrogram
out_wave = torch.istft(result,
                               n_fft=self.nfft,
                               hop_length=self.hop,
                               onesided=True,
                               window=self.window.to(wave.device),
                               center=True)

Is this the same-ish as yours? My main concern for now is mask_direct = calculate_PHM(x16[:, :5, :])

from trunet.

atabakp avatar atabakp commented on June 27, 2024

@eagomez2 maybe you'll help me double check output interpretation.

x16 = self.up6(x15, x1) # x16 is the output of this implementation, except self.up6 has 10 channels instead of 5
mask_direct = calculate_PHM(x16[:, :5, :])    # calculate_PHM is your function from above comments (with fixes from @atabakp), which returns complex_mask 
result = x * mask_direct # x is the input -  complex spectrogram
out_wave = torch.istft(result,
                               n_fft=self.nfft,
                               hop_length=self.hop,
                               onesided=True,
                               window=self.window.to(wave.device),
                               center=True)

Is this the same-ish as yours? My main concern for now is mask_direct = calculate_PHM(x16[:, :5, :])

To verify the correct operation of the code, skip the calculate_PHM function.
multiply one channel of x16 with the x.

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@eagomez2 maybe you'll help me double check output interpretation.

x16 = self.up6(x15, x1) # x16 is the output of this implementation, except self.up6 has 10 channels instead of 5
mask_direct = calculate_PHM(x16[:, :5, :])    # calculate_PHM is your function from above comments (with fixes from @atabakp), which returns complex_mask 
result = x * mask_direct # x is the input -  complex spectrogram
out_wave = torch.istft(result,
                               n_fft=self.nfft,
                               hop_length=self.hop,
                               onesided=True,
                               window=self.window.to(wave.device),
                               center=True)

Is this the same-ish as yours? My main concern for now is mask_direct = calculate_PHM(x16[:, :5, :])

To verify the correct operation of the code, skip the calculate_PHM function. multiply one channel of x16 with the x.

By “correct” I meant that it produces expected output (clean speech). This function is “correct” in terms of data and shape, if that’s what you mean. If I’m not mistaken , one channels of x16 is just one feature needed to calculate mask, and multiplying input with this feature won’t produce clean spectrum.

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@atabakp @eagomez2 I found out that my losses was completely wrong, so I'd like to ask you about outputs of this model. From the paper it's completely not obvious in which order masks are for the second set of masks (channels 6-10). On the fig. 2 of the paper it's non-noise and noise, but at the end of section 3.2 its noise and non-noise. Did you figure it out?

from trunet.

atabakp avatar atabakp commented on June 27, 2024

@eagomez2 maybe you'll help me double check output interpretation.

x16 = self.up6(x15, x1) # x16 is the output of this implementation, except self.up6 has 10 channels instead of 5
mask_direct = calculate_PHM(x16[:, :5, :])    # calculate_PHM is your function from above comments (with fixes from @atabakp), which returns complex_mask 
result = x * mask_direct # x is the input -  complex spectrogram
out_wave = torch.istft(result,
                               n_fft=self.nfft,
                               hop_length=self.hop,
                               onesided=True,
                               window=self.window.to(wave.device),
                               center=True)

Is this the same-ish as yours? My main concern for now is mask_direct = calculate_PHM(x16[:, :5, :])

To verify the correct operation of the code, skip the calculate_PHM function. multiply one channel of x16 with the x.

By “correct” I meant that it produces expected output (clean speech). This function is “correct” in terms of data and shape, if that’s what you mean. If I’m not mistaken , one channels of x16 is just one feature needed to calculate mask, and multiplying input with this feature won’t produce clean spectrum.

single channels can also produce the clean; you only need to multiply the output mask(bounded to 0,1 with a sigmoid) with the magnitude of the noisy signal(x) and use the noisy phase(x.angle()) to construct the out_wave

from trunet.

atabakp avatar atabakp commented on June 27, 2024

@atabakp @eagomez2 I found out that my losses was completely wrong, so I'd like to ask you about outputs of this model. From the paper it's completely not obvious in which order masks are for the second set of masks (channels 6-10). On the fig. 2 of the paper it's non-noise and noise, but at the end of section 3.2 its noise and non-noise. Did you figure it out?

the order doesn't matter, just follow one, the network will adapt to correctly assign the corresponding output, irrespective of the initial order. My suggestion is to skip the PHM for now and make sure the rest of the code is ok.

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

So is it working now @JBloodless ?

I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some eps to stabilize the cos_phase term and it worked. Also sigmoid_tf_residual can be simplified to 1.0 - sigmoid_tf_target.

All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.

from trunet.

atabakp avatar atabakp commented on June 27, 2024

So is it working now @JBloodless ?

I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some eps to stabilize the cos_phase term and it worked. Also sigmoid_tf_residual can be simplified to 1.0 - sigmoid_tf_target.

All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.

I Totally agree, even the PHM is not very crucial, the network can directly output the clean speech mask.

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

So is it working now @JBloodless ?
I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some eps to stabilize the cos_phase term and it worked. Also sigmoid_tf_residual can be simplified to 1.0 - sigmoid_tf_target.
All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.

I Totally agree, even the PHM is not very crucial, the network can directly output the clean speech mask.

@atabakp thanks for bringing this up. Have you tried training without PHM? I was also curious about doing this, but I haven't found the time so far.

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

softmax with temperature,

Nope, still doesn't work. The only thing that "worked" is skipping PHM and multiplying one channel of last output with input, but I didn't wait for it to converge yet. I'll try these fixes, thanks.

One more question for @atabakp : I mentioned that my losses was wrong. By that I meant that in paper loss is the sum of losses for direct source, noise and reverberant path (last equation of section 3.3). How did you calculate them, did you do this sum of 3 or something else? Because I don't see good way to calculate target reverberant path, I just subtract clean signal from reverbed signal, and use tensor of 1e-6 for noise target (since I only train for dereverberation)
Also for @eagomez2 : I get that you used different loss - did you calculate it with only direct source?

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

So is it working now @JBloodless ?
I double-checked the code I used and is very similar to my initial post. The only changes I see is that I got rid of the randomness of gumbel softmax and simply replaced it by softmax with temperature, and I added some eps to stabilize the cos_phase term and it worked. Also sigmoid_tf_residual can be simplified to 1.0 - sigmoid_tf_target.
All in all, I'm inclined to think that even though the sign prediction math in the paper makes sense, in practice it is not as crucial for the network's performance.

I Totally agree, even the PHM is not very crucial, the network can directly output the clean speech mask.

@atabakp thanks for bringing this up. Have you tried training without PHM? I was also curious about doing this, but I haven't found the time so far.

Yes, I tried, I ended up using a single channel for masking the magnitude.

I've managed to train the model without PHM and with single loss on direct (same as paper, multires cosine similarity + multires spectrum MSE). The model converges:

Снимок экрана 2024-02-19 в 17 18 27

but the result is strange

Снимок экрана 2024-02-19 в 17 19 06

Maybe it's because of PCEN feature (my implementation may be not ideal), but voice harmonics in the spectrum seem to be "dereverbed", so I'll try to locate the reason of noisiness.

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@atabakp did you extract magnitude of input spectrum as torch.abs() and not as torch.real?

from trunet.

atabakp avatar atabakp commented on June 27, 2024

@atabakp did you extract magnitude of input spectrum as torch.abs() and not as torch.real?
Yes, with abs().

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

softmax with temperature,

Nope, still doesn't work. The only thing that "worked" is skipping PHM and multiplying one channel of last output with input, but I didn't wait for it to converge yet. I'll try these fixes, thanks.

One more question for @atabakp : I mentioned that my losses was wrong. By that I meant that in paper loss is the sum of losses for direct source, noise and reverberant path (last equation of section 3.3). How did you calculate them, did you do this sum of 3 or something else? Because I don't see good way to calculate target reverberant path, I just subtract clean signal from reverbed signal, and use tensor of 1e-6 for noise target (since I only train for dereverberation) Also for @eagomez2 : I get that you used different loss - did you calculate it with only direct source?

I calculated it for both direct residual but as mentioned by @atabakp , I'm inclined to think that direct should be good enough although I haven't tried this

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@atabakp @eagomez2 sorry for bothering again.

I trained couple of experiments with different parameters. I MAINLY can achieve dereverberation, but in all of my experiments there is weird artifact in lower frequencies. It looks like this strange line:

Снимок экрана 2024-02-29 в 14 00 01

Did you ever encountered this artifact?

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

Hi @JBloodless ,

In my case at least I didn't observe such problems using the GAN setup previously described.

from trunet.

atabakp avatar atabakp commented on June 27, 2024

@atabakp @eagomez2 sorry for bothering again.

I trained couple of experiments with different parameters. I MAINLY can achieve dereverberation, but in all of my experiments there is weird artifact in lower frequencies. It looks like this strange line:

Снимок экрана 2024-02-29 в 14 00 01 Did you ever encountered this artifact?

In the process, I discard the DC bin, substituting it with zeros. This situation may also arise from the use of paddings in transpose convolutions, leading to a consistent output of a constant value for the low-frequency bins.

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@atabakp @eagomez2 sorry for bothering again.
I trained couple of experiments with different parameters. I MAINLY can achieve dereverberation, but in all of my experiments there is weird artifact in lower frequencies. It looks like this strange line:
Снимок экрана 2024-02-29 в 14 00 01
Did you ever encountered this artifact?

In the process, I discard the DC bin, substituting it with zeros. This situation may also arise from the use of paddings in transpose convolutions, leading to a consistent output of a constant value for the low-frequency bins.

Do you discard it in the input only or also in the target?

from trunet.

atabakp avatar atabakp commented on June 27, 2024

@atabakp @eagomez2 sorry for bothering again.
I trained couple of experiments with different parameters. I MAINLY can achieve dereverberation, but in all of my experiments there is weird artifact in lower frequencies. It looks like this strange line:
Снимок экрана 2024-02-29 в 14 00 01
Did you ever encountered this artifact?

In the process, I discard the DC bin, substituting it with zeros. This situation may also arise from the use of paddings in transpose convolutions, leading to a consistent output of a constant value for the low-frequency bins.

Do you discard it in the input only or also in the target?

Discard in the input, and replace it with zero in the output, It means the model is not predicting the mask for DC bin.

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@atabakp @eagomez2 sorry for bothering again.
I trained couple of experiments with different parameters. I MAINLY can achieve dereverberation, but in all of my experiments there is weird artifact in lower frequencies. It looks like this strange line:
Снимок экрана 2024-02-29 в 14 00 01
Did you ever encountered this artifact?

In the process, I discard the DC bin, substituting it with zeros. This situation may also arise from the use of paddings in transpose convolutions, leading to a consistent output of a constant value for the low-frequency bins.

Do you discard it in the input only or also in the target?

Discard in the input, and replace it with zero in the output, It means the model is not predicting the mask for DC bin.

In my case model stopped dereverbing lower frequencies at all, and overall it sounds like no dereverbed at all (since psychoacoustics and all). Am I getting it correctly that you zeroed out lowest bin in input and model output, but not in target (for loss calculation)? And if so, which n_fft did you use?

For context: first is reverbed input, second - output of the model without zeroes in DC, last - output of the model with zeros in DC
Снимок экрана 2024-03-05 в 13 59 35

from trunet.

atabakp avatar atabakp commented on June 27, 2024

@atabakp @eagomez2 sorry for bothering again.

I trained couple of experiments with different parameters. I MAINLY can achieve dereverberation, but in all of my experiments there is weird artifact in lower frequencies. It looks like this strange line:

Снимок экрана 2024-02-29 в 14 00 01

Did you ever encountered this artifact?

In the process, I discard the DC bin, substituting it with zeros. This situation may also arise from the use of paddings in transpose convolutions, leading to a consistent output of a constant value for the low-frequency bins.

Do you discard it in the input only or also in the target?

Discard in the input, and replace it with zero in the output, It means the model is not predicting the mask for DC bin.

In my case model stopped dereverbing lower frequencies at all, and overall it sounds like no dereverbed at all (since psychoacoustics and all). Am I getting it correctly that you zeroed out lowest bin in input and model output, but not in target (for loss calculation)? And if so, which n_fft did you use?

For context: first is reverbed input, second - output of the model without zeroes in DC, last - output of the model with zeros in DC

Снимок экрана 2024-03-05 в 13 59 35

I am not considering the lowest bin in any calculation, and when reconstructing the signal(ifft) I am appending the 0 as lowest bin

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@atabakp @eagomez2 Did you try to train this model for 48kHz?

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

@JBloodless yes, I've trained it at 48kHz

from trunet.

atabakp avatar atabakp commented on June 27, 2024

@atabakp @eagomez2 Did you try to train this model for 48kHz?

I Tried with 16K

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@JBloodless yes, I've trained it at 48kHz

Can you share, what modifications of architecture did you make? For now I just made convolutions and GRU bigger, but it seems that it's not enough

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

@JBloodless yes, I've trained it at 48kHz

Can you share, what modifications of architecture did you make? For now I just made convolutions and GRU bigger, but it seems that it's not enough

I didn't do any changes. I just disabled any resampling algorithms (my data was originally at 48kHz) and trained it normally

from trunet.

JBloodless avatar JBloodless commented on June 27, 2024

@JBloodless yes, I've trained it at 48kHz

Can you share, what modifications of architecture did you make? For now I just made convolutions and GRU bigger, but it seems that it's not enough

I didn't do any changes. I just disabled any resampling algorithms (my data was originally at 48kHz) and trained it normally

You mentioned paper on GAN losses that you use. Did you use the same setup as in that paper? (discriminator on wave and adversarial + reconstruction)

from trunet.

eagomez2 avatar eagomez2 commented on June 27, 2024

@JBloodless yes, it's the same setup

from trunet.

Related Issues (6)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.