Comments (6)
test_CQT_1992_v2
I've been trying to track which function might the one breaking these tests, and I'm ultra confused now. I rewrote the test a bit, so only the complex CQT ground truth is loaded; and magnitude and phase ground truths are extracted from there.
I'm even more confused now:
@pytest.mark.parametrize("device", [*device_args])
def test_cqt_1992_v2_linear(device):
# Linear sweep case
fs = 44100
t = 1
f0 = 55
f1 = 22050
s = np.linspace(0, t, fs * t)
x = chirp(s, f0, 1, f1, method="linear")
x = x.astype(dtype=np.float32)
# Get the complex ground truth for CQT1992v2, magnitude and phase can be derived from it
complex_ground_truth = np.load(
os.path.join(
dir_path, "ground-truths/linear-sweep-cqt-1992-complex-ground-truth.npy"
)
)
ground_truth_complex_real = complex_ground_truth[..., 0]
ground_truth_complex_img = complex_ground_truth[..., 1]
magnitude_ground_truth = np.log(np.sqrt(np.power(ground_truth_complex_real, 2) + np.power(ground_truth_complex_img, 2)) + 1e-5)
phase_ground_truth_atan2 = np.arctan2(ground_truth_complex_img, ground_truth_complex_real)
phase_ground_truth_real = np.cos(phase_ground_truth_atan2)
phase_ground_truth_img = np.sin(phase_ground_truth_atan2)
phase_ground_truth = np.stack([phase_ground_truth_real, phase_ground_truth_img], axis=-1)
# Magnitude
stft = CQT1992v2(
sr=fs, fmin=55, output_format="Magnitude", n_bins=207, bins_per_octave=24
).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
X = torch.log(X + 1e-5)
assert np.allclose(X.cpu(), magnitude_ground_truth, rtol=1e-3, atol=1e-3)
# Complex
stft = CQT1992v2(
sr=fs, fmin=55, output_format="Complex", n_bins=207, bins_per_octave=24
).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
assert np.allclose(X.cpu(), complex_ground_truth, rtol=1e-3, atol=1e-3)
assert np.allclose(np.arctan2(X[..., 1].cpu(), X[..., 0].cpu()), phase_ground_truth_atan2, rtol=1e-3, atol=1e-3)
assert np.allclose(torch.atan2(X[..., 1].cpu(), X[..., 0].cpu()), phase_ground_truth_atan2, rtol=1e-3, atol=1e-3)
# Phase
stft = CQT1992v2(
sr=fs, fmin=55, output_format="Phase", n_bins=207, bins_per_octave=24
).to(device)
X = stft(torch.tensor(x, device=device).unsqueeze(0))
assert np.allclose(X.cpu(), phase_ground_truth, rtol=1e-3, atol=1e-3)
This is quite surprising, as the complex CQT seems to be extracted correctly. Am I doing something wrong when calculating the phase?
How can it be possible that not even np.arctan2
gets the same result? (Look at second assertion of the # Complex
section). But what is surprising to me is that it fails only on the CPU; I'm testing this on a MacBook with MPS, and when I use MPS the test pass perfectly.
from nnaudio.
Thank you @migperfer for the investigation.
May I know which pytorch and numpy version are you using?
I tried torch 1.8.1 and numpy=1.19.5. It seems the results for CQT are all fine now. So I believe that it is related to the notorious floating-point error. Can you try adjusting the tolerance a bit such as rtol=1e-2
and atol=1e-2
?
I have attached the full unit test result below
unit_test_report.txt
There are still problems for CFP
, STFT
, and VQT
. I am investigating the issues now.
from nnaudio.
I managed to resolve all failed test cases. Turns out it's all floating-point error. I make the conditions less strict and all tests pass now.
from nnaudio.
They are working indeed!
All tests are working for me except the CQT ones.
I think there could be a potential issue with MacBooks. I'm using a MacBook, phase is still wrong even with tolerances 1e-1
.
I see that #126 do not change at all the CQT and still is working on Github actions (I Imagine they run on linux), so I'd say the Macbook case a very nitpick one.
from nnaudio.
Are you using a MacBook with M1/M2 processor? Not sure if it induces extra problems.
Well I would say floating-point error is really a big headache.
from nnaudio.
An M2 indeed. And yes, it is a headache 😂
from nnaudio.
Related Issues (20)
- CQT HOT 2
- General improvements HOT 9
- Learnable Window HOT 3
- CQT doesn't work on waveforms on short chunks like 0.5s HOT 2
- [Feature Request] Allow STFT kernels to be normalized HOT 3
- [Feature request] Log2 (octave) normalization in STFT HOT 1
- Mel_Basis kernel HOT 1
- Spectrograms not updating well at low frequency bins HOT 3
- Apply for multi-channel signal HOT 2
- Difficulty in using VQT feature with GPU support HOT 1
- test_stft fails with librosa 0.9.2 (python 3.9.13, numpy 1.22.2). HOT 1
- nnAudio-CPU memory growth HOT 2
- `conv1d` padding needs to be a tuple in `utils.downsampling_by_2`
- Looking for a simple example of STFT/ iSTFT HOT 3
- [Feature Request/Inquiry] CQT inversion HOT 1
- Gammatone Filterbank waveform outputs HOT 4
- cannot re-initialize CUDA in forked subproess HOT 1
- torchscript support HOT 1
- pip package out of date HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from nnaudio.