Go to the documentation of this file.00001
00002
00003 from __future__ import absolute_import
00004 from __future__ import division
00005 from __future__ import print_function
00006
00007 import chainer.serializers as S
00008 from jsk_recognition_msgs.msg import ClassificationResult
00009 from jsk_recognition_utils.chainermodels import AlexNet
00010 from jsk_recognition_utils.chainermodels import AlexNetBatchNormalization
00011 import rospy
00012 from sensor_msgs.msg import Image
00013 from vgg16_object_recognition import VGG16ObjectRecognition
00014
00015
00016 class AlexNetObjectRecognition(VGG16ObjectRecognition):
00017
00018 def __init__(self):
00019 super(VGG16ObjectRecognition, self).__init__()
00020 self.insize = 227
00021 self.gpu = rospy.get_param('~gpu', -1)
00022 self.target_names = rospy.get_param('~target_names')
00023 self.model_name = rospy.get_param('~model_name')
00024 if self.model_name == 'alexnet':
00025 self.model = AlexNet(n_class=len(self.target_names))
00026 elif self.model_name == 'alexnet_batch_normalization':
00027 self.model = AlexNetBatchNormalization(
00028 n_class=len(self.target_names))
00029 else:
00030 rospy.logerr('Unsupported ~model_name: {0}'
00031 .format(self.model_name))
00032 model_file = rospy.get_param('~model_file')
00033 S.load_hdf5(model_file, self.model)
00034 if self.gpu != -1:
00035 self.model.to_gpu(self.gpu)
00036 self.pub = self.advertise('~output', ClassificationResult,
00037 queue_size=1)
00038 self.pub_input = self.advertise(
00039 '~debug/net_input', Image, queue_size=1)
00040
00041
00042 if __name__ == '__main__':
00043 rospy.init_node('alexnet_object_recognition')
00044 app = AlexNetObjectRecognition()
00045 rospy.spin()