craft.py
Go to the documentation of this file.
1 # -*- coding: utf-8 -*-
2 
3 """
4 Copyright (c) 2019-present NAVER Corp.
5 MIT License
6 """
7 
8 from __future__ import absolute_import
9 
10 import torch
11 import torch.nn as nn
12 import torch.nn.functional as F
13 
14 from craft.basenet.vgg16_bn import init_weights
15 from craft.basenet.vgg16_bn import vgg16_bn
16 
17 
18 class double_conv(nn.Module):
19  def __init__(self, in_ch, mid_ch, out_ch):
20  super(double_conv, self).__init__()
21  self.conv = nn.Sequential(
22  nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
23  nn.BatchNorm2d(mid_ch),
24  nn.ReLU(inplace=True),
25  nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
26  nn.BatchNorm2d(out_ch),
27  nn.ReLU(inplace=True)
28  )
29 
30  def forward(self, x):
31  x = self.conv(x)
32  return x
33 
34 
35 class CRAFT(nn.Module):
36  def __init__(self, pretrained=False, freeze=False):
37  super(CRAFT, self).__init__()
38 
39  """ Base network """
40  self.basenet = vgg16_bn(pretrained, freeze)
41 
42  """ U network """
43  self.upconv1 = double_conv(1024, 512, 256)
44  self.upconv2 = double_conv(512, 256, 128)
45  self.upconv3 = double_conv(256, 128, 64)
46  self.upconv4 = double_conv(128, 64, 32)
47 
48  num_class = 2
49  self.conv_cls = nn.Sequential(
50  nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
51  nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
52  nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
53  nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
54  nn.Conv2d(16, num_class, kernel_size=1),
55  )
56 
57  init_weights(self.upconv1.modules())
58  init_weights(self.upconv2.modules())
59  init_weights(self.upconv3.modules())
60  init_weights(self.upconv4.modules())
61  init_weights(self.conv_cls.modules())
62 
63  def forward(self, x):
64  """Base network
65 
66  """
67  sources = self.basenet(x)
68 
69  """ U network """
70  y = torch.cat([sources[0], sources[1]], dim=1)
71  y = self.upconv1(y)
72 
73  y = F.interpolate(
74  y,
75  size=sources[2].size()[
76  2:],
77  mode='bilinear',
78  align_corners=False)
79  y = torch.cat([y, sources[2]], dim=1)
80  y = self.upconv2(y)
81 
82  y = F.interpolate(
83  y,
84  size=sources[3].size()[
85  2:],
86  mode='bilinear',
87  align_corners=False)
88  y = torch.cat([y, sources[3]], dim=1)
89  y = self.upconv3(y)
90 
91  y = F.interpolate(
92  y,
93  size=sources[4].size()[
94  2:],
95  mode='bilinear',
96  align_corners=False)
97  y = torch.cat([y, sources[4]], dim=1)
98  feature = self.upconv4(y)
99 
100  y = self.conv_cls(feature)
101 
102  return y.permute(0, 2, 3, 1), feature
103 
104 
105 if __name__ == '__main__':
106  model = CRAFT(pretrained=True).cuda()
107  output, _ = model(torch.randn(1, 3, 768, 768).cuda())
108  print(output.shape)
node_scripts.craft.craft.double_conv.forward
def forward(self, x)
Definition: craft.py:30
node_scripts.craft.craft.CRAFT.upconv2
upconv2
Definition: craft.py:44
node_scripts.craft.craft.CRAFT.__init__
def __init__(self, pretrained=False, freeze=False)
Definition: craft.py:36
node_scripts.craft.craft.CRAFT.upconv3
upconv3
Definition: craft.py:45
node_scripts.craft.craft.CRAFT
Definition: craft.py:35
node_scripts.craft.craft.CRAFT.forward
def forward(self, x)
Definition: craft.py:63
node_scripts.craft.craft.model
model
Definition: craft.py:106
node_scripts.craft.craft.double_conv.conv
conv
Definition: craft.py:21
node_scripts.craft.craft.double_conv.__init__
def __init__(self, in_ch, mid_ch, out_ch)
Definition: craft.py:19
node_scripts.craft.craft.CRAFT.conv_cls
conv_cls
Definition: craft.py:49
node_scripts.craft.basenet.vgg16_bn.vgg16_bn
Definition: vgg16_bn.py:24
node_scripts.craft.craft.CRAFT.upconv1
upconv1
Definition: craft.py:43
node_scripts.craft.craft.CRAFT.upconv4
upconv4
Definition: craft.py:46
node_scripts.craft.craft.CRAFT.basenet
basenet
Definition: craft.py:40
node_scripts.craft.basenet.vgg16_bn.init_weights
def init_weights(modules)
Definition: vgg16_bn.py:10
node_scripts.craft.craft.double_conv
Definition: craft.py:18


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Fri May 16 2025 03:11:16