| from transformers import PreTrainedModel |
| import torch |
| from .proto import ProtoModule |
| from .configuration_proto import ProtoConfig |
|
|
| class ProtoForMultiLabelClassification(PreTrainedModel): |
| config_class = ProtoConfig |
|
|
| def __init__(self, config: ProtoConfig): |
| super().__init__(config) |
| self.proto_module = ProtoModule( |
| pretrained_model=config.pretrained_model_name_or_path, |
| num_classes=config.num_classes, |
| label_order_path=config.label_order_path, |
| use_sigmoid=config.use_sigmoid, |
| use_cuda=config.use_cuda, |
| lr_prototypes=config.lr_prototypes, |
| lr_features=config.lr_features, |
| lr_others=config.lr_others, |
| num_training_steps=config.num_training_steps, |
| num_warmup_steps=config.num_warmup_steps, |
| loss=config.loss, |
| save_dir=config.save_dir, |
| use_attention=config.use_attention, |
| dot_product=config.dot_product, |
| normalize=config.normalize, |
| final_layer=config.final_layer, |
| reduce_hidden_size=config.reduce_hidden_size, |
| use_prototype_loss=config.use_prototype_loss, |
| prototype_vector_path=config.prototype_vector_path, |
| attention_vector_path=config.attention_vector_path, |
| eval_buckets=config.eval_buckets, |
| seed=config.seed |
| ) |
| self.init_weights() |
|
|
| def forward(self, input_ids, attention_mask, token_type_ids, **kwargs): |
| batch = { |
| "input_ids": input_ids, |
| "attention_masks": attention_mask, |
| "token_type_ids": token_type_ids, |
| } |
| logits, metadata = self.proto_module(batch) |
| return {"logits": logits, "metadata": metadata} |
|
|