Giter VIP home page Giter VIP logo

two_step_mask_learning's People

Contributors

etzinis avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

two_step_mask_learning's Issues

The preprocessing results in less than 20000 mixtures in training set of wsj0-mix2.

Thanks for sharing the code! It's a really great work of audio source separation!
I have a question about the preprocess_wsj0mix.py:
As the length of some audios in wsj0-mix2 is shorter than 4 sec, after performing the codes in line 139-140, some audios are discarded after this preprocessing. The result is that there are only 17075 mixtures in the training set when using the "min" folder (this number should be 19885 when using the "max" folder). This is mismatched with the number (20000) mentioned in Section 3.2.1 of the paper. So I was wondering how many samples are finnally used in the experiment of speech separation in this paper?

about sdr

I try to test my codes which calculate sdr with your separate samples(ex_18).

In my sdr codes, the result is about 6.47 while yours is 19.37.

can you help me find out anything wrong in my codes? Thx.

the codes are as follows.

`#!/usr/bin/env python

import soundfile as sf
from mir_eval.separation import bss_eval_sources
import numpy as np

import torch

from itertools import permutations

def cal_SDRi(src_ref, src_est, mix):
    # Calculate Source-to-Distortion Ratio improvement (SDRi).
    # NOTE: bss_eval_sources is very very slow.
    # Args:
    #     src_ref: numpy.ndarray, [C, T]
    #     src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
    #     mix: numpy.ndarray, [T]
    # Returns:
    #     average_SDRi

    src_anchor = np.stack([mix, mix], axis=0)
    sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est)
    sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor)
    avg_SDRi = ((sdr[0]-sdr0[0]) + (sdr[1]-sdr0[1])) / 2
    return avg_SDRi

def cal_SISNRi(src_ref, src_est, mix):
    # Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
    # Args:
    #     src_ref: numpy.ndarray, [C, T]
    #     src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
    #     mix: numpy.ndarray, [T]
    # Returns:
    #     average_SISNRi
    # 
    sisnr1 = cal_SISNR(src_ref[0], src_est[0])
    sisnr2 = cal_SISNR(src_ref[1], src_est[1])
    sisnr1b = cal_SISNR(src_ref[0], mix)
    sisnr2b = cal_SISNR(src_ref[1], mix)
    avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2
    return avg_SISNRi

def cal_SISNR(ref_sig, out_sig, eps=1e-8):
    # Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
    # Args:
    #     ref_sig: numpy.ndarray, [T]
    #     out_sig: numpy.ndarray, [T]
    # Returns:
    #     SISNR

    assert len(ref_sig) == len(out_sig)
    ref_sig = ref_sig - np.mean(ref_sig)
    out_sig = out_sig - np.mean(out_sig)
    ref_energy = np.sum(ref_sig ** 2) + eps
    proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy
    noise = out_sig - proj
    ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps)
    sisnr = 10 * np.log(ratio + eps) / np.log(10.0)
    return sisnr


def calc_sdr(estimation, origin):

    # batch-wise SDR caculation for one audio file.
    # estimation: (batch, nsample)
    # origin: (batch, nsample)

      
    origin_power = np.sum(origin**2, 1, keepdims=True) + 1e-8  # (batch, 1)
    
    scale = np.sum(origin*estimation, 1, keepdims=True) / origin_power  # (batch, 1)
    
    est_true = scale * origin  # (batch, nsample)
    est_res = estimation - est_true  # (batch, nsample)
    
    true_power = np.sum(est_true**2, 1)
    res_power = np.sum(est_res**2, 1)
    
    return 10*np.log10(true_power) - 10*np.log10(res_power)  # (batch, 1)

def compute_measures(se,s,j):
    Rss=s.transpose().dot(s)
    this_s=s[:,j]

    a=this_s.transpose().dot(se)/Rss[j,j]
    e_true=a*this_s
    e_res=se-a*this_s
    Sss=np.sum((e_true)**2)
    Snn=np.sum((e_res)**2)

    SDR=10*np.log10(Sss/Snn)

    Rsr= s.transpose().dot(e_res)
    b=np.linalg.inv(Rss).dot(Rsr)

    e_interf = s.dot(b)
    e_artif= e_res-e_interf

    SIR=10*np.log10(Sss/np.sum((e_interf)**2))
    SAR=10*np.log10(Sss/np.sum((e_artif)**2))
    return SDR, SIR, SAR

def GetSDR(se,s):
    se = se.transpose()
    s = s.transpose()

    se=se-np.mean(se,axis=0)
    s=s-np.mean(s,axis=0)
    nsampl,nsrc=se.shape
    nsampl2,nsrc2=s.shape

    assert(nsrc2==nsrc)
    assert(nsampl2==nsampl)

    SDR=np.zeros((nsrc,nsrc))
    SIR=SDR.copy()
    SAR=SDR.copy()

    for jest in range(nsrc):
        for jtrue in range(nsrc):
            SDR[jest,jtrue],SIR[jest,jtrue],SAR[jest,jtrue]=compute_measures(se[:,jest],s,jtrue)


    perm=list(permutations(np.arange(nsrc)))
    nperm=len(perm)
    meanSIR=np.zeros((nperm,))
    for p in range(nperm):
        tp=SIR.transpose().reshape(nsrc*nsrc)
        idx=np.arange(nsrc)*nsrc+list(perm[p])
        meanSIR[p]=np.mean(tp[idx])
    popt=np.argmax(meanSIR)
    per=list(perm[popt])
    idx=np.arange(nsrc)*nsrc+per
    SDR=SDR.transpose().reshape(nsrc*nsrc)[idx]
    SIR=SIR.transpose().reshape(nsrc*nsrc)[idx]
    SAR=SAR.transpose().reshape(nsrc*nsrc)[idx]
    return SDR, SIR, SAR, per

EPS = 1e-8


def cal_si_snr_with_pit(source, estimate_source, source_lengths):
    # Calculate SI-SNR with PIT training.
    # Args:
    #     source: [B, C, T], B is batch size
    #     estimate_source: [B, C, T]
    #     source_lengths: [B], each item is between [0, T]

    assert source.size() == estimate_source.size()
    B, C, T = source.size()

    # Step 1. Zero-mean norm
    num_samples = source_lengths.view(-1, 1, 1).float()  # [B, 1, 1]
    mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
    mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate

    # Step 2. SI-SNR with PIT
    # reshape to use broadcast
    s_target = torch.unsqueeze(zero_mean_target, dim=1)  # [B, 1, C, T]
    s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2)  # [B, C, 1, T]
    # s_target = <s', s>s / ||s||^2
    pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True)  # [B, C, C, 1]
    s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + EPS  # [B, 1, C, 1]
    pair_wise_proj = pair_wise_dot * s_target / s_target_energy  # [B, C, C, T]
    # e_noise = s' - s_target
    e_noise = s_estimate - pair_wise_proj  # [B, C, C, T]
    # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
    pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS)
    pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS)  # [B, C, C]
    print('sisnr:',pair_wise_si_snr)

    # Get max_snr of each utterance
    # permutations, [C!, C]
    perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
    # one-hot, [C!, C, C]
    index = torch.unsqueeze(perms, 2)
    perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1)
    # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
    snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot])
    max_snr_idx = torch.argmax(snr_set, dim=1)  # [B]
    # max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1))  # [B, 1]
    max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
    max_snr /= C
    return max_snr, perms, max_snr_idx

def _sdr( y, z, SI=False):
    if SI:
        a = ((z*y).mean(-1) / (y*y).mean(-1)).unsqueeze(-1) * y
        return 10*torch.log10( (a**2).mean(-1) / ((a-z)**2).mean(-1))
    else:
        return 10*torch.log10( (y*y).mean(-1) / ((y-z)**2).mean(-1))

def test():   
    mix = sf.read('./ex_18/mixture.wav')[0]
    source = np.stack([sf.read('./ex_18/s1.wav')[0], sf.read('./ex_18/s2.wav')[0]], axis=0)
    estimate_source = np.stack([sf.read('./ex_18/s1_estimate.wav')[0], sf.read('./ex_18/s2_estimate.wav')[0]], axis=0)

    SDRi =cal_SDRi(source,estimate_source,mix)
    SISNRi = cal_SISNRi(source,estimate_source,mix)
    print('SDRi:{}'.format(SDRi))
    print('SISNRi:{}\n'.format(SISNRi))

    sdr1 = calc_sdr(source, estimate_source)
    sdr2 = calc_sdr(source, np.stack([mix, mix], axis=0))
    sdri = np.mean(sdr1-sdr2)
    print('sdr1:{}'.format(sdr1))
    print('sdr2:{}'.format(sdr2))
    print('sdri:{}\n'.format(sdri))

    SDR, SIR, SAR, per = GetSDR(estimate_source, source)
    print('SDR:{}\nSIR:{}\nSAR:{}\nper:{}\n'.format(SDR, SIR, SAR, per))

    source_lengths = torch.from_numpy(np.array([mix.shape]))
    max_snr, _, _ = cal_si_snr_with_pit(torch.from_numpy(np.array([source])).float(),torch.from_numpy(np.array([estimate_source])).float(),source_lengths)
    print('max_snr:{}\n'.format(max_snr))

    SISDR = _sdr(torch.from_numpy(np.array([source])).float(),torch.from_numpy(np.array([estimate_source])).float(),SI=True)
    print('SISDR: ',SISDR)


if __name__ == '__main__':
    test()

And ouput:

    #     SDRi:7.892910056532607
    #     SISNRi:7.151290758024819

    #     sdr1:[6.68677879 6.25589817]
    #     sdr2:[-1.18234941 -0.17591568]
    #     sdri:7.150471030776565

    #     SDR:[6.68712455 6.25720926]
    #     SIR:[34.92616321 34.92229867]
    #     SAR:[6.69364393 6.26311903]
    #     per:[0, 1]

    #     sisnr: tensor([[[  6.6871, -34.0962],
    #              [-34.1721,   6.2572]]])
    #     max_snr:  tensor([[6.4722]])

    #     SISDR:  tensor([[6.6868, 6.2559]])

    # ex_18/metrics.json
    # {
    # "input_si_sdr": 0.028149127960205078,
    # "input_sdr": 0.15109104033014964,
    # "input_sir": 0.1510910403301708,
    # "input_sar": 144.89122580687916,
    # "input_stoi": 0.7178163832006375,
    # "input_pesq": 1.599277138710022,
    # "si_sdr": 19.083293914794922,
    # "sdr": 19.376235432506704,
    # "sir": 30.187015165321924,
    # "sar": 19.759935974444744,
    # "stoi": 0.9568062920227058,
    # "pesq": 3.562618613243103,
    # "mix_path": "/mnt/data/wham/wav8k/min/tt/mix_clean/050a050c_0.050237_442c020j_-0.050237.wav"
    # }

About readme

Excuse me,I want to study your code but I can't understand the correct run steps.Could you provide a run steps explanation on readme?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.