首页 \ 问答 \ Keras替换输入层(Keras replacing input layer)

Keras替换输入层(Keras replacing input layer)

我拥有的代码(我无法更改)使用带有my_input_tensor作为input_tensor。

model1 = keras.applications.resnet50.ResNet50(input_tensor=my_input_tensor, weights='imagenet')

调查源代码 ,ResNet50函数使用my_input_tensor创建一个新的keras输入层,然后创建模型的其余部分。 这是我想用自己的模型复制的行为。 我从h5文件加载我的模型。

model2 = keras.models.load_model('my_model.h5')

由于此模型已经有一个输入层,我想用一个用my_input_tensor定义的新输入层替换它。

如何更换输入图层?


The code that I have (that I can't change) uses the Resnet with my_input_tensor as the input_tensor.

model1 = keras.applications.resnet50.ResNet50(input_tensor=my_input_tensor, weights='imagenet')

Investigating the source code, ResNet50 function creates a new keras Input Layer with my_input_tensor and then create the rest of the model. This is the behavior that I want to copy with my own model. I load my model from h5 file.

model2 = keras.models.load_model('my_model.h5')

Since this model already has an Input Layer, I want to replace it with a new Input Layer defined with my_input_tensor.

How can I replace an input layer?


原文:https://stackoverflow.com/questions/49546922
更新时间:2019-12-13 08:09

最满意答案

使用以下方法保存模型时:

old_model.save('my_model.h5')

它会节省以下内容:

  1. 模型的体系结构,允许创建模型。
  2. 模型的权重。
  3. 模型的训练配置(丢失,优化器)。
  4. 优化器的状态,允许训练从您之前离开的地方恢复。

那么,当你加载模型时:

res50_model = load_model('my_model.h5')

你应该得到相同的型号,你可以使用以下方法验证相同:

res50_model.summary()
res50_model.get_weights()

现在您可以弹出输入图层并使用以下命令添加自己的:

res50_model.layers.pop(0)
res50_model.summary()

添加新的输入图层:

newInput = Input(batch_shape=(0,299,299,3))    # let us say this new InputLayer
newOutputs = res50_model(newInput)
newModel = Model(newInput, newOutputs)

newModel.summary()
res50_model.summary()

When you saved your model using:

old_model.save('my_model.h5')

it will save following:

  1. The architecture of the model, allowing to create the model.
  2. The weights of the model.
  3. The training configuration of the model (loss, optimizer).
  4. The state of the optimizer, allowing training to resume from where you left before.

So then, when you load the model:

res50_model = load_model('my_model.h5')

you should get the same model back, you can verify the same using:

res50_model.summary()
res50_model.get_weights()

Now you can, pop the input layer and add your own using:

res50_model.layers.pop(0)
res50_model.summary()

add new input layer:

newInput = Input(batch_shape=(0,299,299,3))    # let us say this new InputLayer
newOutputs = res50_model(newInput)
newModel = Model(newInput, newOutputs)

newModel.summary()
res50_model.summary()
2018-03-29

相关文章

更多

最新问答

更多
  • 删除不适用于JSP中使用for循环的每个id(Deletion not working for every id using for loop in JSP)
  • 如何从std :: filesystem :: path中删除引号(How to remove quotation marks from std::filesystem::path)
  • 验证多个控制器方法的URL路径(Validate URL path for several controller methods)
  • 如何在datarow []中的列中找到最大值?(How to find max value in a column in a datarow[] ?)
  • 如何使用预定义文本替换来自数据库的部分结果(How do I replace part of result coming from Database with predefined text)
  • Selenium Java注入了新的Javascript函数(Selenium Java inject new Javascript function)
  • 使用.on的多个下拉菜单选择文本仅适用于第一个下拉列表(Multiple Dropdowns Menu Selection text using .on works only on first dropdown)
  • 快速将黄土曲线添加到大型数据集图中的方法(Quick way to add loess curve to large data set graph)
  • FilteringSelect in mvc(FilteringSelect in mvc)
  • 在Delphi XE2中开发Mac或iOS应用程序需要哪些硬件/软件?(What hardware/software is necessary to develop Mac or iOS apps in Delphi XE2?)
  • 在原型的构造函数中初始化属性时获取“未定义”(Getting 'undefined' when a property is initialized in the constructor of a prototype)
  • 通过越狱加载的应用程序的Documents文件夹位置(Location of Documents folder for an app loaded via jailbreak)
  • 在OpenGL中使用可编程和固定管道功能(Using both programmable and fixed pipeline functionality in OpenGL)
  • 将任何用户输入重定向到单独的底层程序(redirect any user input to a separate underlying program)
  • 编辑文本不能正常工作android(Edit texts not working properly android)
  • “user_denied”Facebook应用页面上的Facebook用户区域设置(Facebook user locale on “user_denied” facebook app page)
  • 在大图像中找到小的部分透明图像的坐标(find coordinates of small partially-transparent image within a large image)
  • 我如何在cakephp 3.1中获得完整的相对路径?(How i can get full relative path of image in cakephp 3.1?)
  • 如何保存拖动标记的新本地化?(How to save new localization of dragged marker?)
  • MySQL UPDATE vs INSERT和DELETE(MySQL UPDATE vs INSERT and DELETE)
  • 在执行查询之前,在SQLAlchemy模型中将datetime转换为unix时间戳?(Convert datetime to unix timestamp in SQLAlchemy model before executing query?)
  • OpenCL与OpenGL互操作的优势(Advantage of OpenCL interoperability with OpenGL)
  • 如何解析用点和等分隔的数据然后添加到listview(How to parsing data from delimited with dot and equal then add to listview)
  • 带调试输出的X3解析器段错误(BOOST_SPIRIT_X3_DEBUG)(X3 parser segfaults with debug output (BOOST_SPIRIT_X3_DEBUG))
  • 将文件夹名称添加到fgrep结果(Add folder name to fgrep result)
  • 在MySQL中加载一个表是非常慢的(Loading one table in MySQL is ridiculously slow)
  • 如何将JSON放入PHP变量?(How do I put JSON into a PHP Variable?)
  • 如何绕过Microsoft.Speech.Recognition中的不流畅?(How to bypass disfluencies in Microsoft.Speech.Recognition?)
  • 原点的最后一行是什么?(What is the last row of an origin for?)
  • 将字符串转换为javascript中的操作(Translate String to operation in javascript)