keras回调函数如何使用

免费教程   2024年05月09日 17:42  

这篇文章主要介绍了keras回调函数如何使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇keras回调函数如何使用文章都会有所收获,下面我们一起来看看吧。

回调函数

回调函数是一个对象(实现了特定方法的类实例),它在调用fit()时被传入模型,并在训练过程中的不同时间点被模型调用

可以访问关于模型状态与模型性能的所有可用数据

模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前状态。

提前终止(early stopping):如果验证损失不再改善,则中断训练(当然,同时保存在训练过程中的最佳模型)。

在训练过程中动态调节某些参数值:比如调节优化器的学习率。

在训练过程中记录训练指标和验证指标,或者将模型学到的表示可视化(这些表示在不断更新):fit()进度条实际上就是一个回调函数。

fit()方法中使用callbacks参数#这里有两个callback函数:早停和模型检查点callbacks_list=[.callbacks.EarlyStopping(monitor="val_accuracy",#监控指标patience=2#两轮内不再改善中断训练),.callbacks.ModelCheckpoint(filepath="checkpoint_path",monitor="val_loss",save_best_only=True)]#模型获取model=get_minist_model()model.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",metrics=["accuracy"])model.fit(train_images,train_labels,epochs=10,callbacks=callbacks_list,#该参数使用回调函数validation_data=(val_images,val_labels))test_metrics=model.evaluate(test_images,test_labels)#计算模型在新数据上的损失和指标predictions=model.predict(test_images)#计算模型在新数据上的分类概率模型的保存和加载#也可以在训练完成后手动保存模型,只需调用model.save('my_checkpoint_path')。#重新加载模型model_new=.models.load_model("checkpoint_path.")通过对Callback类子类化来创建自定义回调函数

on_epoch_begin(epoch, logs) ←----在每轮开始时被调用on_epoch_end(epoch, logs) ←----在每轮结束时被调用on_batch_begin(batch, logs) ←----在处理每个批量之前被调用on_batch_end(batch, logs) ←----在处理每个批量之后被调用on_train_begin(logs) ←----在训练开始时被调用on_train_end(logs ←----在训练结束时被调用

frommatplotlibimportpyplotasplt#实现记录每一轮中每个batch训练后的损失,并为每个epoch绘制一个图classLossHistory(keras.callbacks.Callback):defon_train_begin(self,logs):self.per_batch_losses=[]defon_batch_end(self,batch,logs):self.per_batch_losses.append(logs.get("loss"))defon_epoch_end(self,epoch,logs):plt.clf()plt.plot(range(len(self.per_batch_losses)),self.per_batch_losses,label="Traininglossforeachbatch")plt.xlabel(f"Batch(epoch{epoch})")plt.ylabel("Loss")plt.legend()plt.savefig(f"plot_at_epoch_{epoch}")self.per_batch_losses=[]#清空,方便下一轮的技术model=get_mnist_model()model.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",metrics=["accuracy"])model.fit(train_images,train_labels,epochs=10,callbacks=[LossHistory()],validation_data=(val_images,val_labels))【其他】模型的定义 和 数据加载defget_minist_model():inputs=keras.Input(shape=(28*28,))features=layers.Dense(512,activation="relu")(inputs)features=layers.Dropout(0.5)(features)outputs=layers.Dense(10,activation="softmax")(features)model=keras.Model(inputs,outputs)returnmodel#datsetfromtensorflow.keras.datasetsimportmnist(train_images,train_labels),(test_images,test_labels)=mnist.load_data()train_images=train_images.reshape((60000,28*28)).astype("float32")/255test_images=test_images.reshape((10000,28*28)).astype("float32")/255train_images,val_images=train_images[10000:],train_images[:10000]train_labels,val_labels=train_labels[10000:],train_labels[:10000]

关于“keras回调函数如何使用”这篇文章的内容就介绍到这里,感谢各位的阅读!相信大家对“keras回调函数如何使用”知识都有一定的了解,大家如果还想学习更多知识,欢迎关注行业资讯频道。

域名注册
购买VPS主机

您或许对下面这些文章有兴趣:                    本月吐槽辛苦排行榜

看贴要回贴有N种理由!看帖不回贴的后果你懂得的!


评论内容 (*必填):
(Ctrl + Enter提交)   

部落快速搜索栏

各类专题梳理

网站导航栏

X
返回顶部