segmentation.py
Go to the documentation of this file.
00001 from __future__ import print_function
00002 
00003 import os
00004 import os.path as osp
00005 
00006 import chainer
00007 try:
00008     import chainer_mask_rcnn as cmr
00009 except ImportError:
00010     print('chainer_mask_rcnn cannot be imported.')
00011 import cv2
00012 import numpy as np
00013 
00014 
00015 class SemanticSegmentationDataset(chainer.dataset.DatasetMixin):
00016 
00017     def __init__(self, root_dir):
00018         self.root_dir = root_dir
00019 
00020         class_names_path = osp.join(root_dir, 'class_names.txt')
00021         with open(class_names_path, 'r') as f:
00022             class_names = f.readlines()
00023         self.class_names = [name.rstrip() for name in class_names]
00024 
00025         self._images = []
00026         self._labels = []
00027         images_dir = osp.join(root_dir, 'JPEGImages')
00028         labels_dir = osp.join(root_dir, 'SegmentationClass')
00029         for image_ in sorted(os.listdir(images_dir)):
00030             image_path = osp.join(images_dir, image_)
00031             basename = image_.rstrip('.jpg')
00032             label_path = osp.join(labels_dir, basename + '.npy')
00033             self._images.append(image_path)
00034             self._labels.append(label_path)
00035 
00036     def __len__(self):
00037         return len(self._images)
00038 
00039     def get_example(self, i):
00040         image_path = self._images[i]
00041         label_path = self._labels[i]
00042 
00043         image = cv2.imread(image_path)
00044         assert image.dtype == np.uint8
00045         assert image.ndim == 3
00046 
00047         label = np.load(label_path)
00048         assert label.dtype == np.int32
00049         assert label.ndim == 2
00050 
00051         return image, label
00052 
00053 
00054 class InstanceSegmentationDataset(chainer.dataset.DatasetMixin):
00055 
00056     def __init__(self, root_dir):
00057         self.root_dir = root_dir
00058 
00059         class_names_path = osp.join(root_dir, 'class_names.txt')
00060         with open(class_names_path, 'r') as f:
00061             class_names = f.readlines()
00062         # instance id 0 is '_background_' and should be ignored.
00063         self.fg_class_names = [name.rstrip() for name in class_names][1:]
00064 
00065         self._images = []
00066         self._class_labels = []
00067         self._instance_labels = []
00068         images_dir = osp.join(root_dir, 'JPEGImages')
00069         class_labels_dir = osp.join(root_dir, 'SegmentationClass')
00070         instance_labels_dir = osp.join(root_dir, 'SegmentationObject')
00071         for image_ in sorted(os.listdir(images_dir)):
00072             image_path = osp.join(images_dir, image_)
00073             basename = image_.rstrip('.jpg')
00074             class_label_path = osp.join(
00075                 class_labels_dir, basename + '.npy')
00076             instance_label_path = osp.join(
00077                 instance_labels_dir, basename + '.npy')
00078             self._images.append(image_path)
00079             self._class_labels.append(class_label_path)
00080             self._instance_labels.append(instance_label_path)
00081 
00082     def __len__(self):
00083         return len(self._images)
00084 
00085     def get_example(self, i):
00086         image_path = self._images[i]
00087         class_label_path = self._class_labels[i]
00088         instance_label_path = self._instance_labels[i]
00089 
00090         image = cv2.imread(image_path)
00091         assert image.dtype == np.uint8
00092         assert image.ndim == 3
00093 
00094         class_label = np.load(class_label_path)
00095         assert class_label.dtype == np.int32
00096         assert class_label.ndim == 2
00097 
00098         instance_label = np.load(instance_label_path)
00099         # instance id 0 is '_background_' and should be ignored.
00100         instance_label[instance_label == 0] = -1
00101         assert instance_label.dtype == np.int32
00102         assert instance_label.ndim == 2
00103 
00104         assert image.shape[:2] == class_label.shape == instance_label.shape
00105 
00106         labels, bboxes, masks = cmr.utils.label2instance_boxes(
00107             label_instance=instance_label, label_class=class_label,
00108             return_masks=True,
00109         )
00110         masks = masks.astype(np.int32, copy=False)
00111         labels = labels.astype(np.int32, copy=False)
00112         labels -= 1  # background: 0 -> -1
00113         bboxes = bboxes.astype(np.float32, copy=False)
00114 
00115         return image, bboxes, labels, masks


jsk_recognition_utils
Author(s):
autogenerated on Tue Jul 2 2019 19:40:37