ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

代码智能:问题与解法

2021-07-06 14:33:37  阅读:252  来源: 互联网

标签:String -- 代码 vda 智能 train new 解法


代码智能:问题与解法

在基于预训练大模型引发自然语言处理革命的今天,代码智能技术也在迅速跟进发展。
那么,代码智能主要在做一些什么样的事情呢?可能很多同学会有比较科幻的想法,比如程序员要失业了之类的。
但是,其实很多工作并没有那么神秘,非常基础。那么我们用代码智能要解决什么问题呢?

  • 判断两段代码是不是实现相似的功能
  • 搜索跟当前代码段最相似的代码
  • 检测代码是否有bug
  • 自动修复代码中的bug
  • 给一段代码自动写注释
  • 根据文本推荐最相似的代码段
  • 根据文本生成代码

看了之后是不是觉得更玄幻了?这么困难的问题怎么搞得定?
诚实地讲,这其中的每个子问题都很困难,就算是人类学习起来也很困难。
不过,正像是人类也是一步一步学会的一样,机器也在不断地进步。我们需要的不一定是万能的机器神,也是和我们一样普通的机器人,它们有很大的局限,但是它们可以帮助我们减轻不少工作量。

而且,最后一节我们将揭晓,处理这么多如此复杂问题的方法,却非常简单,一把梭哈,我们只用一个模型就能搞定。

codeBert

下面我们就详细看一看这些问题的细节。

问题:克隆检测 Clone Detection

万地高楼平地起,代码智能任务首先从克隆检测做起。
所谓克隆检测,就是寻找写法和功能上相似的代码。
不要小看代码重复,它会显著地降低代码智能训练的有效性。
我们看下图,训练集中有重复,测试集中有重复,它们的交集中仍然有重复,在论文《The Adverse Effects of Code Duplication in Machine Learning Models of Code》中有详细的分析。

code duplicate

预测两段代码是否相似

以下的例子来自BigCloneBench数据集. 论文地址在:https://arxiv.org/pdf/2002.08653.pdf

下面我们举几个例子来看什么算相似:

代码1:

    private StringBuffer encoder(String arg) {
        if (arg == null) {
            arg = "";
        }
        MessageDigest md5 = null;
        try {
            md5 = MessageDigest.getInstance("MD5");
            md5.update(arg.getBytes(SysConstant.charset));
        } catch (Exception e) {
            e.printStackTrace();
        }
        return toHex(md5.digest());
    }

代码2:

    public String kodetu(String testusoila) {
        MessageDigest md = null;
        try {
            md = MessageDigest.getInstance("SHA");
            md.update(testusoila.getBytes("UTF-8"));
        } catch (NoSuchAlgorithmException e) {
            new MezuLeiho("Ez da zifraketa algoritmoa aurkitu", "Ados", "Zifraketa Arazoa", JOptionPane.ERROR_MESSAGE);
            e.printStackTrace();
        } catch (UnsupportedEncodingException e) {
            new MezuLeiho("Errorea kodetzerakoan", "Ados", "Kodeketa Errorea", JOptionPane.ERROR_MESSAGE);
            e.printStackTrace();
        }
        byte raw[] = md.digest();
        String hash = (new BASE64Encoder()).encode(raw);
        return hash;
    }

代码2的字符串是用巴斯克语写的。它们用的算法也有区别,判空和异常处理也有不同,但是我们认为它们是很类似的,属于克隆识别认为相同或高度相似的。

我们再看一对例子:

代码1:

    public static void test(String args[]) {
        int trace;
        int bytes_read = 0;
        int last_contentLenght = 0;
        try {
            BufferedReader reader;
            URL url;
            url = new URL(args[0]);
            URLConnection istream = url.openConnection();
            last_contentLenght = istream.getContentLength();
            reader = new BufferedReader(new InputStreamReader(istream.getInputStream()));
            System.out.println(url.toString());
            String line;
            trace = t2pNewTrace();
            while ((line = reader.readLine()) != null) {
                bytes_read = bytes_read + line.length() + 1;
                t2pProcessLine(trace, line);
            }
            t2pHandleEventPairs(trace);
            t2pSort(trace, 0);
            t2pExportTrace(trace, new String("pngtest2.png"), 1000, 700, (float) 0, (float) 33);
            t2pExportTrace(trace, new String("pngtest3.png"), 1000, 700, (float) 2.3, (float) 2.44);
            System.out.println("Press any key to contiune read from stream !!!");
            System.out.println(t2pGetProcessName(trace, 0));
            System.in.read();
            istream = url.openConnection();
            if (last_contentLenght != istream.getContentLength()) {
                istream = url.openConnection();
                istream.setRequestProperty("Range", "bytes=" + Integer.toString(bytes_read) + "-");
                System.out.println(Integer.toString(istream.getContentLength()));
                reader = new BufferedReader(new InputStreamReader(istream.getInputStream()));
                while ((line = reader.readLine()) != null) {
                    System.out.println(line);
                    t2pProcessLine(trace, line);
                }
            } else System.out.println("File not changed !");
            t2pDeleteTrace(trace);
        } catch (MalformedURLException e) {
            System.out.println("MalformedURLException !!!");
        } catch (IOException e) {
            System.out.println("File not found " + args[0]);
        }
        ;
    }

代码2:

    private static String loadUrlToString(String a_url) throws IOException {
        URL l_url1 = new URL(a_url);
        BufferedReader br = new BufferedReader(new InputStreamReader(l_url1.openStream()));
        String l_content = "";
        String l_ligne = null;
        l_content = br.readLine();
        while ((l_ligne = br.readLine()) != null) {
            l_content += AA.SL + l_ligne;
        }
        return l_content;
    }

这个虽然没有涉及小语种,但是明显代码长度差异巨大。不过,我们仍然认为它们是相似的。

我们看一对不相似的吧:
代码1:

    private void setNodekeyInJsonResponse(String service) throws Exception {
        String filename = this.baseDirectory + service + ".json";
        Scanner s = new Scanner(new File(filename));
        PrintWriter fw = new PrintWriter(new File(filename + ".new"));
        while (s.hasNextLine()) {
            fw.println(s.nextLine().replaceAll("NODEKEY", this.key));
        }
        s.close();
        fw.close();
        (new File(filename + ".new")).renameTo(new File(filename));
    }

代码2:

    public void transform(String style, String spec, OutputStream out) throws IOException {
        URL url = new URL(rootURL, spec);
        InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));
        transform(style, in, out);
        in.close();
    }

不相似的就不解释了。

BigCloneBench数据集,就是提供了两段代码,以及它们是否相似的人工打标的结果。

数据分为train.txt, valid.txt, test.txt三个集合,它们的格式都是同样的:

idx1 idx2 0/1

其中idx1和idx2是两段代码在data.jsonl中的索引值,最后一个是它们是否相似的人工打标的值。
代码都保存在data.jsonl中,格式为:

{"func":"代码","idx":"idx值"}

我们以训练集train.txt为例,其前两行是这样的:

13988825	8660836	0
80378	18548122	1

13988825在data.jsonl中对应的结构是这样的:

{"func": "    private void setNodekeyInJsonResponse(String service) throws Exception {\n        String filename = this.baseDirectory + service + \".json\";\n        Scanner s = new Scanner(new File(filename));\n        PrintWriter fw = new PrintWriter(new File(filename + \".new\"));\n        while (s.hasNextLine()) {\n            fw.println(s.nextLine().replaceAll(\"NODEKEY\", this.key));\n        }\n        s.close();\n        fw.close();\n        (new File(filename + \".new\")).renameTo(new File(filename));\n    }\n", "idx": "13988825"}

8660836对应的是:

{"func": "    public void transform(String style, String spec, OutputStream out) throws IOException {\n        URL url = new URL(rootURL, spec);\n        InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));\n        transform(style, in, out);\n        in.close();\n    }\n", "idx": "8660836"}

而它们的结果是不相似。大家看到,这个例子就是刚才上面我们写的第三个例子。

搜索跟当前代码段语义最相似的代码段

这个我们使用北大李戈李师团队的POJ-104数据集。

这个数据集需要到https://drive.google.com/uc?id=0B2i-vWnOu7MxVlJwQXN6eVNONUU去下载。

每个代码段用一个index来描述,然后code字段是完整的代码。我们来看个例子:

{
        "label":"1",
        "index":"0",
        "code":"
int f(int a,int x)
{
 int count=1,i;
 for(i=x;i<a;i++)
  if(a%i==0)
   count+=f(a/i,i);
 if(i==a)
  return count;
 else
  return 0;
}

void main()
{
 int n,a;
 scanf(\"%d\",&n);
 for(;n>0;n--)
 {
  scanf(\"%d\",&a);
  if(a==1||a==2)
   printf(\"1\
\");
  else
   printf(\"%d\
\",f(a,2));
 }
}
"
    }

然后,这个任务的目的就是求出针对某一段代码最相似的代码段。以取top 2为例:输出的样例如下:

{"index": "0", "answers": ["3", "2"]}
{"index": "1", "answers": ["0", "4"]}
{"index": "2", "answers": ["0", "1"]}
{"index": "4", "answers": ["1", "5"]}
{"index": "3", "answers": ["4", "2"]}
{"index": "5", "answers": ["4", "3"]}

也就是说,针对于代码index 0, 最相似的代码段是 index 3和2.

index 3是这样的:

void qut(int a,int b);                                       //????
int num=0;                                                    //?????????
int main()
{
 int i,n,g[1000];                                         //?????????
 cin>>n;
 for(i=0;i<n;i++)                                         //??????
  cin>>g[i];
 for(i=0;i<n;i++)
 {
     qut(g[i],1);                                         //????
  cout<<num<<endl;
              num=0;
 }
 return 0;
}

void qut(int a,int b)  
{
 int i;
 if (a>=b)  
 {
  num++;  
  if (b==1)                                      
   b++;
  for (i=b;i<=a;i++) 
  {
   if (a%i==0) 
   {
    qut(a/i,i);                                 //??a%i==0,??
   }
  }
 }
}

问题:缺陷检测

缺陷检测的数据集非常简单粗暴,就是一段打标的代码,标识是不是有漏洞。

我们看个有漏洞的例子:

{
        "project":"FFmpeg",
        "commit_id":"aba232cfa9b193604ed98f3fa505378d006b1b3b",
        "target":1,
        "func":"static int r3d_read_rdvo(AVFormatContext *s, Atom *atom)

{

    R3DContext *r3d = s->priv_data;

    AVStream *st = s->streams[0];

    int i;



    r3d->video_offsets_count = (atom->size - 8) / 4;

    r3d->video_offsets = av_malloc(atom->size);

    if (!r3d->video_offsets)

        return AVERROR(ENOMEM);



    for (i = 0; i < r3d->video_offsets_count; i++) {

        r3d->video_offsets[i] = avio_rb32(s->pb);

        if (!r3d->video_offsets[i]) {

            r3d->video_offsets_count = i;

            break;

        }

        av_dlog(s, \"video offset %d: %#x\
\", i, r3d->video_offsets[i]);

    }



    if (st->r_frame_rate.num)

        st->duration = av_rescale_q(r3d->video_offsets_count,

                                    (AVRational){st->r_frame_rate.den,

                                                 st->r_frame_rate.num},

                                    st->time_base);

    av_dlog(s, \"duration %\"PRId64\"\
\", st->duration);



    return 0;

}
",
        "idx":5
    }

信息就这么多,至于哪行是什么问题,训练集中没有。

当然,数据集里大部分还是没有漏洞的,比如第一条:

{"project": "FFmpeg", "commit_id": "973b1a6b9070e2bf17d17568cbaf4043ce931f51", "target": 0, "func": "static av_cold int vdadec_init(AVCodecContext *avctx)\n\n{\n\n    VDADecoderContext *ctx = avctx->priv_data;\n\n    struct vda_context *vda_ctx = &ctx->vda_ctx;\n\n    OSStatus status;\n\n    int ret;\n\n\n\n    ctx->h264_initialized = 0;\n\n\n\n    /* init pix_fmts of codec */\n\n    if (!ff_h264_vda_decoder.pix_fmts) {\n\n        if (kCFCoreFoundationVersionNumber < kCFCoreFoundationVersionNumber10_7)\n\n            ff_h264_vda_decoder.pix_fmts = vda_pixfmts_prior_10_7;\n\n        else\n\n            ff_h264_vda_decoder.pix_fmts = vda_pixfmts;\n\n    }\n\n\n\n    /* init vda */\n\n    memset(vda_ctx, 0, sizeof(struct vda_context));\n\n    vda_ctx->width = avctx->width;\n\n    vda_ctx->height = avctx->height;\n\n    vda_ctx->format = 'avc1';\n\n    vda_ctx->use_sync_decoding = 1;\n\n    vda_ctx->use_ref_buffer = 1;\n\n    ctx->pix_fmt = avctx->get_format(avctx, avctx->codec->pix_fmts);\n\n    switch (ctx->pix_fmt) {\n\n    case AV_PIX_FMT_UYVY422:\n\n        vda_ctx->cv_pix_fmt_type = '2vuy';\n\n        break;\n\n    case AV_PIX_FMT_YUYV422:\n\n        vda_ctx->cv_pix_fmt_type = 'yuvs';\n\n        break;\n\n    case AV_PIX_FMT_NV12:\n\n        vda_ctx->cv_pix_fmt_type = '420v';\n\n        break;\n\n    case AV_PIX_FMT_YUV420P:\n\n        vda_ctx->cv_pix_fmt_type = 'y420';\n\n        break;\n\n    default:\n\n        av_log(avctx, AV_LOG_ERROR, \"Unsupported pixel format: %d\\n\", avctx->pix_fmt);\n\n        goto failed;\n\n    }\n\n    status = ff_vda_create_decoder(vda_ctx,\n\n                                   avctx->extradata, avctx->extradata_size);\n\n    if (status != kVDADecoderNoErr) {\n\n        av_log(avctx, AV_LOG_ERROR,\n\n                \"Failed to init VDA decoder: %d.\\n\", status);\n\n        goto failed;\n\n    }\n\n    avctx->hwaccel_context = vda_ctx;\n\n\n\n    /* changes callback functions */\n\n    avctx->get_format = get_format;\n\n    avctx->get_buffer2 = get_buffer2;\n\n#if FF_API_GET_BUFFER\n\n    // force the old get_buffer to be empty\n\n    avctx->get_buffer = NULL;\n\n#endif\n\n\n\n    /* init H.264 decoder */\n\n    ret = ff_h264_decoder.init(avctx);\n\n    if (ret < 0) {\n\n        av_log(avctx, AV_LOG_ERROR, \"Failed to open H.264 decoder.\\n\");\n\n        goto failed;\n\n    }\n\n    ctx->h264_initialized = 1;\n\n\n\n    return 0;\n\n\n\nfailed:\n\n    vdadec_close(avctx);\n\n    return -1;\n\n}\n", "idx": 0}

推理搞起来也是十分省事了,就是对应每个index给个0或1的结果:

0	0
1	1
2	1
3	0
4	0

问题:代码自动修复

有了识别代码漏洞的,更进一步就是学习自动修复代码的了。

代码自动修复的题目也很简单,一段是有bug的代码,另一段是修复之后的代码。

我们来看一个例子:

有bug的代码是这样的:

public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( VAR_1 . length ) - 1 ) ] . getTime ( ) ) ; }

修复之后是这样子的:

public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( type ) - 1 ) ] . getTime ( ) ) ; }

也真难为算法了,人看起来都有点费事。

问题:代码互译

比如实现C#语言和Java语言的互译。我们只要有一系列代码的C#写法和Java写法,就可以进行学习进行互译。

我们来看一对例子。
先看C#代码:

public virtual ListSpeechSynthesisTasksResponse ListSpeechSynthesisTasks(ListSpeechSynthesisTasksRequest request){
  var options = new InvokeOptions();
  options.RequestMarshaller = ListSpeechSynthesisTasksRequestMarshaller.Instance;
  options.ResponseUnmarshaller = ListSpeechSynthesisTasksResponseUnmarshaller.Instance;
  return Invoke<ListSpeechSynthesisTasksResponse>(request, options);
}

对应的Java

public ListSpeechSynthesisTasksResult listSpeechSynthesisTasks(ListSpeechSynthesisTasksRequest request) {
  request = beforeClientExecution(request);
  return executeListSpeechSynthesisTasks(request);
}

代码互译

问题:给代码写注释

在训练素材中,有代码和注释,这个任务的目的为新代码写注释。评价指标是对于生成的注释的语言准确度。

这个我们使用CodeSearchNet数据集。

这个数据集中的每条记录的格式如下:

  • repo: 仓库名
  • path: 文件名
  • func_name: 函数或方法名
  • original_string: 未经处理的源字符串
  • language: 编程语言
  • code/function: 代码信息
  • code_tokens/function_tokens: 分词之后的代码结果
  • docstring: 注释字符串信息
  • docstring_tokens: docstring分词之后的结果
  • url: 自然语言的唯一标识号
  • idx: 代码段的唯一标识号

我们来看个例子:

{"repo": "ciena-blueplanet/bunsen-core", "path": "src/reducer.js", "func_name": "", "original_string": "function
(state, action) {\n    return _.defaults({\n      isValidating: action.isValidating,\n      lastAction: IS_VALIDA
TING\n    }, state)\n  }", "language": "javascript", "code": "function (state, action) {\n    return _.defaults({
\n      isValidating: action.isValidating,\n      lastAction: IS_VALIDATING\n    }, state)\n  }", "code_tokens":
["function", "(", "state", ",", "action", ")", "{", "return", "_", ".", "defaults", "(", "{", "isValidating", ":"
, "action", ".", "isValidating", ",", "lastAction", ":", "IS_VALIDATING", "}", ",", "state", ")", "}"], "docstrin
g": "Update is validating result\n@param {State} state - state to update\n@param {Action} action - action\n@retur
ns {State} - updated state", "docstring_tokens": ["Update", "is", "validating", "result"], "sha": "993c67e314e2b7
5003a1ff4c2f0cb667715562b2", "url": "https://github.com/ciena-blueplanet/bunsen-core/blob/993c67e314e2b75003a1ff4
c2f0cb667715562b2/src/reducer.js#L394-L399", "partition": "train"}

对于生成的自然语言,我们采用《ORANGE: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation 》论文的方法进行评分。

问题:为自然语言文本匹配最合适的代码段

我们仍然使用上一节的CodeSearchNet数据集。

这个搜索的结果类似于下面这样:

{"url": "url0", "answers": [10,11,12,13,14]}
{"url": "url1", "answers": [10,12,11,13,14]}
{"url": "url2", "answers": [13,11,12,10,14]}
{"url": "url3", "answers": [10,14,12,13,11]}
{"url": "url4", "answers": [10,11,12,13,14]}

配上UI,大致实现的效果是这样的:

文本转代码

或者是这样:
webquery

问题:根据自然语言生成代码

这是终极任务,就是根据一段文本描述硬生生地生成一段代码出来。

格式非常简单,就一段代码和一段文本。

我们来看个训练样本的例子:

{"code": "void function ( Binder arg0 ) { EventBus loc0 = new EventBus ( ) ; AmbariEventPublisher loc1 = new AmbariEventPublisher ( ) ; repla
ceEventBus ( AmbariEventPublisher . class , loc1 , loc0 ) ; arg0 . bind ( AmbariEventPublisher . class ) . toInstance ( loc1 ) ; }", "nl": "force the eventb us from ambarievent publisher to be serialand synchronous . concode_field_sep PlaceHolder placeHolder concode_field_sep void registerAlertListeners concode_elem_sep EventBus synchronizeAlertEventPublisher concode_elem_sep void replaceEventBus concode_elem_sep void registerAmbariListeners"}

这NL部分有点乱啊,没办法,为了增加数据量,没有那么多人手打精确的标。

我们再看一个:

{"code": "byte [ ] function ( Class < ? > arg0 , Configuration arg1 ) { return AuthenticationTokenSerializer . serialize ( org . apache . acc
umulo . core . client . mapreduce . lib . impl . ConfiguratorBase . getAuthenticationToken ( arg0 , arg1 ) ) ; }", "nl": "do n't use this . n
o , really , do n't use this . you already have an authenticationtoken with org.apache.accumulo.core.client.mapreduce.lib.impl.configuratorba
se #getauthenticationtoken class , configuration . you do n't need to construct it yourself . gets the password from the configuration . warn
ing : the password is stored in the configuration and shared with all mapreduce tasks ; it is base64 encoded to provide a charset safe conver
sion to a string , and is not intended to be secure . concode_field_sep PlaceHolder placeHolder concode_field_sep String getPrincipal concode
_elem_sep void setLogLevel concode_elem_sep Level getLogLevel concode_elem_sep Boolean isConnectorInfoSet concode_elem_sep String getTokenCla
ss concode_elem_sep void setZooKeeperInstance concode_elem_sep void setMockInstance concode_elem_sep Instance getInstance concode_elem_sep St
ring enumToConfKey concode_elem_sep void setConnectorInfo"}

是不是质量也没好到哪儿去?这就是CONCODE数据集的样子。

解法:基于大规模预训练模型的多任务学习

402年前,当努尔哈赤面临明朝多路大军的围困的时候,采取了“凭你几路来,我只一路去”的战术赢得了萨尔浒之战的立国之战。
我们同样学习古人的智慧,任你数据集千变万化,我们的工具就只用一个 - 大规模预训练模型。

下面是预训练模型的简要发展史:

image.png

以开头我们展示的微软的codebert模型为例,我们要处理上面最复杂的代码生成任务,只要一条命令就可以搞定:

python -m torch.distributed.launch --nproc_per_node=$PER_NODE_GPU run.py \
        --data_dir=$DATADIR \
        --langs=$LANG \
        --output_dir=$OUTPUTDIR \
        --pretrain_dir=$PRETRAINDIR \
        --log_file=$LOGFILE \
        --model_type=gpt2 \
        --block_size=512 \
        --do_train \
        --node_index 0 \
        --gpu_per_node $PER_NODE_GPU \
        --learning_rate=5e-5 \
        --weight_decay=0.01 \
        --evaluate_during_training \
        --per_gpu_train_batch_size=6 \
        --per_gpu_eval_batch_size=12 \
        --gradient_accumulation_steps=2 \
        --num_train_epochs=30 \
        --logging_steps=100 \
        --save_steps=5000 \
        --overwrite_output_dir \
        --seed=42

如果使用两张2 NVIDIA P100 GPU卡的话,22小时左右就可以训练完。

推理呢,也是一条语句就搞定:

python -u run.py \
        --data_dir=$DATADIR \
        --langs=$LANG \
        --output_dir=$OUTPUTDIR \
        --pretrain_dir=$PRETRAINDIR \
        --log_file=$LOGFILE \
        --model_type=gpt2 \
        --block_size=512 \
        --do_infer \
        --logging_steps=100 \
        --seed=42

只用一张P100卡,大约40分钟就可以搞定。

有了上面的基础,我们就可以去打比赛啦。上面介绍的数据集,全都是比赛的赛题:
截屏2021-07-02 下午12.56.58.png

上面提到的数据集,可以在https://github.com/microsoft/CodeXGLUE下载到。

欢迎来到代码智能的世界!

附录:快速上手指南

放翁云:纸上得来终觉浅,绝知此事要躬行。
下面我们就落地下,将代码智能模型的训练和推理跑起来~~~

  • 第一步:安装transformers框架,因为codebert是基于这个框架的:
pip install transformers --user
  • 第二步:安装PyTorch或者Tensorflow作为Transformers的后端,以2021年7月5日这个时间点,需要的PyTorch版本至少是1.5.0以上。驱动能搞定的话,索性就安装最新的吧:
pip install torch torchvision torchtext torchaudio --user
  • 第三步,下载微软的数据集
git clone https://github.com/microsoft/CodeXGLUE
  • 第四步,我们先玩玩BigCloneBench吧

到Code-Code/Clone-detection-BigCloneBench/code目录下,运行:

python run.py     --output_dir=./saved_models     --model_type=roberta     --config_name=microsoft/codebert-base     --model_name_or_path=microsoft/codebert-base     --tokenizer_name=roberta-base     --do_train     --train_data_file=../dataset/train.txt     --eval_data_file=../dataset/valid.txt     --test_data_file=../dataset/test.txt     --epoch 2     --block_size 400     --train_batch_size 16     --eval_batch_size 32     --learning_rate 5e-5     --max_grad_norm 1.0     --evaluate_during_training     --seed 123456 2>&1| tee train.log

然后训练就运行起来了:

07/05/2021 16:29:24 - INFO - __main__ -   ***** Running training *****
07/05/2021 16:29:24 - INFO - __main__ -     Num examples = 90102
07/05/2021 16:29:24 - INFO - __main__ -     Num Epochs = 2
07/05/2021 16:29:24 - INFO - __main__ -     Instantaneous batch size per GPU = 8
07/05/2021 16:29:24 - INFO - __main__ -     Total train batch size (w. parallel, distributed & accumulation) = 16
07/05/2021 16:29:24 - INFO - __main__ -     Gradient Accumulation steps = 1
07/05/2021 16:29:24 - INFO - __main__ -     Total optimization steps = 11264

在两张V100卡大约需要训练40分钟左右。
训练之后是验证,会将目前最好的结果保存到checkpoint中以备推理时使用

07/05/2021 17:10:04 - INFO - __main__ -   ***** Running evaluation  ***** 40950/41541 [00:10<00:00, 2785.61it/s]
07/05/2021 17:10:04 - INFO - __main__ -     Num examples = 41541
07/05/2021 17:10:04 - INFO - __main__ -     Batch size = 32
07/05/2021 17:16:05 - INFO - __main__ -   ***** Eval results  *****
07/05/2021 17:16:05 - INFO - __main__ -     eval_f1 = 0.9531
07/05/2021 17:16:05 - INFO - __main__ -     eval_precision = 0.9579
07/05/2021 17:16:05 - INFO - __main__ -     eval_recall = 0.9484
07/05/2021 17:16:05 - INFO - __main__ -     eval_threshold = 0.97
07/05/2021 17:16:06 - INFO - __main__ -     ********************
07/05/2021 17:16:06 - INFO - __main__ -     Best f1:0.9531
07/05/2021 17:16:06 - INFO - __main__ -     ********************
07/05/2021 17:16:08 - INFO - __main__ -   Saving model checkpoint to ./saved_models/checkpoint-best-f1/model.bin

一次训练两轮,第二轮效果提升到0.97多:

07/05/2021 17:56:43 - INFO - __main__ -   ***** Running evaluation  ***** 40950/41541 [00:12<00:00, 3535.62it/s]
07/05/2021 17:56:43 - INFO - __main__ -     Num examples = 41541
07/05/2021 17:56:43 - INFO - __main__ -     Batch size = 32
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
07/05/2021 18:02:44 - INFO - __main__ -   ***** Eval results  *****
07/05/2021 18:02:44 - INFO - __main__ -     eval_f1 = 0.9701
07/05/2021 18:02:44 - INFO - __main__ -     eval_precision = 0.9772
07/05/2021 18:02:44 - INFO - __main__ -     eval_recall = 0.9633
07/05/2021 18:02:44 - INFO - __main__ -     eval_threshold = 0.97
07/05/2021 18:02:45 - INFO - __main__ -     ********************
07/05/2021 18:02:45 - INFO - __main__ -     Best f1:0.9701
07/05/2021 18:02:45 - INFO - __main__ -     ********************
07/05/2021 18:02:47 - INFO - __main__ -   Saving model checkpoint to ./saved_models/checkpoint-best-f1/model.bin

然后我们用训好的模型进行推理吧:

python run.py \
    --output_dir=./saved_models \
    --model_type=roberta \
    --config_name=microsoft/codebert-base \
    --model_name_or_path=microsoft/codebert-base \
    --tokenizer_name=roberta-base \
    --do_eval \
    --do_test \
    --train_data_file=../dataset/train.txt \
    --eval_data_file=../dataset/valid.txt \
    --test_data_file=../dataset/test.txt \
    --epoch 2 \
    --block_size 400 \
    --train_batch_size 16 \
    --eval_batch_size 32 \
    --learning_rate 5e-5 \
    --max_grad_norm 1.0 \
    --evaluate_during_training \
    --seed 123456 2>&1| tee test.log

最后我们运行evaluator.py来查看测试结果:

python ../evaluator/evaluator.py -a ../dataset/test.txt -p saved_models/predictions.txt

输出如下:

{'Recall': 0.9677421599288263, 'Prediction': 0.9557057904236594, 'F1': 0.9616080550111168}

准确率0.956, 召回率0.968,还不错~

跟CodeXGLUE的排行榜比一比:

跟榜上的CodeBert的结果基本一致

GraphCodeBert

要提升性能,我们可以用GraphCodeBert来替换CodeBert.

我们先下载GraphCodeBert的代码:

git clone https://github.com/microsoft/CodeBERT

然后转到GraphCodeBERT/clonedetection目录,解压dataset.zip:

unzip dataset.zip

然后就可以像训练codebert一样训练graphcodebert了:

mkdir saved_models
python run.py \
    --output_dir=saved_models \
    --config_name=microsoft/graphcodebert-base \
    --model_name_or_path=microsoft/graphcodebert-base \
    --tokenizer_name=microsoft/graphcodebert-base \
    --do_train \
    --train_data_file=dataset/train.txt \
    --eval_data_file=dataset/valid.txt \
    --test_data_file=dataset/test.txt \
    --epoch 1 \
    --code_length 512 \
    --data_flow_length 128 \
    --train_batch_size 16 \
    --eval_batch_size 32 \
    --learning_rate 2e-5 \
    --max_grad_norm 1.0 \
    --evaluate_during_training \
    --seed 123456 2>&1| tee saved_models/train.log

上面的参数是按4个V100 GPU来调的,如果只有两块V100,可以将–code_length改成256.
CodeBert 40分钟左右一轮,GraphCodeBert大约需要6个半小时一轮。

然后我们进行推理:

python run.py     --output_dir=saved_models     --config_name=microsoft/graphcodebert-base     --model_name_or_path=microsoft/graphcodebert-base     --tokenizer_name=microsoft/graphcodebert-base     --do_eval     --do_test     --train_data_file=dataset/train.txt     --eval_data_file=dataset/valid.txt     --test_data_file=dataset/test.txt     --epoch 1     --code_length 256     --data_flow_length 128     --train_batch_size 16     --eval_batch_size 32     --learning_rate 2e-5     --max_grad_norm 1.0     --evaluate_during_training     --seed 123456 2>&1| tee saved_models/test.log

最后我们解读一下结果吧:

python evaluator/evaluator.py -a dataset/test.txt -p saved_models/predictions.txt 2>&1| tee saved_models/score.log

结果如下:

{'Recall': 0.9589415798936043, 'Prediction': 0.962620653900429, 'F1': 0.9607703728051462}

标签:String,--,代码,vda,智能,train,new,解法
来源: https://blog.csdn.net/lusing/article/details/118519575

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有