Giter VIP home page Giter VIP logo

pytorch_stoi's Introduction

PyTorch implementation of STOI

Build Status PyPI Status

Implementation of the classical and extended Short Term Objective Intelligibility in PyTorch. See also Cees Taal's website and the python implementation

Install

pip install torch_stoi

Important warning

This implementation is intended to be used as a loss function only.
It doesn't replicate the exact behavior of the original metrics but the results should be close enough that it can be used as a loss function. See the Notes in the NegSTOILoss class.

Quantitative comparison coming soon hopefully ๐Ÿš€

Usage

import torch
from torch import nn
from torch_stoi import NegSTOILoss

sample_rate = 16000
loss_func = NegSTOILoss(sample_rate=sample_rate)
# Your nnet and optimizer definition here
nnet = nn.Module()

noisy_speech = torch.randn(2, 16000)
clean_speech = torch.randn(2, 16000)
# Estimate clean speech
est_speech = nnet(noisy_speech)
# Compute loss and backward (then step etc...)
loss_batch = loss_func(est_speech, clean_speech)
loss_batch.mean().backward()

Comparing NumPy and PyTorch versions : the static test

Values obtained with the NumPy version (commit 84b1bd8) are compared to the PyTorch version in the following graphs.

8kHz

Classic STOI measure

Extended STOI measure

16kHz

Classic STOI measure

Extended STOI measure

16kHz signals used to compare both versions contained a lot of silence, which explains why the match is very bad without VAD.

Comparing NumPy and PyTorch versions : Training a DNN

Coming in the near future

References

  • [1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time Objective Intelligibility Measure for Time-Frequency Weighted Noisy Speech', ICASSP 2010, Texas, Dallas.
  • [2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for Intelligibility Prediction of Time-Frequency Weighted Noisy Speech', IEEE Transactions on Audio, Speech, and Language Processing, 2011.
  • [3] J. Jensen and C. H. Taal, 'An Algorithm for Predicting the Intelligibility of Speech Masked by Modulated Noise Maskers', IEEE Transactions on Audio, Speech and Language Processing, 2016.

pytorch_stoi's People

Contributors

iver56 avatar jonashaag avatar mpariente avatar philgzl avatar tuzehai avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

pytorch_stoi's Issues

NaN when back-propagation

Hi,

I met an issue when using the torch_stoi in the training. The gradients will become NaN sometimes, and I fixed this by adding an +EPS in line 127 and line 128. I think the sqrt of 0 might lead to a NaN problem.

Thanks

CuBLAS error when OBM is not on GPU in PyTorch nightly

I use PyTorch nightly for their STFT implementation. pytorch_stoi crashes with

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`

at the OBM matrix multiplication:

x_tob = torch.matmul(self.OBM, torch.norm(x_spec, 2, -1) ** 2).pow(0.5)

Moving self.OBM to the GPU first fixes the issue:

self.OBM = nn.Parameter(torch.from_numpy(obm_mat).float(),
                        requires_grad=False).cuda()  # <==

Do you think that's a bug in PyTorch nightly?

My version:

>>> torch.__version__
'1.7.0.dev20200818'
>>> torch.version.cuda
'10.2'
>>>
KeyboardInterrupt
>>> torch.version.git_version
'14b1e2392c2b91aea922a33b916cee5cf079c5b8'

Support for `lengths` argument

Would you be interested in adding support for a lengths argument that contains the original lengths of the waveforms before batching? The user might be padding waveforms to match the length and create the batch, but this can change the results. If use_vad=True and the user is padding with zeros, the difference is small but not zero due to the last frame overlapping with padded zeros.

the range of pytorch_stoi's results

Hi Pariente,

thanks for sharing such convenient tool, I have some questions about your testing results. I just guess that you modified the different SNR level to observe the different sample points. If SNR become large, the value of pystoi will approximate to 1.

Basing on the formula of original paper:
image
I think the range of the result in ideal stoi is in [-1, 1].

However, in your result(which is in 16KHz with VAD) shows that the value of pytorch version can be higher than 1.
image

Could you please provide more detailed information about how you test?

Best,
Chi-Chang Lee.

stoi loss The output value is negative

When I used torch-stoi as loss function for my network , the output of stoi is negative all time, does it normal or just I used it wrong.
My usage just like the code belowk, I put enhanced speech as first input parameter while the clean speech as second.

Please Please correct me and tell me the right way!!

noisy_speech = torch.randn(2, 16000)
clean_speech = torch.randn(2, 16000)
# Estimate clean speech
est_speech = nnet(noisy_speech)
# Compute loss and backward (then step etc...)
loss_batch = loss_func(est_speech, clean_speech)
loss_batch.mean().backward()

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.