TensorFlow中ckpt、frozen model、keras、onnx之间涉及的转换¶
tf下的三种存储格式¶
TensorFlow中有三种模型和数据存储格式
- saved model格式。在tf2.x版本中经常使用,tf2下原生api和keras封装以后的api都支持导出该类格式。saved model格式对应一个文件夹,里面分别存储模型的结构信息和参数信息。
- frozen graph格式。在tf1.x版本中经常使用,tf1下定义的graph可以直接导出。frozen graph格式仅包含单个pb文件,里面有模型的所有信息,一般为了方便推理会把training相关信息去除。
- ckpt格式。用于保存模型的训练时的计算图和权重信息,keras训练时默认会导出这种格式。
一般的推理框架过去都会对frozen model格式进行支持,新的版本会逐渐向onnx这种通用格式迁移。 这时,ckpt格式的存储方法仅仅当做checkpoint来用,不再具备部署的属性。
1 keras 转 onnx 或 frozen graph¶
1.1 keras model 转 onnx¶
对于keras模型,可以使用以下工具转成onnx格式
- tf2onnx 链接:https://github.com/onnx/tensorflow-onnx
这里可能涉及到算子对齐的问题,建议仔细看下onnx的官方文档,注意onnx的opset version。重点关注resize算子的参数是否与keras中的对齐。如果是tf1版本需要重点关注resize算子参数中是否包含了针对tf1旧版本的特殊设置。
1.2 keras model 转 frozen graph¶
keras model 转 frozen graph有两种方法。第一个保存为saved model或者ckpt格式,然后加载通过tf2onnx转为onnx。
第二种方法更为直接,代码如下,这里直接给出多输出多输出下的转换代码,在tf2.3
版本测试通过。主要原理是通过tf.function
再封装一层
# convert model to pb
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import tensorflow as tf
import os
from tensorflow.python.framework import importer
def freeze_keras_model2pb(keras_model, pb_filepath, input_vaiablename_list=None, output_vaiablename_list=None):
"""
karas 模型转pb
:param keras_model:
:param pb_filepath:
:param input_vaiablename_list:
:param output_vaiablename_list:
:return:
"""
assert hasattr(keras_model,'inputs'), "the keras model must be built with functional api or sequential"
# save pb
if input_vaiablename_list is None:
input_vaiablename_list = list()
if output_vaiablename_list is None:
output_vaiablename_list = list()
if len(input_vaiablename_list) == len(keras_model.inputs):
input_vaiable_list = input_vaiablename_list
else:
input_vaiable_list = ['x%d' % i for i in range(len(keras_model.inputs))]
input_funcsignature_list = [tf.TensorSpec(item.shape, dtype=item.dtype, name=name) for name, item in
zip(input_vaiable_list, keras_model.inputs)]
full_model = tf.function(lambda *x: keras_model(x,training=False))
# To obtain an individual graph, use the get_concrete_function method of the callable created by tf.function.
# It can be called with the same arguments as func and returns a special tf.Graph object
concrete_func = full_model.get_concrete_function(input_funcsignature_list)
# Get frozen ConcreteFunction
frozen_graph = convert_variables_to_constants_v2(concrete_func)
graph_def = frozen_graph.graph.as_graph_def()
out_idx = 0
# and len(output_vaiablename_list) > out_idx
# node.name = output_vaiablename_list[out_idx]
for node in graph_def.node:
node.device = ""
if node.name.startswith('Identity'):
out_idx += 1
if len(output_vaiablename_list) == out_idx:
ouput_vaiable_list = output_vaiablename_list
else:
ouput_vaiable_list = ['y%d' % i for i in range(out_idx)]
out_idx = 0
for node in graph_def.node:
node.device = ""
if node.name.startswith('Identity'):
node.name = ouput_vaiable_list[out_idx]
out_idx += 1
new_graph = tf.Graph()
with new_graph.as_default():
importer.import_graph_def(graph_def, name="")
# output_graph_def = tf.compat.v1.graph_util.remove_training_nodes(new_graph.as_graph_def())
return tf.io.write_graph(graph_or_graph_def=new_graph,
logdir=os.path.dirname(pb_filepath),
name=os.path.basename(pb_filepath),
as_text=False), input_vaiable_list, ouput_vaiable_list
# test pb file
def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
graph = tf.Graph()
def _imports_graph_def():
tf.graph_util.import_graph_def(graph_def, name="")
with graph.as_default():
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
if print_graph:
print("-" * 50)
print("Frozen model layers: ")
layers = [op.name for op in import_graph.get_operations()]
for layer in layers:
print(layer)
print("-" * 50)
return wrapped_import.prune(tf.nest.map_structure(import_graph.as_graph_element, inputs),tf.nest.map_structure(import_graph.as_graph_element, outputs))
def pbfile2concrete_function(pbfile,inputs,outputs,print_graph = False):
"""
pbfile 转 concrete function
:param pbfile:
:param inputs:
:param outputs:
:param print_graph:
:return:
"""
with tf.io.gfile.GFile(pbfile, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(graph_def=graph_def,
inputs=inputs,
outputs=outputs,
print_graph=print_graph)
return graph_def, frozen_func
if __name__ == '__main__':
model1 = tf.keras.Sequential([
tf.keras.Input([128]),
tf.keras.layers.Dense(256,activation="relu"),
tf.keras.layers.Dense(256,activation="relu"),
tf.keras.layers.Dense(10,activation="softmax"),
])
model2 = tf.keras.Sequential([
tf.keras.Input([128]),
tf.keras.layers.Dense(256, activation="relu"),
tf.keras.layers.Dense(256, activation="relu"),
tf.keras.layers.Dense(10, activation="softmax"),
])
x1 = tf.keras.layers.Input([128])
x2 = tf.keras.layers.Input([128])
model = tf.keras.Model(inputs = [x1,x2], outputs=[model1(x1), model2(x2)])
_, input_vaiable_list, ouput_vaiable_list = freeze_keras_model2pb(model,"test.pb")
input_vaiable_list = [x+':0' for x in input_vaiable_list]
ouput_vaiable_list = [x+':0' for x in ouput_vaiable_list]
concrete_func_frompb = pbfile2concrete_function("test.pb",input_vaiable_list,ouput_vaiable_list)
x0 = tf.random.normal((2,128), 0, 1)
x1 = tf.random.normal((2,128), 0, 1)
predictions = concrete_func_frompb(x0,x1)
outputs = model((x0,x1))
for item1,item2 in zip(predictions, outputs):
print(item1, item2)
2 ckpt转onnx或frozen model¶
2.1 ckpt转onnx¶
直接使用tf2onnx可以转换
2.2 ckpt转frozen model¶
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.tools import freeze_graph
def freeze_graph(input_checkpoint, output_graph, output_node_names = "Tanh"):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:output_node_names: 多输出使用使用逗号隔开
:return:
'''
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
g= tf.Graph() # 获得默认的图
with g.as_default():
with tf.Session(graph=g) as sess:
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
saver.restore(sess, input_checkpoint) #恢复图并得到数据
# input_graph_def = g.as_graph_def() # 返回一个序列化的图代表当前的图
output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=sess.graph_def,# 等于:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
output_graph_def = tf.graph_util.remove_training_nodes(output_graph_def)
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
for i,n in enumerate(output_graph_def.node):
print("Name of the node - %s" % n.name)
调用示例,注意这里ckpt只写到ckpt,不写后面的.meta
上面方式生成的pb文件可能会有些问题,比如bn的training
不应该作为输入而是一个constant
,下面给出替换bn算子的解决方案。
import sys
from tensorflow.core.framework import graph_pb2
import copy
# load our graph
def load_graph(grapf_filepath):
tf.reset_default_graph()
graph_def = tf.GraphDef()
with tf.gfile.FastGFile(grapf_filepath, 'rb') as f:
graph_def.ParseFromString(f.read())
return graph_def
def replace_placeholder_with_constant(input_grapf_filepath, output_grapf_filepath,
target_node_name, replace_const = None):
if replace_const is None:
replace_const = tf.constant(False, dtype=bool, shape=[], name=target_node_name)
graph_def = load_graph(input_grapf_filepath)
# Create new graph, and rebuild it from original one
# replacing phase train node def with constant
new_graph_def = graph_pb2.GraphDef()
for node in graph_def.node:
if node.name == target_node_name:
new_graph_def.node.extend([replace_const.op.node_def])
else:
new_graph_def.node.extend([copy.deepcopy(node)])
# save new graph
with tf.gfile.GFile(output_grapf_filepath, "wb") as f:
f.write(new_graph_def.SerializeToString())
调用示例,实现把名字Placeholder_2
的输入去除,改为值为False
的constant
3 总结¶
在实际环境下,往往有兼容旧设备或者旧版本的需求,frozen graph尽管在新的设备上不怎么常用但还是需要照顾一下。 此外,对于pb的处理这种hack方式可能也是解决兼容问题的一种途径,比如当你使用新的版本对之前版本重构以后会出现输入和输出名字不对齐的问题,这个时候就需要对pb文件“做手术” onnx作为现在较为通用的部署格式,尽管有些大佬不喜欢,但依然阻挡不住其被各大推理框架支持的脚步。但是onnx中的算子和推理框架中的算子只能说有交集,针对推理框架中不支持的算子,依然需要写plugin或者拆成两个onnx,中间不支持部分采用原生方式的方式来实现。 总而言之,碰到问题解决问题就是了,至于怎么解决,如何解决才是更快更好的方案,这就是经验所在了。