etzinis / two_step_mask_learning Goto Github PK
View Code? Open in Web Editor NEWA two step optimization for sound source separation on the adaptive front-end domain
A two step optimization for sound source separation on the adaptive front-end domain
Thanks for sharing the code! It's a really great work of audio source separation!
I have a question about the preprocess_wsj0mix.py:
As the length of some audios in wsj0-mix2 is shorter than 4 sec, after performing the codes in line 139-140, some audios are discarded after this preprocessing. The result is that there are only 17075 mixtures in the training set when using the "min" folder (this number should be 19885 when using the "max" folder). This is mismatched with the number (20000) mentioned in Section 3.2.1 of the paper. So I was wondering how many samples are finnally used in the experiment of speech separation in this paper?
I try to test my codes which calculate sdr with your separate samples(ex_18).
In my sdr codes, the result is about 6.47 while yours is 19.37.
can you help me find out anything wrong in my codes? Thx.
the codes are as follows.
`#!/usr/bin/env python
import soundfile as sf
from mir_eval.separation import bss_eval_sources
import numpy as np
import torch
from itertools import permutations
def cal_SDRi(src_ref, src_est, mix):
# Calculate Source-to-Distortion Ratio improvement (SDRi).
# NOTE: bss_eval_sources is very very slow.
# Args:
# src_ref: numpy.ndarray, [C, T]
# src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
# mix: numpy.ndarray, [T]
# Returns:
# average_SDRi
src_anchor = np.stack([mix, mix], axis=0)
sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est)
sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor)
avg_SDRi = ((sdr[0]-sdr0[0]) + (sdr[1]-sdr0[1])) / 2
return avg_SDRi
def cal_SISNRi(src_ref, src_est, mix):
# Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
# Args:
# src_ref: numpy.ndarray, [C, T]
# src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
# mix: numpy.ndarray, [T]
# Returns:
# average_SISNRi
#
sisnr1 = cal_SISNR(src_ref[0], src_est[0])
sisnr2 = cal_SISNR(src_ref[1], src_est[1])
sisnr1b = cal_SISNR(src_ref[0], mix)
sisnr2b = cal_SISNR(src_ref[1], mix)
avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2
return avg_SISNRi
def cal_SISNR(ref_sig, out_sig, eps=1e-8):
# Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
# Args:
# ref_sig: numpy.ndarray, [T]
# out_sig: numpy.ndarray, [T]
# Returns:
# SISNR
assert len(ref_sig) == len(out_sig)
ref_sig = ref_sig - np.mean(ref_sig)
out_sig = out_sig - np.mean(out_sig)
ref_energy = np.sum(ref_sig ** 2) + eps
proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy
noise = out_sig - proj
ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps)
sisnr = 10 * np.log(ratio + eps) / np.log(10.0)
return sisnr
def calc_sdr(estimation, origin):
# batch-wise SDR caculation for one audio file.
# estimation: (batch, nsample)
# origin: (batch, nsample)
origin_power = np.sum(origin**2, 1, keepdims=True) + 1e-8 # (batch, 1)
scale = np.sum(origin*estimation, 1, keepdims=True) / origin_power # (batch, 1)
est_true = scale * origin # (batch, nsample)
est_res = estimation - est_true # (batch, nsample)
true_power = np.sum(est_true**2, 1)
res_power = np.sum(est_res**2, 1)
return 10*np.log10(true_power) - 10*np.log10(res_power) # (batch, 1)
def compute_measures(se,s,j):
Rss=s.transpose().dot(s)
this_s=s[:,j]
a=this_s.transpose().dot(se)/Rss[j,j]
e_true=a*this_s
e_res=se-a*this_s
Sss=np.sum((e_true)**2)
Snn=np.sum((e_res)**2)
SDR=10*np.log10(Sss/Snn)
Rsr= s.transpose().dot(e_res)
b=np.linalg.inv(Rss).dot(Rsr)
e_interf = s.dot(b)
e_artif= e_res-e_interf
SIR=10*np.log10(Sss/np.sum((e_interf)**2))
SAR=10*np.log10(Sss/np.sum((e_artif)**2))
return SDR, SIR, SAR
def GetSDR(se,s):
se = se.transpose()
s = s.transpose()
se=se-np.mean(se,axis=0)
s=s-np.mean(s,axis=0)
nsampl,nsrc=se.shape
nsampl2,nsrc2=s.shape
assert(nsrc2==nsrc)
assert(nsampl2==nsampl)
SDR=np.zeros((nsrc,nsrc))
SIR=SDR.copy()
SAR=SDR.copy()
for jest in range(nsrc):
for jtrue in range(nsrc):
SDR[jest,jtrue],SIR[jest,jtrue],SAR[jest,jtrue]=compute_measures(se[:,jest],s,jtrue)
perm=list(permutations(np.arange(nsrc)))
nperm=len(perm)
meanSIR=np.zeros((nperm,))
for p in range(nperm):
tp=SIR.transpose().reshape(nsrc*nsrc)
idx=np.arange(nsrc)*nsrc+list(perm[p])
meanSIR[p]=np.mean(tp[idx])
popt=np.argmax(meanSIR)
per=list(perm[popt])
idx=np.arange(nsrc)*nsrc+per
SDR=SDR.transpose().reshape(nsrc*nsrc)[idx]
SIR=SIR.transpose().reshape(nsrc*nsrc)[idx]
SAR=SAR.transpose().reshape(nsrc*nsrc)[idx]
return SDR, SIR, SAR, per
EPS = 1e-8
def cal_si_snr_with_pit(source, estimate_source, source_lengths):
# Calculate SI-SNR with PIT training.
# Args:
# source: [B, C, T], B is batch size
# estimate_source: [B, C, T]
# source_lengths: [B], each item is between [0, T]
assert source.size() == estimate_source.size()
B, C, T = source.size()
# Step 1. Zero-mean norm
num_samples = source_lengths.view(-1, 1, 1).float() # [B, 1, 1]
mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples
zero_mean_target = source - mean_target
zero_mean_estimate = estimate_source - mean_estimate
# Step 2. SI-SNR with PIT
# reshape to use broadcast
s_target = torch.unsqueeze(zero_mean_target, dim=1) # [B, 1, C, T]
s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2) # [B, C, 1, T]
# s_target = <s', s>s / ||s||^2
pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True) # [B, C, C, 1]
s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + EPS # [B, 1, C, 1]
pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, C, C, T]
# e_noise = s' - s_target
e_noise = s_estimate - pair_wise_proj # [B, C, C, T]
# SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS)
pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS) # [B, C, C]
print('sisnr:',pair_wise_si_snr)
# Get max_snr of each utterance
# permutations, [C!, C]
perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
# one-hot, [C!, C, C]
index = torch.unsqueeze(perms, 2)
perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1)
# [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot])
max_snr_idx = torch.argmax(snr_set, dim=1) # [B]
# max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1)) # [B, 1]
max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
max_snr /= C
return max_snr, perms, max_snr_idx
def _sdr( y, z, SI=False):
if SI:
a = ((z*y).mean(-1) / (y*y).mean(-1)).unsqueeze(-1) * y
return 10*torch.log10( (a**2).mean(-1) / ((a-z)**2).mean(-1))
else:
return 10*torch.log10( (y*y).mean(-1) / ((y-z)**2).mean(-1))
def test():
mix = sf.read('./ex_18/mixture.wav')[0]
source = np.stack([sf.read('./ex_18/s1.wav')[0], sf.read('./ex_18/s2.wav')[0]], axis=0)
estimate_source = np.stack([sf.read('./ex_18/s1_estimate.wav')[0], sf.read('./ex_18/s2_estimate.wav')[0]], axis=0)
SDRi =cal_SDRi(source,estimate_source,mix)
SISNRi = cal_SISNRi(source,estimate_source,mix)
print('SDRi:{}'.format(SDRi))
print('SISNRi:{}\n'.format(SISNRi))
sdr1 = calc_sdr(source, estimate_source)
sdr2 = calc_sdr(source, np.stack([mix, mix], axis=0))
sdri = np.mean(sdr1-sdr2)
print('sdr1:{}'.format(sdr1))
print('sdr2:{}'.format(sdr2))
print('sdri:{}\n'.format(sdri))
SDR, SIR, SAR, per = GetSDR(estimate_source, source)
print('SDR:{}\nSIR:{}\nSAR:{}\nper:{}\n'.format(SDR, SIR, SAR, per))
source_lengths = torch.from_numpy(np.array([mix.shape]))
max_snr, _, _ = cal_si_snr_with_pit(torch.from_numpy(np.array([source])).float(),torch.from_numpy(np.array([estimate_source])).float(),source_lengths)
print('max_snr:{}\n'.format(max_snr))
SISDR = _sdr(torch.from_numpy(np.array([source])).float(),torch.from_numpy(np.array([estimate_source])).float(),SI=True)
print('SISDR: ',SISDR)
if __name__ == '__main__':
test()
And ouput:
# SDRi:7.892910056532607
# SISNRi:7.151290758024819
# sdr1:[6.68677879 6.25589817]
# sdr2:[-1.18234941 -0.17591568]
# sdri:7.150471030776565
# SDR:[6.68712455 6.25720926]
# SIR:[34.92616321 34.92229867]
# SAR:[6.69364393 6.26311903]
# per:[0, 1]
# sisnr: tensor([[[ 6.6871, -34.0962],
# [-34.1721, 6.2572]]])
# max_snr: tensor([[6.4722]])
# SISDR: tensor([[6.6868, 6.2559]])
# ex_18/metrics.json
# {
# "input_si_sdr": 0.028149127960205078,
# "input_sdr": 0.15109104033014964,
# "input_sir": 0.1510910403301708,
# "input_sar": 144.89122580687916,
# "input_stoi": 0.7178163832006375,
# "input_pesq": 1.599277138710022,
# "si_sdr": 19.083293914794922,
# "sdr": 19.376235432506704,
# "sir": 30.187015165321924,
# "sar": 19.759935974444744,
# "stoi": 0.9568062920227058,
# "pesq": 3.562618613243103,
# "mix_path": "/mnt/data/wham/wav8k/min/tt/mix_clean/050a050c_0.050237_442c020j_-0.050237.wav"
# }
Excuse me,I want to study your code but I can't understand the correct run steps.Could you provide a run steps explanation on readme?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.