ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Tensorflow静态图pb(frozen graph)模型保存与调用

2020-12-08 14:34:08  阅读:777  来源: 互联网

标签:frozen graph pb names input model def


pb模型保存

基于tf2

model = ...

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="frozen_graph.pb",
                  as_text=False)

基于keras (tf1)

from tensorflow.keras import backend as K

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""

        frozen_graph = graph_util.convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names)
        if not clear_devices:
            for node in frozen_graph.node:
                node.device = "/GPU:0"
        return frozen_graph


# load model
model = keras.models.model_from_json(...)


# save pb model
out_path = 'model.pb'
input_names = [n.op.name for n in model.inputs]
output_names = [n.op.name for n in model.outputs]
print(input_names, output_names)
frozen_graph = freeze_session(K.get_session(), output_names=output_names,clear_devices=clear_devices)
with open(out_path, "wb") as f:
    f.write(frozen_graph.SerializeToString())

模型调用

这里以tf1为例:

from tensorflow.compat.v1 import Graph, GraphDef, import_graph_def, Session
from tensorflow.compat.v1.gfile import GFile

frozen_graph =  "model.pb"
# import graph
with GFile(frozen_graph, "rb") as f:
    graph_def = GraphDef()
    graph_def.ParseFromString(f.read())
with Graph().as_default() as graph:
    import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=""
                     )

# set input output
x = graph.get_tensor_by_name("input:0")
y1 = graph.get_tensor_by_name("output1:0")
y2 = graph.get_tensor_by_name("output1:0")
sess = Session(graph=graph)

# get batch_input
batch_image = np.zeros([1, 512, 512, 3])
# get ...

# predict
feed_dict_testing = {x: batch_image}
output1, output2 = sess.run([y1, y2], feed_dict=feed_dict_testing)

 

标签:frozen,graph,pb,names,input,model,def
来源: https://blog.csdn.net/dou3516/article/details/110871797

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有