Tried freezing the no-top quicknet models, and training a linear classifier on top of them, in order to classify images from the Imagenette dataset (10 easy classes from ImageNet).
Because the pretrained zoo models are trained on the superset of this dataset, I expected the pretrained embedders to perform very well, but they did not succeed in reaching above 50% accuracy.
However, when I manually cut the full models, the embedders work as expected and reach 95% easily, hinting the problem is with the no-top pretrained weights.
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
from larq_zoo.training.data import preprocess_image_bytes
from larq_zoo.sota import QuickNet, QuickNetLarge, QuickNetXL
from zookeeper import cli, task, Field
from typing import Callable, Tuple, Optional
class EmbedderWrapperModel(keras.Model):
def __init__(self, zoo_class: Callable[..., keras.Model],
input_shape: int, num_classes: int, dynamic=False,
finetune_basenet=True, pretrained_basenet=True, cut_layer_name: Optional[str] = None):
super(EmbedderWrapperModel, self).__init__(dynamic=dynamic)
self.basenet = self._get_basenet(zoo_class, input_shape, finetune_basenet, pretrained_basenet, cut_layer_name)
global_pool_shape = self.basenet.output_shape[1], self.basenet.output_shape[2]
self.batch_norm = keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
self.global_pool = keras.layers.AveragePooling2D(pool_size=global_pool_shape)
self.dense_softmax = keras.layers.Dense(num_classes, activation=tf.nn.softmax)
def _get_basenet(self, zoo_class: Callable[..., keras.Model], input_shape: int,
finetune_basenet: bool, pretrained_basenet: bool, cut_layer_name: Optional[str]) -> keras.Model:
weights = "imagenet" if pretrained_basenet else None
if not cut_layer_name:
basenet = zoo_class(input_shape=(input_shape, input_shape, 3), include_top=False, weights=weights)
else:
full_zoo_model = zoo_class(input_shape=(input_shape, input_shape, 3), include_top=True, weights=weights)
inputs, outputs = full_zoo_model.inputs, full_zoo_model.get_layer(cut_layer_name).output
basenet = keras.Model(inputs=inputs, outputs=outputs)
basenet.trainable = finetune_basenet
return basenet
def call(self, inputs, training=False, mask=None):
x = self.basenet(inputs, training=training)
x = self.batch_norm(x, training=training)
x = self.global_pool(x)
x = keras.layers.Flatten()(x)
x = self.dense_softmax(x)
return x
def wrap_preprocessing(preprocessing: Callable, training=False) -> Callable:
return lambda x, y: (preprocessing(x, training), y)
def get_imagenette_dataset(batch_size: int, preprocessing: Callable,
parallel=True) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
decoders = {"image": tfds.decode.SkipDecoding()}
total_dataset = tfds.load('imagenette', split=None, shuffle_files=True, as_supervised=True, decoders=decoders)
train, test = total_dataset['train'], total_dataset['validation']
parallelism = tf.data.experimental.AUTOTUNE if parallel else None
train_dataset = (
train.cache()
.shuffle(10 * batch_size, reshuffle_each_iteration=True)
.map(wrap_preprocessing(preprocessing, training=True), num_parallel_calls=parallelism)
.batch(batch_size)
)
test_dataset = (
test.cache()
.map(wrap_preprocessing(preprocessing), num_parallel_calls=parallelism)
.batch(batch_size)
)
return train_dataset, test_dataset
class BugTest:
INPUT_SHAPE = 224
CLASSES_NUM = 10
EPOCHS = 3
BATCH_SIZE = 256
LEARNING_RATE = 1e-2
LR_DECAY = 0.1
DECAY_EVERY = 30
FINETUNE_BASENET = False
PRETRAINED_BASENET = True
def test_bug(self, cut_full_model: bool, model_class: Callable):
# `activation` is last relu layer before global pooling
cut_layer_name = "activation" if cut_full_model else None
model = EmbedderWrapperModel(model_class, self.INPUT_SHAPE, self.CLASSES_NUM,
finetune_basenet=self.FINETUNE_BASENET, pretrained_basenet=self.PRETRAINED_BASENET,
cut_layer_name=cut_layer_name)
train_dataset, test_dataset = get_imagenette_dataset(self.BATCH_SIZE, preprocess_image_bytes)
optimizer = keras.optimizers.Adam(self.LEARNING_RATE)
loss = keras.losses.SparseCategoricalCrossentropy()
metrics = [keras.metrics.SparseCategoricalAccuracy()]
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
model.fit(train_dataset, epochs=self.EPOCHS, validation_data=test_dataset)
@task
class QuickNetBugTest(BugTest):
cut_full_model: bool = Field(False)
def run(self):
self.test_bug(self.cut_full_model, QuickNet)
@task
class QuickNetLargeBugTest(BugTest):
cut_full_model: bool = Field(False)
def run(self):
self.test_bug(self.cut_full_model, QuickNetLarge)
@task
class QuickNetXLBugTest(BugTest):
cut_full_model: bool = Field(False)
def run(self):
self.test_bug(self.cut_full_model, QuickNetXL)
if __name__ == "__main__":
cli()
Expected the pretrained no-top models and the cut pretrained full models to perform the same, instead got the following discrepancy: