ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

使用微调后的Bert模型做编码器进行文本特征向量抽取

2021-04-14 18:34:07  阅读:312  来源: 互联网

标签:Bert 编码器 特征向量 bert graph tf input model config


      通常,我们使用bert做文本分类,泛化性好、表现优秀。在进行文本相似性计算任务时,往往是对语料训练词向量,再聚合文本向量embedding数据,计算相似度;但是,word2vec是静态词向量,表征能力有限,此时,可以用已进行特定环境下训练的bert模型,抽取出cls向量作为整个句子的表征向量以供下游任务使用,可以说是一个附加产物;主要流程如下:

1)加载ckpt模型
2)确定输出tensor名称,在bert中,cls的名称为:bert/pooler/dense/Tanh(而不是SoftMax)
3)存储为pb model

主代码:

def extract_bert_vector():
    """ 抽取bert 768 特征向量
    :return:
    """
    OUTPUT_GRAPH = 'pb_model/bert_encoder.pb'
    output_node = ["bert/pooler/dense/Tanh"]
    ckpt_model = r'output'
    bert_config_file = r'chinese_L-12_H-768_A-12/bert_config.json'
    max_seq_length = 200

    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True
    sess = tf.Session(config=gpu_config)
    graph = tf.get_default_graph()
    with open(r'data/file_dict.json', 'r') as fr:
        label_list = json.load(fr)
    with graph.as_default():
        print("going to restore checkpoint")
        input_ids_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_ids")
        input_mask_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_mask")
        bert_config = modeling.BertConfig.from_json_file(bert_config_file)
        (loss, per_example_loss, logits, probabilities) = create_model(
            bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p,
            segment_ids=None, labels=None, num_labels=len(label_list), use_one_hot_embeddings=False)
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(ckpt_model))
        graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node)
        with tf.gfile.GFile(OUTPUT_GRAPH, "wb") as f:
            f.write(graph.SerializeToString())
        print('extract vector pb model saved!')

 

标签:Bert,编码器,特征向量,bert,graph,tf,input,model,config
来源: https://www.cnblogs.com/demo-deng/p/14659357.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有