基于Keras的模型剪枝(Pruning)

设置

!pip install -q tensorflow-model-optimization

import tempfile
import os

import tensorflow as tf
import numpy as np

from tensorflow_model_optimization.python.core.keras.compat import keras

%load_ext tensorboard

 

在不修剪的情况下为 MNIST 训练模型

# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
)

 

评估基线测试准确性并保存模型以供以后使用

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)

 

预训练模型 Pruning

Start the model with 50% sparsity (50% zeros in weights) and end with 80% sparsity.
动态剪枝:PolynomialDecay策略意味着剪枝率不是一个固定的数字,而是随着训练步骤的增加而动态调整。这种动态调整允许模型在训练初期保持更多的连接,随着模型对剪枝的适应,逐渐增加剪枝的强度,这样可以帮助模型保持一定的性能,同时实现模型大小和计算资源的优化。通过这种方式,模型在剪枝过程结束时,能达到较高的稀疏度,同时在训练的早期阶段避免了剪枝过多可能导致的信息损失。这种逐渐增加稀疏度的策略,允许网络在训练过程中逐步适应这些改变,从而可能达到更好的最终性能。

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude # 移除权重中幅度最小的部分

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
# end_step 用于定义PolynomialDecay策略中剪枝率达到最终稀疏度的时刻。这个计算确保了剪枝过程能够在指定的训练时长内完成。

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,  #模型一开始训练就剪枝
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params) #模型定义

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.summary()

 

根据 baseline 训练和评估模型

Fine tune with pruning for two epochs.

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

#For this example, there is minimal loss in test accuracy after pruning, compared to the baseline.

_, model_for_pruning_accuracy = model_for_pruning.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)

# The logs show the progression of sparsity on a per-layer basis.

#docs_infra: no_execute
%tensorboard --logdir={logdir}

 

Create 3x smaller models

上面只是用了标准的剪枝,实际应用还有两个方面可以继续压缩模型:

  • tfmot.sparsity.keras.strip_pruning 函数被用于去除模型中与修剪相关的所有临时变量(例如,修剪掩码),因为这些变量在训练之后用于推断不再需要,但会增加模型的大小。
  • 修剪操作通常会导致模型中许多权重变为零(这是通过将不重要的权重设为零来实现的)。序列化(保存到文件中)后的权重矩阵尺寸与修剪之前相同,尽管它包含了许多零值。标准压缩算法(如gzip)**可以识别这些冗余的零值,并通过仅存储非零信息来进一步压缩模型文件。
# First, create a compressible model for TensorFlow.
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

_, pruned_keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)

# Then, create a compressible model for TFLite.
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()

_, pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)

print('Saved pruned TFLite model to:', pruned_tflite_file)

# Define a helper function to actually compress the models via gzip and measure the zipped size.

def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file)))

 

Create a 10x smaller model from combining pruning and quantization

除了剪枝,我们可以继续用量化技术(PTQ)来压缩模型。
converter.optimizations = [tf.lite.Optimize.DEFAULT] 这行API代码执行了量化操作。

You can apply post-training quantization to the pruned model for additional benefits.

converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()

_, quantized_and_pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(quantized_and_pruned_tflite_file, 'wb') as f:
  f.write(quantized_and_pruned_tflite_model)

print('Saved quantized and pruned TFLite model to:', quantized_and_pruned_tflite_file)

print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_pruned_tflite_file)))

 

See persistence of accuracy from TF to TFLite

Define a helper function to evaluate the TF Lite model on the test dataset.

import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on ever y image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy


# You evaluate the pruned and quantized model and see that the accuracy from TensorFlow persists to the TFLite backend.

interpreter = tf.lite.Interpreter(model_content=quantized_and_pruned_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Pruned and quantized TFLite test_accuracy:', test_accuracy)
print('Pruned TF test accuracy:', model_for_pruning_accuracy)

相关推荐

  1. 基于Keras模型剪枝Pruning

    2024-03-17 15:12:05       33 阅读
  2. 模型剪枝——RETHINKING THE VALUE OF NETWORK PRUNING

    2024-03-17 15:12:05       140 阅读
  3. 深度学习模型剪枝

    2024-03-17 15:12:05       28 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-17 15:12:05       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-17 15:12:05       72 阅读
  3. 在Django里面运行非项目文件

    2024-03-17 15:12:05       58 阅读
  4. Python语言-面向对象

    2024-03-17 15:12:05       69 阅读

热门阅读

  1. MySQL是什么

    2024-03-17 15:12:05       38 阅读
  2. Seata的详细解释

    2024-03-17 15:12:05       38 阅读
  3. 蓝桥杯(3.16 刷真题)

    2024-03-17 15:12:05       36 阅读
  4. CSS3DRenderer, CSS3DSprite API 使用案例demo

    2024-03-17 15:12:05       38 阅读
  5. Google Hacking技术

    2024-03-17 15:12:05       32 阅读
  6. mysql报错日志查看

    2024-03-17 15:12:05       43 阅读
  7. Synchronized关键字的底层原理

    2024-03-17 15:12:05       39 阅读
  8. 令牌桶算法和漏桶算法

    2024-03-17 15:12:05       31 阅读
  9. 【Docker】Prometheus 容器部署及应用

    2024-03-17 15:12:05       36 阅读