diff --git a/lib/models/seg_hrnet_ocr.py b/lib/models/seg_hrnet_ocr.py index da19d236..487c8ec5 100644 --- a/lib/models/seg_hrnet_ocr.py +++ b/lib/models/seg_hrnet_ocr.py @@ -61,7 +61,7 @@ def forward(self, feats, probs): probs = probs.view(batch_size, c, -1) feats = feats.view(batch_size, feats.size(1), -1) feats = feats.permute(0, 2, 1) # batch x hw x c - probs = F.softmax(self.scale * probs, dim=2)# batch x k x hw + probs = F.softmax(self.scale * probs, dim=1)# batch x k x hw ocr_context = torch.matmul(probs, feats)\ .permute(0, 2, 1).unsqueeze(3)# batch x k x c return ocr_context