| | import torch |
| | import json |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
| | from datasets import load_dataset |
| | from tqdm import tqdm |
| |
|
| | device_map = "auto" |
| | model = AutoModelForCausalLM.from_pretrained( |
| | "/path/to/meta-llama3-8b/", |
| | return_dict=True, |
| | torch_dtype=torch.float16, |
| | device_map=device_map) |
| |
|
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("/path/to/meta-llama3-8b/",add_eos_token=True) |
| |
|
| | tokenizer.pad_token_id = tokenizer.eos_token_id + 1 |
| | tokenizer.padding_side = "right" |
| |
|
| | pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, max_new_tokens=100) |
| |
|
| | test_dataset = load_dataset("json", data_files={'test':'/path/to/parser_test_moves_15.jsonl'})["test"] |
| |
|
| | def is_first_moves(sample): |
| | answer = 0 |
| | slist = sample.split('\n') |
| | if slist[0].startswith('Context: 0 <Buil> Mission has started.'): |
| | struct = [i for i in slist if i.startswith('Structure:')] |
| | rels = struct[0].split(':')[1].strip() |
| | if len(rels) == 0: |
| | answer = 1 |
| | return answer |
| |
|
| |
|
| | def check_endpoints(struct, head): |
| | """ |
| | takes a struct string and a head int and returns only |
| | the struct rels with sources that are >= head |
| | """ |
| | new_rels_list = [] |
| | new_rels = None |
| | if struct: |
| | rels = struct.split(' ') |
| | for rel in rels: |
| | if len(rel) > 0: |
| | source = int(rel.split('(')[1].split(',')[0].strip()) |
| | if source >= head: |
| | new_rels_list.append(rel) |
| | if len(new_rels_list) > 0: |
| | new_rels = ' '.join(new_rels_list) |
| | return new_rels |
| |
|
| | def add_previous(sample, previous, predictions): |
| | new_output = [] |
| | keep_str = None |
| | |
| | slist = sample.split('\n') |
| | head = int(slist[0].split('Context:')[1].split('<')[0].strip()) |
| | |
| | for s in slist: |
| | if s.startswith('Structure:'): |
| | new_structure = check_endpoints(previous, head) |
| | if new_structure: |
| | s = 'Structure: ' + new_structure + ' ' + predictions |
| | keep_str = new_structure + ' ' + predictions |
| | else: |
| | s = 'Structure: ' + predictions |
| | keep_str = predictions |
| | new_output.append(s) |
| | new_output_string = '\n'.join(new_output) |
| | return keep_str, new_output_string |
| |
|
| | def format_gen(preds): |
| | labels = ['COM','CONTR','CORR','QAP','ACK','ELAB','CLARIFQ','COND','CONTIN', |
| | 'RES','EXPL','QELAB','ALT','NARR','CONFQ','SEQ'] |
| | split_list = [st.strip() for st in preds.split(' ')] |
| | clean_list = [] |
| | for a in split_list: |
| | s_tuple = None |
| | rel = None |
| | try: |
| | s = a.split('(')[1].split(')')[0].split(',') |
| | r = a.split('(')[0].strip() |
| | except IndexError: |
| | print('split error one') |
| | else: |
| | try: |
| | s_tuple = (int(s[0]), int(s[1])) |
| | except IndexError: |
| | print('split error two') |
| | except ValueError: |
| | print('value error three') |
| | if r in labels: |
| | |
| | rel = r |
| | if rel != None and s_tuple != None: |
| | clean_list.append(rel + '(' + str(s_tuple[0]) + ',' + str(s_tuple[1]) + ')') |
| | clean_preds = ' '.join(clean_list) |
| | return clean_preds |
| |
|
| |
|
| | def formatting_prompts_func(example): |
| | output_text = '<|begin_of_text|>Identify the discourse structure (DS) for the new turn in the following excerpt :\n' + example + '\n ### DS:' |
| | return output_text |
| |
|
| |
|
| | f = open("/path/to/val-output-file.txt","w") |
| |
|
| | new_generations = None |
| | previous_generations = None |
| | for datum in tqdm(test_dataset['sample']): |
| |
|
| | |
| | if is_first_moves(datum): |
| | text = formatting_prompts_func(datum) |
| | previous_generations = None |
| | else: |
| | |
| | update_prev, amended_text = add_previous(datum, previous_generations, new_generations) |
| | previous_generations = update_prev |
| | text = formatting_prompts_func(amended_text) |
| | generated = pipe(text)[0]['generated_text'] |
| | print(generated, file=f) |
| | new_generations = format_gen(generated.split('### DS:')[1]) |
| |
|
| | f.close() |
| |
|
| |
|