Сериализация моделей

Привет! Сегодня я хочу поделиться с вами своим опытом сохранения и загрузки моделей машинного обучения с помощью сериализации.

Сериализация моделей
Краткое содержание

Это очень важная тема, ведь часто нам нужно сохранить уже обученную модель, чтобы потом быстро её загрузить и использовать для предсказаний, не тратя время на повторное обучение.

Сериализация моделей позволяет сохранить всю информацию о модели (архитектуру, веса, состояние оптимизатора и т.д.) в файл, который потом можно легко загрузить обратно. Давайте разберём этот процесс на примере популярных фреймворков - PyTorch и Keras. Начнём с PyTorch. Допустим, у нас есть обученная модель model. Чтобы сохранить её, достаточно вызвать метод torch.save():

import torch

# Сохраняем модель
torch.save(model.state_dict(), 'model.pth')

Здесь model.state_dict() возвращает словарь с состоянием модели, который мы сохраняем в файл 'model.pth'. Теперь, чтобы загрузить модель, создадим экземпляр класса модели и загрузим в него сохранённое состояние:

# Создаём экземпляр модели
model = MyModel(*args, **kwargs)

# Загружаем состояние
model.load_state_dict(torch.load('model.pth'))
model.eval()

Вызов model.eval() переводит модель в режим инференса. Всё, модель готова к использованию! Теперь посмотрим, как сериализация работает в Keras. Здесь всё ещё проще:

from tensorflow import keras

# Сохраняем модель
model.save('model.h5')

# Загружаем модель
model = keras.models.load_model('model.h5')

Метод save() сохраняет модель в формате HDF5. При загрузке через load_model() восстанавливается полностью готовая к использованию модель. Keras также поддерживает сохранение только архитектуры или только весов модели:

# Сохраняем только архитектуру
with open('model_architecture.json', 'w') as f:
    f.write(model.to_json())

# Сохраняем только веса 
model.save_weights('model_weights.h5')

При загрузке нужно сначала создать модель с той же архитектурой и потом загрузить веса:

# Загружаем архитектуру
with open('model_architecture.json') as f:
    model = keras.models.model_from_json(f.read())

# Загружаем веса
model.load_weights('model_weights.h5')

Разделение архитектуры и весов бывает полезно, когда нужно переиспользовать архитектуру с другим набором весов. Надеюсь, эта статья помогла вам разобраться, как сохранять и загружать модели с помощью сериализации в PyTorch и Keras. Это действительно удобный и важный инструмент в арсенале любого специалиста по машинному обучению. Пробуйте применять его в своих проектах и экспериментируйте! Спасибо за внимание! Если у вас остались вопросы - не стесняйтесь задавать их в комментариях. Желаю вам успехов в освоении сериализации моделей!

Подписаться на новости Nerd IT

Не пропустите последние выпуски. Зарегистрируйтесь сейчас, чтобы получить полный доступ к статьям.
jamie@example.com
Подписаться