| | |
| | |
| |
|
| | |
| |
|
| |
|
| | import gradio as gr |
| | import pandas as pd |
| | import numpy as np |
| | import torch |
| | from torch import nn |
| | from torch.nn import init, MarginRankingLoss |
| | from torch.optim import Adam |
| | from distutils.version import LooseVersion |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.autograd import Variable |
| | import math |
| | from transformers import AutoConfig, AutoModel, AutoTokenizer |
| | import nltk |
| | import re |
| | import torch.optim as optim |
| | from tqdm import tqdm |
| | from transformers import AutoModelForMaskedLM |
| | import torch.nn.functional as F |
| | import random |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | def greet(X, ny): |
| | global eng_dict |
| | ny = int(ny) |
| | if ny == 0: |
| | rand_no = random.random() |
| | tok_map = {2: 0.4363429005892416, |
| | 1: 0.6672580202327398, |
| | 4: 0.7476060740459144, |
| | 3: 0.9618703668504087, |
| | 6: 0.9701028532809564, |
| | 7: 0.9729244545819342, |
| | 8: 0.9739508754144756, |
| | 5: 0.9994508859743607, |
| | 9: 0.9997507867114407, |
| | 10: 0.9999112969650892, |
| | 11: 0.9999788802297832, |
| | 0: 0.9999831041838266, |
| | 12: 0.9999873281378701, |
| | 22: 0.9999957760459568, |
| | 14: 1.0000000000000002} |
| | for key in tok_map.keys(): |
| | if rand_no < tok_map[key]: |
| | num_sub_tokens_label = key |
| | break |
| | else: |
| | num_sub_tokens_label = ny |
| | tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base") |
| | model = AutoModelForMaskedLM.from_pretrained("microsoft/graphcodebert-base") |
| | model.load_state_dict(torch.load('model_26_2')) |
| | model.eval() |
| | X_init = X |
| | X_init = X_init.replace("[MASK]", " [MASK] ") |
| | X_init = X_init.replace("[MASK]", " ".join([tokenizer.mask_token] * num_sub_tokens_label)) |
| | tokens = tokenizer.encode_plus(X_init, add_special_tokens=False,return_tensors='pt') |
| | input_id_chunki = tokens['input_ids'][0].split(510) |
| | input_id_chunks = [] |
| | mask_chunks = [] |
| | mask_chunki = tokens['attention_mask'][0].split(510) |
| | for tensor in input_id_chunki: |
| | input_id_chunks.append(tensor) |
| | for tensor in mask_chunki: |
| | mask_chunks.append(tensor) |
| | xi = torch.full((1,), fill_value=101) |
| | yi = torch.full((1,), fill_value=1) |
| | zi = torch.full((1,), fill_value=102) |
| | for r in range(len(input_id_chunks)): |
| | input_id_chunks[r] = torch.cat([xi, input_id_chunks[r]],dim = -1) |
| | input_id_chunks[r] = torch.cat([input_id_chunks[r],zi],dim=-1) |
| | mask_chunks[r] = torch.cat([yi, mask_chunks[r]],dim=-1) |
| | mask_chunks[r] = torch.cat([mask_chunks[r],yi],dim=-1) |
| | di = torch.full((1,), fill_value=0) |
| | for i in range(len(input_id_chunks)): |
| | pad_len = 512 - input_id_chunks[i].shape[0] |
| | if pad_len > 0: |
| | for p in range(pad_len): |
| | input_id_chunks[i] = torch.cat([input_id_chunks[i],di],dim=-1) |
| | mask_chunks[i] = torch.cat([mask_chunks[i],di],dim=-1) |
| | vb = torch.ones_like(input_id_chunks[0]) |
| | fg = torch.zeros_like(input_id_chunks[0]) |
| | maski = [] |
| | for l in range(len(input_id_chunks)): |
| | masked_pos = [] |
| | for i in range(len(input_id_chunks[l])): |
| | if input_id_chunks[l][i] == tokenizer.mask_token_id: |
| | if i != 0 and input_id_chunks[l][i-1] == tokenizer.mask_token_id: |
| | continue |
| | masked_pos.append(i) |
| | maski.append(masked_pos) |
| | input_ids = torch.stack(input_id_chunks) |
| | att_mask = torch.stack(mask_chunks) |
| | outputs = model(input_ids, attention_mask = att_mask) |
| | last_hidden_state = outputs[0].squeeze() |
| | l_o_l_sa = [] |
| | sum_state = [] |
| | for t in range(num_sub_tokens_label): |
| | c = [] |
| | l_o_l_sa.append(c) |
| | if len(maski) == 1: |
| | masked_pos = maski[0] |
| | for k in masked_pos: |
| | for t in range(num_sub_tokens_label): |
| | l_o_l_sa[t].append(last_hidden_state[k+t]) |
| | else: |
| | for p in range(len(maski)): |
| | masked_pos = maski[p] |
| | for k in masked_pos: |
| | for t in range(num_sub_tokens_label): |
| | if (k+t) >= len(last_hidden_state[p]): |
| | l_o_l_sa[t].append(last_hidden_state[p+1][k+t-len(last_hidden_state[p])]) |
| | continue |
| | l_o_l_sa[t].append(last_hidden_state[p][k+t]) |
| | for t in range(num_sub_tokens_label): |
| | sum_state.append(l_o_l_sa[t][0]) |
| | for i in range(len(l_o_l_sa[0])): |
| | if i == 0: |
| | continue |
| | for t in range(num_sub_tokens_label): |
| | sum_state[t] = sum_state[t] + l_o_l_sa[t][i] |
| | yip = len(l_o_l_sa[0]) |
| | |
| | er = "" |
| | for t in range(num_sub_tokens_label): |
| | sum_state[t] /= yip |
| | idx = torch.topk(sum_state[t], k=5, dim=0)[1] |
| | wor = [tokenizer.decode(i.item()).strip() for i in idx] |
| | for kl in wor: |
| | if all(char.isalpha() for char in kl): |
| | |
| | er+=kl |
| | break |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return er |
| | title = "Rename a variable in a Java class" |
| | description = """This model is a fine-tuned GraphCodeBERT model fin-tuned to output higher-quality variable names for Java classes. Long classes are handled by the |
| | model. Replace any variable name with a "[MASK]" to get an identifier renaming. |
| | """ |
| | ex = ["""import java.io.*; |
| | public class x { |
| | public static void main(String[] args) { |
| | String f = "file.txt"; |
| | BufferedReader [MASK] = null; |
| | String l; |
| | try { |
| | [MASK] = new BufferedReader(new FileReader(f)); |
| | while ((l = [MASK].readLine()) != null) { |
| | System.out.println(l); |
| | } |
| | } catch (IOException e) { |
| | e.printStackTrace(); |
| | } finally { |
| | try { |
| | if ([MASK] != null) [MASK].close(); |
| | } catch (IOException ex) { |
| | ex.printStackTrace(); |
| | } |
| | } |
| | } |
| | }""", """import java.net.*; |
| | import java.io.*; |
| | |
| | public class s { |
| | public static void main(String[] args) throws IOException { |
| | ServerSocket [MASK] = new ServerSocket(8000); |
| | try { |
| | Socket s = [MASK].accept(); |
| | PrintWriter pw = new PrintWriter(s.getOutputStream(), true); |
| | BufferedReader br = new BufferedReader(new InputStreamReader(s.getInputStream())); |
| | String i; |
| | while ((i = br.readLine()) != null) { |
| | pw.println(i); |
| | } |
| | } finally { |
| | if ([MASK] != null) [MASK].close(); |
| | } |
| | } |
| | }""", """import java.io.*; |
| | import java.util.*; |
| | |
| | public class y { |
| | public static void main(String[] args) { |
| | String [MASK] = "data.csv"; |
| | String l = ""; |
| | String cvsSplitBy = ","; |
| | try (BufferedReader br = new BufferedReader(new FileReader([MASK]))) { |
| | while ((l = br.readLine()) != null) { |
| | String[] z = l.split(cvsSplitBy); |
| | System.out.println("Values [field-1= " + z[0] + " , field-2=" + z[1] + "]"); |
| | } |
| | } catch (IOException e) { |
| | e.printStackTrace(); |
| | } |
| | } |
| | }"""] |
| | |
| | textbox = gr.Textbox(title=title, |
| | description=description,examples = ex,label="Type Java code snippet:", placeholder="replace variable with [MASK]", lines=10) |
| |
|
| | gr.Interface(fn=greet, inputs=[ |
| | textbox, |
| | gr.Textbox(type="text", label="Number of tokens in name:", placeholder="0 for randomly sampled number of tokens") |
| | ], outputs="text").launch() |
| |
|
| |
|
| | |
| |
|
| |
|
| | import java.io.*; |
| | public class x { |
| | public static void main(String[] args) { |
| | String f = "file.txt"; |
| | BufferedReader [MASK] = null; |
| | String l; |
| | try { |
| | [MASK] = new BufferedReader(new FileReader(f)); |
| | while ((l = [MASK].readLine()) != null) { |
| | System.out.println(l); |
| | } |
| | } catch (IOException e) { |
| | e.printStackTrace(); |
| | } finally { |
| | try { |
| | if ([MASK] != null) [MASK].close(); |
| | } catch (IOException ex) { |
| | ex.printStackTrace(); |
| | } |
| | } |
| | } |
| | } |
| |
|
| |
|