Saltar al contenido

¿Cómo extraer los hiperparámetros del modelo de spark.ml en PySpark?

Pudiera darse el caso de que halles algún fallo con tu código o trabajo, recuerda probar siempre en un entorno de testing antes aplicar el código al trabajo final.

Solución:

Me encontré con este problema también. Descubrí que necesita llamar a la propiedad java por alguna razón que no sé por qué. Así que solo haz esto:

from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder, CrossValidator
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator

evaluator = RegressionEvaluator(metricName="mae")
lr = LinearRegression()
grid = ParamGridBuilder().addGrid(lr.maxIter, [500]) 
                                .addGrid(lr.regParam, [0]) 
                                .addGrid(lr.elasticNetParam, [1]) 
                                .build()
lr_cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, 
                        evaluator=evaluator, numFolds=3)
lrModel = lr_cv.fit(your_training_set_here)
bestModel = lrModel.bestModel

Imprimiendo los parámetros que quieras:

>>> print 'Best Param (regParam): ', bestModel._java_obj.getRegParam()
0
>>> print 'Best Param (MaxIter): ', bestModel._java_obj.getMaxIter()
500
>>> print 'Best Param (elasticNetParam): ', bestModel._java_obj.getElasticNetParam()
1

Esto se aplica a otros métodos como extractParamMap() también. Deberían arreglar esto pronto.

Esto podría no ser tan bueno como la respuesta de wernerchao (porque no es conveniente almacenar hiperparámetros en variables), pero puede ver rápidamente los mejores hiperparámetros de un modelo de validación cruzada de esta manera:

cvModel.getEstimatorParamMaps()[ np.argmax(cvModel.avgMetrics) ]

Suponiendo que cvModel3Day sea el nombre de su modelo, los parámetros se pueden extraer como se muestra a continuación en Spark Scala

val params = cvModel3Day.bestModel.asInstanceOf[PipelineModel].stages(2).asInstanceOf[GBTClassificationModel].extractParamMap()

val depth = cvModel3Day.bestModel.asInstanceOf[PipelineModel].stages(2).asInstanceOf[GBTClassificationModel].getMaxDepth

val iter = cvModel3Day.bestModel.asInstanceOf[PipelineModel].stages(2).asInstanceOf[GBTClassificationModel].getMaxIter

val bins = cvModel3Day.bestModel.asInstanceOf[PipelineModel].stages(2).asInstanceOf[GBTClassificationModel].getMaxBins

val features  = cvModel3Day.bestModel.asInstanceOf[PipelineModel].stages(2).asInstanceOf[GBTClassificationModel].getFeaturesCol

val step = cvModel3Day.bestModel.asInstanceOf[PipelineModel].stages(2).asInstanceOf[GBTClassificationModel].getStepSize

val samplingRate  = cvModel3Day.bestModel.asInstanceOf[PipelineModel].stages(2).asInstanceOf[GBTClassificationModel].getSubsamplingRate

Eres capaz de añadir valor a nuestro contenido asistiendo con tu experiencia en las acotaciones.

¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)



Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *