I am trying to reimplement MAML by myself, but I figured it out that implementing batch norm is tricky. Since statistics of data can change over internal iteration (inside loop) by SGD, it is hard to decide how beta and gamma (not used for this case because of ReLU nonlinearity) should be updated. Furthermore, it is rather unclear which means and variances should be tracked for validation. Should it be tracked overall mean and variance across tasks and internal iterations? Or, should the statistics before it is tuned to tasks be measured?
I tried some of the possible strategies, but the result did not match up with the performance reported. The best performance was observed without using batch norm, but it still 95% for Omniglot Dataset for 5-way, 1-shot problem.
With this consideration, I looked up the code and I found out that moving means and variances are not saved and restored. I could not find:
Restoring model weights from logs/omniglot5way//cls_5.mbs_32.ubs_1.numstep1.updatelr0.4batchnorm/model2000
[array([-0.01856588, -0.07139841, 0.00564138, 0.01990231, -0.00762643,
-0.05611767, 0.07031356, -0.03621985, 0.05920702, 0.12788662,
0.04555263, -0.06217157, -0.07977205, 0.01632672, -0.03578645,
0.10676555, -0.04455299, -0.0573478 , 0.52247691, -0.05695038,
0.14302482, -0.07892933, -0.02123305, 0.01870824, 0.01471483,
-0.06067625, 0.097821 , -0.05786318, 0.03801388, -0.04843186,
-0.01786073, 0.0293963 , 0.56441385, 0.07509601, 0.11491237,
0.01052142, 0.23142786, 0.03433308, -0.05783347, -0.0444839 ,
0.02227049, -0.02804896, -0.04594825, 0.05347209, -0.0399643 ,
0.02923759, 0.1299762 , -0.02817831, -0.0735756 , -0.0284342 ,
-0.04498725, -0.05203079, -0.04267518, -0.03341504, -0.05648317,
-0.02747083, -0.03525382, 0.34740165, -0.00822794, 0.03952603,
0.03410957, 0.29954502, -0.01362322, -0.04790628], dtype=float32), array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), array([ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)]
It feels like the only beta variables are updated, so it seems what batchnorm layers do is shifting only.