Сериализация моделей
Это очень важная тема, ведь часто нам нужно сохранить уже обученную модель, чтобы потом быстро её загрузить и использовать для предсказаний, не тратя время на повторное обучение.
Сериализация моделей позволяет сохранить всю информацию о модели (архитектуру, веса, состояние оптимизатора и т.д.) в файл, который потом можно легко загрузить обратно. Давайте разберём этот процесс на примере популярных фреймворков - PyTorch и Keras. Начнём с PyTorch. Допустим, у нас есть обученная модель model. Чтобы сохранить её, достаточно вызвать метод torch.save():
import torch
# Сохраняем модель
torch.save(model.state_dict(), 'model.pth')
Здесь model.state_dict() возвращает словарь с состоянием модели, который мы сохраняем в файл 'model.pth'. Теперь, чтобы загрузить модель, создадим экземпляр класса модели и загрузим в него сохранённое состояние:
# Создаём экземпляр модели
model = MyModel(*args, **kwargs)
# Загружаем состояние
model.load_state_dict(torch.load('model.pth'))
model.eval()
Вызов model.eval() переводит модель в режим инференса. Всё, модель готова к использованию! Теперь посмотрим, как сериализация работает в Keras. Здесь всё ещё проще:
from tensorflow import keras
# Сохраняем модель
model.save('model.h5')
# Загружаем модель
model = keras.models.load_model('model.h5')
Метод save() сохраняет модель в формате HDF5. При загрузке через load_model() восстанавливается полностью готовая к использованию модель. Keras также поддерживает сохранение только архитектуры или только весов модели:
# Сохраняем только архитектуру
with open('model_architecture.json', 'w') as f:
f.write(model.to_json())
# Сохраняем только веса
model.save_weights('model_weights.h5')
При загрузке нужно сначала создать модель с той же архитектурой и потом загрузить веса:
# Загружаем архитектуру
with open('model_architecture.json') as f:
model = keras.models.model_from_json(f.read())
# Загружаем веса
model.load_weights('model_weights.h5')
Разделение архитектуры и весов бывает полезно, когда нужно переиспользовать архитектуру с другим набором весов. Надеюсь, эта статья помогла вам разобраться, как сохранять и загружать модели с помощью сериализации в PyTorch и Keras. Это действительно удобный и важный инструмент в арсенале любого специалиста по машинному обучению. Пробуйте применять его в своих проектах и экспериментируйте! Спасибо за внимание! Если у вас остались вопросы - не стесняйтесь задавать их в комментариях. Желаю вам успехов в освоении сериализации моделей!