example4 - train.py
1 import os, sys, shutil
2 import keras
3 import time
4 import tensorflow as tf
5 from keras import backend as kb
6 from keras.models import *
7 from keras.layers import *
8 from keras.optimizers import *
9 from keras.callbacks import ModelCheckpoint
10 import glob
11 from PIL import Image
12 import numpy as np
13 import cv2
14 from skimage import img_as_uint
17 #============================================ T R A I N M A I N ========================================
18 root = r"./unet_flow"
19 images_path = root + r"/images"
20 models_path = root + r"/models"
21 logs_path = root + r"/logs"
22 train_images = images_path + r"/train"
23 train_cropped_images_path = images_path + r"/train_cropped"
24 paths = [root, images_path, models_path, logs_path, train_images, train_cropped_images_path]
26 print("Creating folders for training process ..")
27 for path in paths:
28  if not os.path.exists(path):
29  print("Creating ", path)
30  os.makedirs(path)
32 #============================================ T R A I N ===============================================
33 # other configuration
34 channels = 2
35 img_width, img_height = 128, 128
37 unet_epochs = 1
40 # Get the file paths
41 kb.clear_session()
42 gpus = tf.config.experimental.list_physical_devices('GPU')
43 print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
44 if gpus:
45  try:
46  # Currently, memory growth needs to be the same across GPUs
47  for gpu in gpus:
48  tf.config.experimental.set_memory_growth(gpu, True)
49  logical_gpus = tf.config.experimental.list_logical_devices('GPU')
50  print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
51  except RuntimeError as e:
52  # Memory growth must be set before GPUs have been initialized
53  print(e)
55 print("Creating and compiling Unet network .. ")
57 # save output to logs
58 old_stdout = sys.stdout
59 timestr = time.strftime("%Y%m%d-%H%M%S")
60 model_name = 'DEPTH_' + timestr + '.model'
61 name = logs_path + r'/loss_output_' + model_name + '.log'
62 log_file = open(name, "w")
63 sys.stdout = log_file
64 print('Loss function output of model :', model_name, '..')
67 input_size = (img_width, img_height, channels)
68 inputs = Input(input_size)
69 conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
70 conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
71 pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
72 conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
73 conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
74 pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
75 conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
76 conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
77 pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
78 conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
79 conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
80 drop4 = Dropout(0.5)(conv4)
81 pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
83 conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
84 conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
85 drop5 = Dropout(0.5)(conv5)
87 up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(drop5))
88 merge6 = concatenate([drop4, up6], axis=3)
89 conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
90 conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
92 up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(conv6))
93 merge7 = concatenate([conv3, up7], axis=3)
94 conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
95 conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
97 up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
98  UpSampling2D(size=(2, 2))(conv7))
99 merge8 = concatenate([conv2, up8], axis=3)
100 conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
101 conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)
103 up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(conv8))
104 merge9 = concatenate([conv1, up9], axis=3)
105 conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
106 conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
107 conv9 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
108 conv10 = Conv2D(channels, 1, activation='sigmoid')(conv9)
110 model = Model(inputs=inputs, outputs=conv10)
111 model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy'])
112 model.summary()
113 compiled_model = model
115 print("Unet network is successfully compiled !")
116 print('Preparing training data for CNN ..')
117 print('Cropping training data ..')
119 # clean dir
120 print("Clean cropped images directory ..")
121 for filename in os.listdir(train_cropped_images_path):
122  file_path = os.path.join(train_cropped_images_path, filename)
123  try:
124  if os.path.isfile(file_path) or os.path.islink(file_path):
125  os.unlink(file_path)
126  elif os.path.isdir(file_path):
127  shutil.rmtree(file_path)
128  except Exception as e:
129  print('Failed to delete %s. Reason: %s' % (file_path, e))
131 cropped_w, cropped_h = img_width, img_height
133 noisy_images = [f for f in glob.glob(train_images + "**/res*" + IMAGE_EXTENSION, recursive=True)]
134 pure_images = [f for f in glob.glob(train_images + "**/gt*" + IMAGE_EXTENSION, recursive=True)]
135 ir_images = [f for f in glob.glob(train_images + "**/left*" + IMAGE_EXTENSION, recursive=True)]
136 config_list = [(noisy_images, False), (pure_images, False), (ir_images, True)]
138 print("Cropping training images to size ", cropped_w, cropped_h)
139 for config in config_list:
140  filelist, is_ir = config
141  w, h = (cropped_w, cropped_h)
142  rolling_frame_num = 0
143  for i, file in enumerate(filelist):
144  name = os.path.basename(file)
145  name = os.path.splitext(name)[0]
146  if is_ir:
147  ii = cv2.imread(file)
148  gray_image = cv2.cvtColor(ii, cv2.COLOR_BGR2GRAY)
149  img = Image.fromarray(np.array(gray_image).astype("uint16"))
150  else:
151  img = Image.fromarray(np.array(Image.open(file)).astype("uint16"))
152  width, height = img.size
153  frame_num = 0
154  for col_i in range(0, width, w):
155  for row_i in range(0, height, h):
156  crop = img.crop((col_i, row_i, col_i + w, row_i + h))
157  save_to = os.path.join(train_cropped_images_path,
158  name + '_{:03}' + '_row_' + str(row_i) + '_col_' + str(col_i) + '_width' + str(
159  w) + '_height' + str(h) + IMAGE_EXTENSION)
160  crop.save(save_to.format(frame_num))
161  frame_num += 1
162  rolling_frame_num += frame_num
164 print("Training images are successfully cropped !")
166 save_model_name = models_path +'/' + model_name
167 images_num_to_process = 1000
168 all_cropped_num = len(os.listdir(train_cropped_images_path)) // 3 # this folder contains cropped images of pure, noisy and ir
169 iterations = all_cropped_num // images_num_to_process
170 if all_cropped_num % images_num_to_process > 0 :
171  iterations += 1
173 print('Starting a training process ..')
174 print("Create a 2-channel image from each cropped image and its corresponding IR image :")
175 print("Channel 0 : pure or noisy cropped image")
176 print("Channel 1 : corresponding IR image")
177 print("The new images are the input for Unet network.")
178 for i in range(iterations):
179  print('*************** Iteration : ', i, '****************')
180  first_image = i*images_num_to_process
181  if i == iterations-1:
182  images_num_to_process = all_cropped_num - i*images_num_to_process
184  ### convert cropped images to arrays
185  cropped_noisy_images = [f for f in glob.glob(train_cropped_images_path + "**/res*" + IMAGE_EXTENSION, recursive=True)]
186  cropped_pure_images = [f for f in glob.glob(train_cropped_images_path + "**/gt*" + IMAGE_EXTENSION, recursive=True)]
187  cropped_ir_images = [f for f in glob.glob(train_cropped_images_path + "**/left*" + IMAGE_EXTENSION, recursive=True)]
189  cropped_images_list = [(cropped_noisy_images, "noisy"), (cropped_pure_images, "pure")]
191  for curr in cropped_images_list:
192  curr_cropped_images, images_type = curr
193  im_files, ir_im_files = [], []
194  curr_cropped_images.sort()
196  limit = first_image + images_num_to_process
197  if first_image + images_num_to_process > len(curr_cropped_images):
198  limit = len(curr_cropped_images)
200  for i in range(first_image, limit):
201  path = os.path.join(train_cropped_images_path, curr_cropped_images[i])
202  if os.path.isdir(path):
203  # skip directories
204  continue
205  im_files.append(path)
206  cropped_ir_images.sort()
208  for i in range(first_image, limit):
209  path = os.path.join(train_cropped_images_path, cropped_ir_images[i])
210  if os.path.isdir(path):
211  # skip directories
212  continue
213  ir_im_files.append(path)
215  im_files.sort()
216  ir_im_files.sort()
217  images_plt = [cv2.imread(f, cv2.IMREAD_UNCHANGED) for f in im_files if f.endswith(IMAGE_EXTENSION)]
218  ir_images_plt = [cv2.imread(f, cv2.IMREAD_UNCHANGED) for f in ir_im_files if f.endswith(IMAGE_EXTENSION)]
219  images_plt = np.array(images_plt)
220  ir_images_plt = np.array(ir_images_plt)
221  images_plt = images_plt.reshape(images_plt.shape[0], img_width, img_height, 1)
222  ir_images_plt = ir_images_plt.reshape(ir_images_plt.shape[0], img_width, img_height, 1)
224  im_and_ir = images_plt
225  if channels > 1:
226  im_and_ir = np.stack((images_plt, ir_images_plt), axis=3)
227  im_and_ir = im_and_ir.reshape(im_and_ir.shape[0], img_width, img_height, channels)
229  # convert your lists into a numpy array of size (N, H, W, C)
230  img = np.array(im_and_ir)
231  # Parse numbers as floats
232  img = img.astype('float32')
233  # Normalize data : remove average then devide by standard deviation
234  img = (img - np.average(img)) / np.var(img)
236  if images_type == "pure":
237  pure_input_train = img
238  else:
239  noisy_input_train = img
241  # Start training Unet network
242  model_checkpoint = ModelCheckpoint(models_path + r"/unet_membrane.hdf5", monitor='loss', verbose=1, save_best_only=True)
243  steps_per_epoch = len(cropped_noisy_images) // unet_epochs
244  model.fit(noisy_input_train, pure_input_train,
245  steps_per_epoch=steps_per_epoch,
246  epochs=unet_epochs,
247  callbacks=[model_checkpoint])
249  # save the model
250  compiled_model.save(save_model_name)
251  compiled_model = keras.models.load_model(save_model_name)
253 sys.stdout = old_stdout
254 log_file.close()
255 print("Training process is done successfully !")
256 print("Check log {} for more details".format(name))
257 #============================================ T E S T M A I N =========================================
259 origin_files_index_size_path_test = {}
260 test_img_width, test_img_height = 480, 480
261 img_width, img_height = test_img_width, test_img_height
262 test_model_name = save_model_name
264 test_images = images_path + r"/test"
265 test_cropped_images_path = images_path + r"/test_cropped"
266 denoised_dir = images_path + r"/denoised"
267 paths = [test_images, test_cropped_images_path, denoised_dir]
269 print("Creating folders for testing process ..")
270 for path in paths:
271  if not os.path.exists(path):
272  print("Creating ", path)
273  os.makedirs(path)
275 #================================= S T A R T T E S T I N G ==========================================
276 old_stdout = sys.stdout
277 try:
278  model = keras.models.load_model(test_model_name)
279 except Exception as e:
280  print('Failed to load model %s. Reason: %s' % (test_model_name, e))
283 print('Testing model', str(os.path.basename(test_model_name).split('.')[0]), '..')
284 name = logs_path + '/output_' + str(test_model_name.split('.')[-1]) + '.log'
285 print("Check log {} for more details".format(name))
287 log_file = open(name, "w")
288 sys.stdout = log_file
289 print('prediction time : ')
291 # clean directory before processing
292 for filename in os.listdir(test_cropped_images_path):
293  file_path = os.path.join(test_cropped_images_path, filename)
294  try:
295  if os.path.isfile(file_path) or os.path.islink(file_path):
296  os.unlink(file_path)
297  elif os.path.isdir(file_path):
298  shutil.rmtree(file_path)
299  except Exception as e:
300  print('Failed to delete %s. Reason: %s' % (file_path, e))
303 noisy_images = [f for f in glob.glob(test_images + "**/res*" + IMAGE_EXTENSION, recursive=True)]
304 ir_images = [f for f in glob.glob(test_images + "**/left*" + IMAGE_EXTENSION, recursive=True)]
306 total_cropped_images = [0]*len(noisy_images)
307 ir_total_cropped_images = [0]*len(ir_images)
309 ########### SPLIT IMAGES ##################
311 print("Crop testing images to sizes of ", test_img_width, test_img_height)
312 ir_config = (ir_images, ir_total_cropped_images, True, {})
313 noisy_config = (noisy_images, total_cropped_images, False, origin_files_index_size_path_test)
314 config_list = [ir_config, noisy_config]
316 for config in config_list:
317  filelist, total_cropped_images, is_ir, origin_files_index_size_path = list(config)
318  for idx, file in enumerate(filelist):
319  w, h = (test_img_width, test_img_height)
320  rolling_frame_num = 0
321  name = os.path.basename(file)
322  name = os.path.splitext(name)[0]
324  if not os.path.exists(test_cropped_images_path + r'/' + name):
325  os.makedirs(test_cropped_images_path + r'/' + name)
326  new_test_cropped_images_path = test_cropped_images_path + r'/' + name
328  if is_ir:
329  # ir images has 3 similar channels, we need only 1 channel
330  ii = cv2.imread(file)
331  gray_image = cv2.cvtColor(ii, cv2.COLOR_BGR2GRAY)
332  img = Image.fromarray(np.array(gray_image).astype("uint16"))
333  #img = np.array(gray_image).astype("uint16")
334  else:
335  img = Image.fromarray(np.array(Image.open(file)).astype("uint16"))
336  #img = np.array(Image.open(file)).astype("uint16")
337  #cv2.imwrite(r"C:\Users\user\Documents\test_unet_flow\images\im"+str(idx)+".png", img)
339  width, height = 848, 480 #img.size
340  frame_num = 0
341  for col_i in range(0, width, w):
342  for row_i in range(0, height, h):
343  crop = img.crop((col_i, row_i, col_i + w, row_i + h))
344  #crop = img[row_i:row_i+h, col_i:col_i+w]
345  save_to = os.path.join(new_test_cropped_images_path, name + '_{:03}' + '_row_' + str(row_i) + '_col_' + str(col_i) + '_width' + str(w) + '_height' + str(h) + IMAGE_EXTENSION)
346  crop.save(save_to.format(frame_num))
347  #cv2.imwrite(save_to.format(frame_num), crop)
348  frame_num += 1
349  origin_files_index_size_path[idx] = (rolling_frame_num, width, height, file)
351  total_cropped_images[idx] = frame_num
354 ########### IMAGE TO ARRAY ##################
355 cropped_noisy_images = [f for f in glob.glob(test_cropped_images_path + "**/res*" , recursive=True)]
356 cropped_noisy_images.sort()
357 for i,directory in enumerate(cropped_noisy_images):
359  cropped_image_offsets = []
360  ir_cropped_images_file = test_cropped_images_path + r'/' + 'left-' + str(directory.split('-')[-1])
362  cropped_w, cropped_h = test_img_width, test_img_height
363  im_files = [f for f in glob.glob(directory + "**/res*" , recursive=True)]
364  ir_im_files = [f for f in glob.glob(ir_cropped_images_file + "**/left*" , recursive=True)]
366  im_files.sort()
367  ir_im_files.sort()
369  for i in range(len(im_files)):
370  cropped_image_offsets.append([os.path.basename(im_files[i]).split('_')[3], os.path.basename(im_files[i]).split('_')[5]])
372  images_plt = [cv2.imread(f, cv2.IMREAD_UNCHANGED) for f in im_files if
373  f.endswith(IMAGE_EXTENSION)]
374  ir_images_plt = [cv2.imread(f, cv2.IMREAD_UNCHANGED) for f in ir_im_files if
375  f.endswith(IMAGE_EXTENSION)]
377  images_plt = np.array(images_plt)
378  ir_images_plt = np.array(ir_images_plt)
379  images_plt = images_plt.reshape(images_plt.shape[0], cropped_w, cropped_h, 1)
380  ir_images_plt = ir_images_plt.reshape(ir_images_plt.shape[0], cropped_w, cropped_h, 1)
382  im_and_ir = images_plt
383  if channels > 1:
384  im_and_ir = np.stack((images_plt, ir_images_plt), axis=3)
385  im_and_ir = im_and_ir.reshape(im_and_ir.shape[0], cropped_w, cropped_h, channels)
387  img = np.array(im_and_ir)
388  # Parse numbers as floats
389  img = img.astype('float32')
391  # Normalize data : remove average then devide by standard deviation
392  img = (img - np.average(img)) / np.var(img)
393  samples = img
395  rolling_frame_num, width, height, origin_file_name = origin_files_index_size_path_test[i]
396  cropped_w, cropped_h = test_img_width, test_img_height
397  whole_image = np.zeros((height, width, channels), dtype="float32")
399  t1 = time.perf_counter()
400  for i in range(total_cropped_images[i]):
401  # testing
402  sample = samples[i:i+1]
403  row, col = cropped_image_offsets[i]
404  row, col = int(row), int(col)
405  denoised_image = model.predict(sample)
406  row_end = row + cropped_h
407  col_end = col + cropped_w
408  denoised_row = cropped_h
409  denoised_col = cropped_w
410  if row + cropped_h >= height:
411  row_end = height-1
412  denoised_row = abs(row-row_end)
413  if col + cropped_w >= width:
414  col_end = width-1
415  denoised_col = abs(col - col_end)
416  # combine tested images
417  whole_image[row:row_end, col:col_end]= denoised_image[:, 0:denoised_row,0:denoised_col, :]
418  t2 = time.perf_counter()
419  print('test: ', os.path.basename(directory.split('/')[-1]), ': ', t2 - t1, 'seconds')
420  denoised_name = os.path.basename(directory.split('/')[-1])
421  outfile = denoised_dir + '/' + denoised_name.split('-')[0] + '' + '_denoised-' + denoised_name.split('-')[1] + IMAGE_EXTENSION
422  whole_image = img_as_uint(whole_image)
423  cv2.imwrite(outfile, whole_image[:,:,0])
424 sys.stdout = old_stdout
425 log_file.close()
426 print("Testing process is done successfully !")
