1 from collections
import namedtuple
5 import torch.nn.init
as init
6 from torchvision
import models
7 from torchvision.models.vgg
import model_urls
12 if isinstance(m, nn.Conv2d):
13 init.xavier_uniform_(m.weight.data)
14 if m.bias
is not None:
16 elif isinstance(m, nn.BatchNorm2d):
17 m.weight.data.fill_(1)
19 elif isinstance(m, nn.Linear):
20 m.weight.data.normal_(0, 0.01)
25 def __init__(self, pretrained=True, freeze=True):
27 model_urls[
'vgg16_bn'] = model_urls[
'vgg16_bn'].replace(
28 'https://',
'http://')
29 vgg_pretrained_features = models.vgg16_bn(
30 pretrained=pretrained).features
37 self.
slice1.add_module(
str(x), vgg_pretrained_features[x])
38 for x
in range(12, 19):
39 self.
slice2.add_module(
str(x), vgg_pretrained_features[x])
40 for x
in range(19, 29):
41 self.
slice3.add_module(
str(x), vgg_pretrained_features[x])
42 for x
in range(29, 39):
43 self.
slice4.add_module(
str(x), vgg_pretrained_features[x])
46 self.
slice5 = torch.nn.Sequential(
47 nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
48 nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
49 nn.Conv2d(1024, 1024, kernel_size=1)
62 for param
in self.
slice1.parameters():
63 param.requires_grad =
False
76 vgg_outputs = namedtuple(
78 'fc7',
'relu5_3',
'relu4_3',
'relu3_2',
'relu2_2'])
79 out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)