ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

推荐系统中双塔模型损失函数设计

2022-02-16 14:31:49  阅读:301  来源: 互联网

标签:函数 模型 双塔 feature geek job vector feat name


loss的设计对系统来说至关重要,最初采用用户侧和商品侧向量的夹角余弦,binary_crossentropy进行优化,无法得到满意的模型效果,依托深度模型强大的拟合能力,特征向量居然全部归零。随后重新对系统改进,主要是1.负样本构建;2.loss函数设计

这里主要讲第二点。

基于距离的损失函数--Hinge Loss

这也是当前推荐系统最常用的loss function,效果之所以好,个人认为还是用了对比学习的思路,能将样本拉的更开。

二元输入,此时的label是0、1 。当y=0时,user和item距离越大loss越小,y=1时,user和item距离越大loss越大,即:

注:因为是基于距离的计算方式,为了保障空间一致性,所以一般情况下是共享网络结果参数,反向更新的时候就通过求导更新就好,当然还得适情况而定了。

现贴出优化后的模型代码,且行且珍惜吧!

def cosine_similarity(x):
    dot1 = K.batch_dot(x[0], x[1], axes=1)
    dot2 = K.batch_dot(x[0], x[0], axes=1)
    dot3 = K.batch_dot(x[1], x[1], axes=1)
    max_ = K.maximum(K.sqrt(dot2 * dot3), K.epsilon())
    return dot1 / max_

def contrastive_loss(y_true, y_pred):
    margin = 1
    return K.mean(y_true * K.square(y_pred) + (1 - y_true) * K.square(K.maximum(margin - y_pred, 0)))

def get_recall_model(geek_features, job_features):
    bert_encoder_shape = (384,)
    vocab_bias = 50
    def model_input(shape,name):
        return Input(shape=shape,name=name,dtype="string")
    def sparse_feat(feat,vocab_size,embedding_dim):
        return Embedding(vocab_size, embedding_dim)(feat)
    def dense_feat(feat):
        return Lambda(lambda x:tf.expand_dims(x, axis=2))(feat)
    def embedd_feat(shape,name):
        return Input(shape=shape, name=name)
    def hash_bucket(x, vocab_size_max):
        return Lambda(lambda x: tf.strings.to_hash_bucket_fast(x, vocab_size_max - 1) + 1)(x)
    def euclidean_distance(vects):
        x, y = vects
        return K.sqrt(K.sum(K.square(x - y), axis=1, keepdims=True))

    geek_feats = []
    job_feats = []
    for each in geek_features:
        geek_feats.append(model_input(shape=(None,),name=each))
    for each in job_features:
        job_feats.append(model_input(shape=(None,),name=each))

    geek_hash_feats = [hash_bucket(e, len(data[feat_name].value_counts())+vocab_bias) for e,feat_name in zip(geek_feats,geek_features)]
    job_hash_feats = [hash_bucket(e, len(data[feat_name].value_counts())+vocab_bias) for e,feat_name in zip(job_feats,job_features)]

    geek_feature_inputs = [sparse_feat(e, len(data[feat_name].value_counts())+vocab_bias, 64) for e,feat_name in zip(geek_hash_feats,geek_features)]
    geek_feature_columns = [Lambda(lambda x:tf.squeeze(x,[1]))(e) for e in geek_feature_inputs]
    query_feature_columns = [embedd_feat(shape=bert_encoder_shape,name="query_embedding")]
    job_feature_inputs = [sparse_feat(e, len(data[feat_name].value_counts())+vocab_bias, 64) for e,feat_name in zip(job_hash_feats,job_features)]
    job_feature_columns = [Lambda(lambda x:tf.squeeze(x,[1]))(e) for e in job_feature_inputs]
    title_feature_columns = [embedd_feat(shape=bert_encoder_shape,name="title_embedding")]

    # geek tower
    geek_vector_tmp = Lambda(lambda x:K.concatenate(x, axis=-1))(geek_feature_columns+query_feature_columns)
    geek_vector = Dense(128, activation="relu")(geek_vector_tmp)
    geek_vector = Dense(64, activation="relu",kernel_regularizer="l2",name="geek_vector")(geek_vector)

    # job tower
    job_vector_tmp = Lambda(lambda x:K.concatenate(x, axis=-1))(job_feature_columns+title_feature_columns)
    job_vector = Dense(128, activation="relu")(job_vector_tmp)
    job_vector = Dense(64, activation="relu",kernel_regularizer="l2",name="job_vector")(job_vector)

    dot_geek_job = Lambda(lambda x:tf.multiply(x[0],x[1]))([geek_vector, job_vector])
    dot_geek_job = Lambda(lambda x:tf.reduce_sum(x, axis=1))(dot_geek_job)
    dot_geek_job = Lambda(lambda x:tf.expand_dims(x,1))(dot_geek_job)

    geek_job_distance = Lambda(euclidean_distance, name="output")([geek_vector, job_vector])
    model = Model(inputs=geek_feats+job_feats+query_feature_columns+title_feature_columns, outputs=geek_job_distance,name="merge")
    return model

 

标签:函数,模型,双塔,feature,geek,job,vector,feat,name
来源: https://www.cnblogs.com/demo-deng/p/15900373.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有