| import pytest |
| import torch |
|
|
| from mmseg.core import OHEMPixelSampler |
| from mmseg.models.decode_heads import FCNHead |
|
|
|
|
| def _context_for_ohem(): |
| return FCNHead(in_channels=32, channels=16, num_classes=19) |
|
|
|
|
| def test_ohem_sampler(): |
|
|
| with pytest.raises(AssertionError): |
| |
| sampler = OHEMPixelSampler(context=_context_for_ohem()) |
| seg_logit = torch.randn(1, 19, 45, 45) |
| seg_label = torch.randint(0, 19, size=(1, 1, 89, 89)) |
| sampler.sample(seg_logit, seg_label) |
|
|
| |
| sampler = OHEMPixelSampler( |
| context=_context_for_ohem(), thresh=0.7, min_kept=200) |
| seg_logit = torch.randn(1, 19, 45, 45) |
| seg_label = torch.randint(0, 19, size=(1, 1, 45, 45)) |
| seg_weight = sampler.sample(seg_logit, seg_label) |
| assert seg_weight.shape[0] == seg_logit.shape[0] |
| assert seg_weight.shape[1:] == seg_logit.shape[2:] |
| assert seg_weight.sum() > 200 |
|
|
| |
| sampler = OHEMPixelSampler(context=_context_for_ohem(), min_kept=200) |
| seg_logit = torch.randn(1, 19, 45, 45) |
| seg_label = torch.randint(0, 19, size=(1, 1, 45, 45)) |
| seg_weight = sampler.sample(seg_logit, seg_label) |
| assert seg_weight.shape[0] == seg_logit.shape[0] |
| assert seg_weight.shape[1:] == seg_logit.shape[2:] |
| assert seg_weight.sum() == 200 |
|
|