# convert model to pb
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import tensorflow as tf
import os
from tensorflow.python.framework import importer
def freeze_keras_model2pb(keras_model, pb_filepath, input_vaiablename_list=None, output_vaiablename_list=None):
"""
karas 模型转pb
:param keras_model:
:param pb_filepath:
:param input_vaiablename_list:
:param output_vaiablename_list:
:return:
"""
assert hasattr(keras_model,'inputs'), "the keras model must be built with functional api or sequential"
# save pb
if input_vaiablename_list is None:
input_vaiablename_list = list()
if output_vaiablename_list is None:
output_vaiablename_list = list()
if len(input_vaiablename_list) == len(keras_model.inputs):
input_vaiable_list = input_vaiablename_list
else:
input_vaiable_list = ['x%d' % i for i in range(len(keras_model.inputs))]
input_funcsignature_list = [tf.TensorSpec(item.shape, dtype=item.dtype, name=name) for name, item in
zip(input_vaiable_list, keras_model.inputs)]
full_model = tf.function(lambda *x: keras_model(x,training=False))
# To obtain an individual graph, use the get_concrete_function method of the callable created by tf.function.
# It can be called with the same arguments as func and returns a special tf.Graph object
concrete_func = full_model.get_concrete_function(input_funcsignature_list)
# Get frozen ConcreteFunction
frozen_graph = convert_variables_to_constants_v2(concrete_func)
graph_def = frozen_graph.graph.as_graph_def()
out_idx = 0
# and len(output_vaiablename_list) > out_idx
# node.name = output_vaiablename_list[out_idx]
for node in graph_def.node:
node.device = ""
if node.name.startswith('Identity'):
out_idx += 1
if len(output_vaiablename_list) == out_idx:
ouput_vaiable_list = output_vaiablename_list
else:
ouput_vaiable_list = ['y%d' % i for i in range(out_idx)]
out_idx = 0
for node in graph_def.node:
node.device = ""
if node.name.startswith('Identity'):
node.name = ouput_vaiable_list[out_idx]
out_idx += 1
new_graph = tf.Graph()
with new_graph.as_default():
importer.import_graph_def(graph_def, name="")
# output_graph_def = tf.compat.v1.graph_util.remove_training_nodes(new_graph.as_graph_def())
return tf.io.write_graph(graph_or_graph_def=new_graph,
logdir=os.path.dirname(pb_filepath),
name=os.path.basename(pb_filepath),
as_text=False), input_vaiable_list, ouput_vaiable_list
# test pb file
def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
graph = tf.Graph()
def _imports_graph_def():
tf.graph_util.import_graph_def(graph_def, name="")
with graph.as_default():
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
if print_graph:
print("-" * 50)
print("Frozen model layers: ")
layers = [op.name for op in import_graph.get_operations()]
for layer in layers:
print(layer)
print("-" * 50)
return wrapped_import.prune(tf.nest.map_structure(import_graph.as_graph_element, inputs),tf.nest.map_structure(import_graph.as_graph_element, outputs))
def pbfile2concrete_function(pbfile,inputs,outputs,print_graph = False):
"""
pbfile 转 concrete function
:param pbfile:
:param inputs:
:param outputs:
:param print_graph:
:return:
"""
with tf.io.gfile.GFile(pbfile, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(graph_def=graph_def,
inputs=inputs,
outputs=outputs,
print_graph=print_graph)
return graph_def, frozen_func
if __name__ == '__main__':
model1 = tf.keras.Sequential([
tf.keras.Input([128]),
tf.keras.layers.Dense(256,activation="relu"),
tf.keras.layers.Dense(256,activation="relu"),
tf.keras.layers.Dense(10,activation="softmax"),
])
model2 = tf.keras.Sequential([
tf.keras.Input([128]),
tf.keras.layers.Dense(256, activation="relu"),
tf.keras.layers.Dense(256, activation="relu"),
tf.keras.layers.Dense(10, activation="softmax"),
])
x1 = tf.keras.layers.Input([128])
x2 = tf.keras.layers.Input([128])
model = tf.keras.Model(inputs = [x1,x2], outputs=[model1(x1), model2(x2)])
_, input_vaiable_list, ouput_vaiable_list = freeze_keras_model2pb(model,"test.pb")
input_vaiable_list = [x+':0' for x in input_vaiable_list]
ouput_vaiable_list = [x+':0' for x in ouput_vaiable_list]
concrete_func_frompb = pbfile2concrete_function("test.pb",input_vaiable_list,ouput_vaiable_list)
x0 = tf.random.normal((2,128), 0, 1)
x1 = tf.random.normal((2,128), 0, 1)
predictions = concrete_func_frompb(x0,x1)
outputs = model((x0,x1))
for item1,item2 in zip(predictions, outputs):
print(item1, item2)