| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import unittest |
|
|
| from transformers import is_torch_available |
| from transformers.testing_utils import require_torch |
|
|
|
|
| if is_torch_available(): |
| import torch |
|
|
| from transformers.generation import DisjunctiveConstraint |
|
|
|
|
| @require_torch |
| class ConstraintTest(unittest.TestCase): |
| def test_input_types(self): |
| |
| |
|
|
| cset = [[1, 2, 4], [1, 2, 3, 4]] |
| dc = DisjunctiveConstraint(cset) |
| self.assertTrue(isinstance(dc.token_ids, list)) |
|
|
| with self.assertRaises(ValueError): |
| DisjunctiveConstraint(torch.LongTensor([[1, 2, 4], [1, 2, 3]])) |
|
|
| with self.assertRaises(ValueError): |
| DisjunctiveConstraint([torch.LongTensor([1, 2, 4]), torch.LongTensor([1, 2, 3, 4, 5])]) |
|
|
| def test_check_illegal_input(self): |
| |
| |
| |
| |
| |
| cset = [[1, 2], [1, 2, 3, 4]] |
|
|
| with self.assertRaises(ValueError): |
| DisjunctiveConstraint(cset) |
|
|
| def test_example_progression(self): |
| cset = [[1, 2, 3], [1, 2, 4]] |
|
|
| dc = DisjunctiveConstraint(cset) |
|
|
| stepped, completed, reset = dc.update(1) |
| desired = stepped is True and completed is False and reset is False |
| self.assertTrue(desired) |
| self.assertTrue(not dc.completed) |
| self.assertTrue(dc.current_seq == [1]) |
|
|
| stepped, completed, reset = dc.update(2) |
| desired = stepped is True and completed is False and reset is False |
| self.assertTrue(desired) |
| self.assertTrue(not dc.completed) |
| self.assertTrue(dc.current_seq == [1, 2]) |
|
|
| stepped, completed, reset = dc.update(3) |
| desired = stepped is True and completed is True and reset is False |
| self.assertTrue(desired) |
| self.assertTrue(dc.completed) |
| self.assertTrue(dc.current_seq == [1, 2, 3]) |
|
|
| def test_example_progression_unequal_three_mid_and_reset(self): |
| cset = [[1, 2, 3], [1, 2, 4, 5], [1, 2, 5]] |
|
|
| dc = DisjunctiveConstraint(cset) |
|
|
| stepped, completed, reset = dc.update(1) |
| self.assertTrue(not dc.completed) |
| self.assertTrue(dc.current_seq == [1]) |
|
|
| stepped, completed, reset = dc.update(2) |
| self.assertTrue(not dc.completed) |
| self.assertTrue(dc.current_seq == [1, 2]) |
|
|
| stepped, completed, reset = dc.update(4) |
| self.assertTrue(not dc.completed) |
| self.assertTrue(dc.current_seq == [1, 2, 4]) |
|
|
| stepped, completed, reset = dc.update(5) |
| self.assertTrue(dc.completed) |
| self.assertTrue(dc.current_seq == [1, 2, 4, 5]) |
|
|
| dc.reset() |
|
|
| stepped, completed, reset = dc.update(1) |
| self.assertTrue(not dc.completed) |
| self.assertTrue(dc.remaining() == 3) |
| self.assertTrue(dc.current_seq == [1]) |
|
|
| stepped, completed, reset = dc.update(2) |
| self.assertTrue(not dc.completed) |
| self.assertTrue(dc.remaining() == 2) |
| self.assertTrue(dc.current_seq == [1, 2]) |
|
|
| stepped, completed, reset = dc.update(5) |
| self.assertTrue(dc.completed) |
| self.assertTrue(dc.remaining() == 0) |
| self.assertTrue(dc.current_seq == [1, 2, 5]) |
|
|