scikit-learn: Guardar y restaurar modelos

s

En muchas ocasiones, mientras se trabaja con el scikit-learn biblioteca, deberá guardar sus modelos de predicción en un archivo y luego restaurarlos para poder reutilizar su trabajo anterior para: probar su modelo en nuevos datos, comparar múltiples modelos o cualquier otra cosa. Este procedimiento de guardado también se conoce como serialización de objetos: representa un objeto con un flujo de bytes para almacenarlo en el disco, enviarlo a través de una red o guardarlo en una base de datos, mientras que el procedimiento de restauración se conoce como deserialización. En este artículo, analizamos tres formas posibles de hacer esto en Python y scikit-learn, cada una presentada con sus pros y sus contras.

Herramientas para guardar y restaurar modelos

La primera herramienta que describimos es Pepinillo, la herramienta estándar de Python para la (des) serialización de objetos. Luego, miramos el Joblib biblioteca que ofrece (des) serialización fácil de objetos que contienen grandes matrices de datos, y finalmente presentamos un enfoque manual para guardar y restaurar objetos hacia / desde JSON (Notación de objetos JavaScript). Ninguno de estos enfoques representa una solución óptima, pero se debe elegir el ajuste correcto de acuerdo con las necesidades de su proyecto.

Inicialización del modelo

Inicialmente, creemos un modelo de scikit-learn. En nuestro ejemplo usaremos un Regresión logística modelo y el Conjunto de datos de iris. Vamos a importar las bibliotecas necesarias, cargar los datos y dividirlos en conjuntos de prueba y entrenamiento.

from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load and split data
data = load_iris()
Xtrain, Xtest, Ytrain, Ytest = train_test_split(data.data, data.target, test_size=0.3, random_state=4)

Ahora creemos el modelo con algunos parámetros no predeterminados y ajustémoslo a los datos de entrenamiento. Suponemos que ha encontrado previamente los parámetros óptimos del modelo, es decir, los que producen la mayor precisión estimada.

# Create a model
model = LogisticRegression(C=0.1, 
                           max_iter=20, 
                           fit_intercept=True, 
                           n_jobs=3, 
                           solver="liblinear")
model.fit(Xtrain, Ytrain)

Y nuestro modelo resultante:

LogisticRegression(C=0.1, class_weight=None, dual=False, fit_intercept=True,
    intercept_scaling=1, max_iter=20, multi_class="ovr", n_jobs=3,
    penalty='l2', random_state=None, solver="liblinear", tol=0.0001,
    verbose=0, warm_start=False)

Utilizando la fit método, el modelo ha aprendido sus coeficientes que se almacenan en model.coef_. El objetivo es guardar los parámetros y coeficientes del modelo en un archivo, por lo que no es necesario repetir los pasos de entrenamiento del modelo y optimización de parámetros nuevamente con datos nuevos.

Módulo de pepinillos

En las siguientes líneas de código, el modelo que creamos en el paso anterior se guarda en un archivo y luego se carga como un nuevo objeto llamado pickled_model. A continuación, el modelo cargado se utiliza para calcular la puntuación de precisión y predecir los resultados de los nuevos datos no vistos (de prueba).

import pickle

#
# Create your model here (same as above)
#

# Save to file in the current working directory
pkl_filename = "pickle_model.pkl"
with open(pkl_filename, 'wb') as file:
    pickle.dump(model, file)

# Load from file
with open(pkl_filename, 'rb') as file:
    pickle_model = pickle.load(file)
    
# Calculate the accuracy score and predict target values
score = pickle_model.score(Xtest, Ytest)
print("Test score: {0:.2f} %".format(100 * score))
Ypredict = pickle_model.predict(Xtest)

La ejecución de este código debería generar su puntuación y guardar el modelo a través de Pickle:

$ python save_model_pickle.py
Test score: 91.11 %

Lo mejor de usar Pickle para guardar y restaurar nuestros modelos de aprendizaje es que es rápido: puede hacerlo en dos líneas de código. Es útil si ha optimizado los parámetros del modelo en los datos de entrenamiento, por lo que no necesita repetir este paso nuevamente. De todos modos, no guarda los resultados de la prueba ni ningún dato. Aún así, puede hacer esto guardando una tupla o una lista de varios objetos (y recuerde qué objeto va a dónde), de la siguiente manera:

tuple_objects = (model, Xtrain, Ytrain, score)

# Save tuple
pickle.dump(tuple_objects, open("tuple_model.pkl", 'wb'))

# Restore tuple
pickled_model, pickled_Xtrain, pickled_Ytrain, pickled_score = pickle.load(open("tuple_model.pkl", 'rb'))

Módulo Joblib

La biblioteca Joblib está destinada a ser un reemplazo de Pickle, para objetos que contienen datos grandes. Repetiremos el procedimiento de guardar y restaurar como con Pickle.

from sklearn.externals import joblib

# Save to file in the current working directory
joblib_file = "joblib_model.pkl"
joblib.dump(model, joblib_file)

# Load from file
joblib_model = joblib.load(joblib_file)

# Calculate the accuracy and predictions
score = joblib_model.score(Xtest, Ytest)
print("Test score: {0:.2f} %".format(100 * score))
Ypredict = pickle_model.predict(Xtest)
$ python save_model_joblib.py
Test score: 91.11 %

Como se ve en el ejemplo, la biblioteca Joblib ofrece un flujo de trabajo un poco más simple en comparación con Pickle. Si bien Pickle requiere que se pase un objeto de archivo como argumento, Joblib funciona tanto con objetos de archivo como con nombres de archivo de cadena. En caso de que su modelo contenga grandes conjuntos de datos, cada conjunto se almacenará en un archivo separado, pero el procedimiento de guardar y restaurar seguirá siendo el mismo. Joblib también permite diferentes métodos de compresión, como ‘zlib’, ‘gzip’, ‘bz2’ y diferentes niveles de compresión.

Guardar y restaurar manualmente a JSON

Dependiendo de su proyecto, muchas veces encontrará Pickle y Joblib como soluciones inadecuadas. Algunas de estas razones se analizan más adelante en la sección Problemas de compatibilidad. De todos modos, siempre que desee tener un control total sobre el proceso de guardar y restaurar, la mejor manera es crear sus propias funciones manualmente.

A continuación, se muestra un ejemplo de cómo guardar y restaurar objetos manualmente mediante JSON. Este enfoque nos permite seleccionar los datos que deben guardarse, como los parámetros del modelo, los coeficientes, los datos de entrenamiento y cualquier otra cosa que necesitemos.

Como queremos guardar todos estos datos en un solo objeto, una forma posible de hacerlo es crear una nueva clase que herede de la clase modelo, que en nuestro ejemplo es LogisticRegression. La nueva clase, llamada MyLogReg, luego implementa los métodos save_json y load_json para guardar y restaurar a / desde un archivo JSON, respectivamente.

Para simplificar, guardaremos solo tres parámetros del modelo y los datos de entrenamiento. Algunos datos adicionales que podríamos almacenar con este enfoque son, por ejemplo, una puntuación de validación cruzada en el conjunto de entrenamiento, datos de prueba, puntuación de precisión en los datos de prueba, etc.

import json
import numpy as np

class MyLogReg(LogisticRegression):
    
    # Override the class constructor
    def __init__(self, C=1.0, solver="liblinear", max_iter=100, X_train=None, Y_train=None):
        LogisticRegression.__init__(self, C=C, solver=solver, max_iter=max_iter)
        self.X_train = X_train
        self.Y_train = Y_train
        
    # A method for saving object data to JSON file
    def save_json(self, filepath):
        dict_ = {}
        dict_['C'] = self.C
        dict_['max_iter'] = self.max_iter
        dict_['solver'] = self.solver
        dict_['X_train'] = self.X_train.tolist() if self.X_train is not None else 'None'
        dict_['Y_train'] = self.Y_train.tolist() if self.Y_train is not None else 'None'
        
        # Creat json and save to file
        json_txt = json.dumps(dict_, indent=4)
        with open(filepath, 'w') as file:
            file.write(json_txt)
    
    # A method for loading data from JSON file
    def load_json(self, filepath):
        with open(filepath, 'r') as file:
            dict_ = json.load(file)
            
        self.C = dict_['C']
        self.max_iter = dict_['max_iter']
        self.solver = dict_['solver']
        self.X_train = np.asarray(dict_['X_train']) if dict_['X_train'] != 'None' else None
        self.Y_train = np.asarray(dict_['Y_train']) if dict_['Y_train'] != 'None' else None
        

Ahora probemos el MyLogReg clase. Primero creamos un objeto mylogreg, pasarle los datos de entrenamiento y guardarlo en un archivo. Luego creamos un nuevo objeto json_mylogreg y llama al load_json método para cargar los datos del archivo.

filepath = "mylogreg.json"

# Create a model and train it
mylogreg = MyLogReg(X_train=Xtrain, Y_train=Ytrain)
mylogreg.save_json(filepath)

# Create a new object and load its data from JSON file
json_mylogreg = MyLogReg()
json_mylogreg.load_json(filepath)
json_mylogreg

Al imprimir el nuevo objeto, podemos ver nuestros parámetros y datos de entrenamiento según sea necesario.

MyLogReg(C=1.0,
     X_train=array([[ 4.3,  3. ,  1.1,  0.1],
       [ 5.7,  4.4,  1.5,  0.4],
       ...,
       [ 7.2,  3. ,  5.8,  1.6],
       [ 7.7,  2.8,  6.7,  2. ]]),
     Y_train=array([0, 0, ..., 2, 2]), class_weight=None, dual=False,
     fit_intercept=True, intercept_scaling=1, max_iter=100,
     multi_class="ovr", n_jobs=1, penalty='l2', random_state=None,
     solver="liblinear", tol=0.0001, verbose=0, warm_start=False)

Dado que la serialización de datos usando JSON realmente guarda el objeto en un formato de cadena, en lugar de un flujo de bytes, el archivo ‘mylogreg.json’ podría abrirse y modificarse con un editor de texto. Aunque este enfoque sería conveniente para el desarrollador, es menos seguro ya que un intruso puede ver y modificar el contenido del archivo JSON. Además, este enfoque es más adecuado para objetos con una pequeña cantidad de variables de instancia, como los modelos de scikit-learn, porque cualquier adición de nuevas variables requiere cambios en los métodos de guardar y restaurar.

Problemas de compatibilidad

Si bien algunos de los pros y los contras de cada herramienta se cubrieron en el texto hasta ahora, probablemente el mayor inconveniente de las herramientas Pickle y Joblib es su compatibilidad con diferentes modelos y versiones de Python.

Compatibilidad con la versión de Python: la documentación de ambas herramientas indica que no se recomienda (des) serializar objetos en diferentes versiones de Python, aunque podría funcionar con cambios menores de versión.

Compatibilidad del modelo: uno de los errores más frecuentes es guardar su modelo con Pickle o Joblib y luego cambiar el modelo antes de intentar restaurar desde un archivo. La estructura interna del modelo debe permanecer sin cambios entre guardar y recargar.

Un último problema con Pickle y Joblib está relacionado con la seguridad. Ambas herramientas pueden contener código malicioso, por lo que no se recomienda restaurar datos de fuentes no confiables o no autenticadas.

Conclusiones

En esta publicación describimos tres herramientas para guardar y restaurar modelos de scikit-learn. Las bibliotecas Pickle y Joblib son rápidas y fáciles de usar, pero tienen problemas de compatibilidad en diferentes versiones de Python y cambios en el modelo de aprendizaje. Por otro lado, el enfoque manual es más difícil de implementar y debe modificarse con cualquier cambio en la estructura del modelo, pero en el lado positivo podría adaptarse fácilmente a varias necesidades y no tiene problemas de compatibilidad.

 

About the author

Ramiro de la Vega

Bienvenido a Pharos.sh

Soy Ramiro de la Vega, Estadounidense con raíces Españolas. Empecé a programar hace casi 20 años cuando era muy jovencito.

Espero que en mi web encuentres la inspiración y ayuda que necesitas para adentrarte en el fantástico mundo de la programación y conseguir tus objetivos por difíciles que sean.

Add comment

Sobre mi

Últimos Post

Etiquetas

Esta web utiliza cookies propias para su correcto funcionamiento. Al hacer clic en el botón Aceptar, aceptas el uso de estas tecnologías y el procesamiento de tus datos para estos propósitos. Más información
Privacidad