在 【学习分享6】医学图像分类处理和模型训练实践 的环境基础上,可以进行本章的实战,不过还需要安装一下如下的python库:
pip install SimpleITK==2.0.2
# 安装xvfb
sudo apt install -y xvfb
# 开启xvfb
Xvfb :0 -screen 0 1280x960x24 -listen tcp -ac +extension GLX +extension RENDER &
# 设置默认显示
export DISPLAY=:0
# 安装x11vnc
sudo apt install -y x11vnc
# 启动x11vnc
sudo x11vnc -display :0 -forever -shared -rfbport 端口 -passwd 密码 &
import matplotlib.pyplot as plt
import numpy as np
# 生成数据
x = np.arange(0, 10, 0.1) # 横坐标数据为从0到10之间,步长为0.1的等差数组
y = np.sin(x) # 纵坐标数据为 x 对应的 sin(x) 值
# 生成图形
plt.plot(x, y)
# 显示图形
本章的测试数据集,使用的而是MRI公开数据集,从 Medical Segmentation Decathlon 下载,然后解压。
tar xvf Task01_BrainTumour.tar
ls -l Task01_BrainTumour
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import json
from skimage import measure
import random
# import tensorflow as tf
import tensorflow.compat.v1 as tf
from tensorflow.keras.layers import Layer
from tensorflow import keras
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, ReLU, UpSampling3D, Input, MaxPool3D, Concatenate
from tensorflow.keras import Model
def z_score_norm(img, mean=None, std=None):
if mean is None:
mean = np.mean(img)
if std is None:
std = np.std(img)
return (img - mean) / (std + 1e-6)
def show_imgs(images, cols=5):
rows = len(images)
titles = ['FLAIR', 'T1w', 'T1gd', 'T2w', 'Label']
f, axes = plt.subplots(max(1, rows), cols)
axes_ravel = axes.ravel()
for i, (image, label) in enumerate(images):
ds = np.where(label != 0)[0]
slice = ds[len(ds) // 2]
for j in range(cols-1):
axes_ravel[i*cols+j].imshow(image[slice, :, :, j], cmap='Greys_r')
axes_ravel[(i+1)*cols-1].imshow(label[slice, :, :], cmap='Greys_r')
plt.subplots_adjust(wspace=0.01, hspace=0)
def read_img(img_path, label_path=None):
img_itk = sitk.ReadImage(img_path)
img_np = sitk.GetArrayFromImage(img_itk)
img_np = np.moveaxis(img_np, 0, 3)
if label_path is not None:
label_itk = sitk.ReadImage(label_path)
label_np = sitk.GetArrayFromImage(label_itk)
return img_np, label_np
return img_np
img_path = '../data/Task01_BrainTumour/imagesTr/BRATS_001.nii.gz'
label_path = '../data/Task01_BrainTumour/labelsTr/BRATS_001.nii.gz'
img_np, label_np = read_img(img_path, label_path)
show_imgs([[img_np, label_np]])
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import json
from skimage import measure
import random
# import tensorflow as tf
import tensorflow.compat.v1 as tf
from tensorflow.keras.layers import Layer
from tensorflow import keras
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, ReLU, UpSampling3D, Input, MaxPool3D, Concatenate
from tensorflow.keras import Model
def z_score_norm(img, mean=None, std=None):
if mean is None:
mean = np.mean(img)
if std is None:
std = np.std(img)
return (img - mean) / (std + 1e-6)
def show_imgs(images, cols=5):
rows = len(images)
titles = ['FLAIR', 'T1w', 'T1gd', 'T2w', 'Label']
f, axes = plt.subplots(max(1, rows), cols)
axes_ravel = axes.ravel()
for i, (image, label) in enumerate(images):
ds = np.where(label != 0)[0]
slice = ds[len(ds) // 2]
for j in range(cols-1):
axes_ravel[i*cols+j].imshow(image[slice, :, :, j], cmap='Greys_r')
axes_ravel[(i+1)*cols-1].imshow(label[slice, :, :], cmap='Greys_r')
plt.subplots_adjust(wspace=0.01, hspace=0)
def read_img(img_path, label_path=None):
img_itk = sitk.ReadImage(img_path)
img_np = sitk.GetArrayFromImage(img_itk)
img_np = np.moveaxis(img_np, 0, 3)
if label_path is not None:
label_itk = sitk.ReadImage(label_path)
label_np = sitk.GetArrayFromImage(label_itk)
return img_np, label_np
return img_np
def extract_ordered_overlap_patches(img, label, patch_size, s=16):
img = img[:, 20:-20, 20:-20, :]
d, h, w, _ = img.shape
patch_d, patch_h, patch_w = patch_size
sd = d - patch_d
sh = h - patch_h
sw = w - patch_w
std = s
sth = s*2
stw = s*2
patch_list = []
pos_list = []
if label is not None:
label = label[:, 20:-20, 20:-20]
for i in range(sd // std + 1):
for j in range(sh // sth + 1):
for k in range(sw//stw + 1):
patch_img = img[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w), :]
if label is not None:
patch_label = label[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w)]
if patch_label.shape != tuple(patch_size):
if np.count_nonzero(patch_label)/np.count_nonzero(label) >= 0.2:
patch_list.append((patch_img, patch_label))
pos_list.append((i, j, k))
pos_list.append((i, j, k))
return patch_list, pos_list
# pre-process images
def preprocess(image):
# z-score normalization in each slice and each channel
for i in range(image.shape[3]):
for z in range(image.shape[0]):
img_slice = image[z, :, :, i]
image[z, :, :, i] = z_score_norm(img_slice)
return image
img_path = '../data/Task01_BrainTumour/imagesTr/BRATS_001.nii.gz'
label_path = '../data/Task01_BrainTumour/labelsTr/BRATS_001.nii.gz'
img_np, label_np = read_img(img_path, label_path)
show_imgs([[img_np, label_np]])
patch_size = (32, 160, 160)
patch_list, _ = extract_ordered_overlap_patches(img_np, label_np, patch_size)
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import json
from skimage import measure
import random
# import tensorflow as tf
import tensorflow.compat.v1 as tf
from tensorflow.keras.layers import Layer
from tensorflow import keras
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, ReLU, UpSampling3D, Input, MaxPool3D, Concatenate
from tensorflow.keras import Model
def z_score_norm(img, mean=None, std=None):
if mean is None:
mean = np.mean(img)
if std is None:
std = np.std(img)
return (img - mean) / (std + 1e-6)
def show_imgs(images, cols=5):
rows = len(images)
titles = ['FLAIR', 'T1w', 'T1gd', 'T2w', 'Label']
f, axes = plt.subplots(max(1, rows), cols)
axes_ravel = axes.ravel()
for i, (image, label) in enumerate(images):
ds = np.where(label != 0)[0]
slice = ds[len(ds) // 2]
for j in range(cols-1):
axes_ravel[i*cols+j].imshow(image[slice, :, :, j], cmap='Greys_r')
axes_ravel[(i+1)*cols-1].imshow(label[slice, :, :], cmap='Greys_r')
plt.subplots_adjust(wspace=0.01, hspace=0)
def read_img(img_path, label_path=None):
img_itk = sitk.ReadImage(img_path)
img_np = sitk.GetArrayFromImage(img_itk)
img_np = np.moveaxis(img_np, 0, 3)
if label_path is not None:
label_itk = sitk.ReadImage(label_path)
label_np = sitk.GetArrayFromImage(label_itk)
return img_np, label_np
return img_np
def extract_ordered_overlap_patches(img, label, patch_size, s=16):
img = img[:, 20:-20, 20:-20, :]
d, h, w, _ = img.shape
patch_d, patch_h, patch_w = patch_size
sd = d - patch_d
sh = h - patch_h
sw = w - patch_w
std = s
sth = s*2
stw = s*2
patch_list = []
pos_list = []
if label is not None:
label = label[:, 20:-20, 20:-20]
for i in range(sd // std + 1):
for j in range(sh // sth + 1):
for k in range(sw//stw + 1):
patch_img = img[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w), :]
if label is not None:
patch_label = label[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w)]
if patch_label.shape != tuple(patch_size):
if np.count_nonzero(patch_label)/np.count_nonzero(label) >= 0.2:
patch_list.append((patch_img, patch_label))
pos_list.append((i, j, k))
pos_list.append((i, j, k))
return patch_list, pos_list
# pre-process images
def preprocess(image):
# z-score normalization in each slice and each channel
for i in range(image.shape[3]):
for z in range(image.shape[0]):
img_slice = image[z, :, :, i]
image[z, :, :, i] = z_score_norm(img_slice)
return image
# data loader
def data_generator(data_dir, path_list, target_shape, batch_size, is_training, buffer_size=8):
if not is_training:
buffer_size = 1
buffer_size = min(len(path_list), buffer_size)
k = len(path_list) // buffer_size
for i in range(k):
data_list = []
for j in range(i*buffer_size, (i+1)*buffer_size):
img_path = path_list[j]['image'].replace('./', data_dir)
label_path = path_list[j]['label'].replace('./', data_dir)
if not os.path.exists(img_path) or not os.path.exists(label_path):
img, label = read_img(img_path, label_path)
img = preprocess(img)
patch_list, _ = extract_ordered_overlap_patches(img, label, target_shape)
data_list += patch_list
X = np.array([it[0] for it in data_list])
Y = np.array([it[1] for it in data_list])
if is_training:
index = np.random.permutation(len(data_list))
X = X[index, ...]
Y = Y[index, ...]
for step in range(X.shape[0]//batch_size-1):
x = X[step * batch_size:(step + 1) * batch_size, ...]
y = Y[step * batch_size:(step + 1) * batch_size, ...]
yield x, y
img_path = '../data/Task01_BrainTumour/imagesTr/BRATS_001.nii.gz'
label_path = '../data/Task01_BrainTumour/labelsTr/BRATS_001.nii.gz'
img_np, label_np = read_img(img_path, label_path)
show_imgs([[img_np, label_np]])
patch_size = (32, 160, 160)
patch_list, _ = extract_ordered_overlap_patches(img_np, label_np, patch_size)
data_dir = '../data/Task01_BrainTumour/'
with open(os.path.join(data_dir, 'dataset.json'), 'r') as f:
data = json.load(f)
train_path_list = data['training']
# patch_size = (32, 160, 160)
test_data = data_generator(data_dir, train_path_list[:3], patch_size, 1, False, 1)
for x,y in test_data:
print(x.shape, y.shape)
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import json
from skimage import measure
import random
# import tensorflow as tf
import tensorflow.compat.v1 as tf
from tensorflow.keras.layers import Layer
from tensorflow import keras
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, ReLU, UpSampling3D, Input, MaxPool3D, Concatenate
from tensorflow.keras import Model
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
def z_score_norm(img, mean=None, std=None):
if mean is None:
mean = np.mean(img)
if std is None:
std = np.std(img)
return (img - mean) / (std + 1e-6)
def show_imgs(images, cols=5):
rows = len(images)
titles = ['FLAIR', 'T1w', 'T1gd', 'T2w', 'Label']
f, axes = plt.subplots(max(1, rows), cols)
axes_ravel = axes.ravel()
for i, (image, label) in enumerate(images):
ds = np.where(label != 0)[0]
slice = ds[len(ds) // 2]
for j in range(cols-1):
axes_ravel[i*cols+j].imshow(image[slice, :, :, j], cmap='Greys_r')
axes_ravel[(i+1)*cols-1].imshow(label[slice, :, :], cmap='Greys_r')
plt.subplots_adjust(wspace=0.01, hspace=0)
def read_img(img_path, label_path=None):
img_itk = sitk.ReadImage(img_path)
img_np = sitk.GetArrayFromImage(img_itk)
img_np = np.moveaxis(img_np, 0, 3)
if label_path is not None:
label_itk = sitk.ReadImage(label_path)
label_np = sitk.GetArrayFromImage(label_itk)
return img_np, label_np
return img_np
def extract_ordered_overlap_patches(img, label, patch_size, s=16):
img = img[:, 20:-20, 20:-20, :]
d, h, w, _ = img.shape
patch_d, patch_h, patch_w = patch_size
sd = d - patch_d
sh = h - patch_h
sw = w - patch_w
std = s
sth = s*2
stw = s*2
patch_list = []
pos_list = []
if label is not None:
label = label[:, 20:-20, 20:-20]
for i in range(sd // std + 1):
for j in range(sh // sth + 1):
for k in range(sw//stw + 1):
patch_img = img[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w), :]
if label is not None:
patch_label = label[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w)]
if patch_label.shape != tuple(patch_size):
if np.count_nonzero(patch_label)/np.count_nonzero(label) >= 0.2:
patch_list.append((patch_img, patch_label))
pos_list.append((i, j, k))
pos_list.append((i, j, k))
return patch_list, pos_list
# pre-process images
def preprocess(image):
# z-score normalization in each slice and each channel
for i in range(image.shape[3]):
for z in range(image.shape[0]):
img_slice = image[z, :, :, i]
image[z, :, :, i] = z_score_norm(img_slice)
return image
# data loader
def data_generator(data_dir, path_list, target_shape, batch_size, is_training, buffer_size=8):
if not is_training:
buffer_size = 1
buffer_size = min(len(path_list), buffer_size)
k = len(path_list) // buffer_size
for i in range(k):
data_list = []
for j in range(i*buffer_size, (i+1)*buffer_size):
img_path = path_list[j]['image'].replace('./', data_dir)
label_path = path_list[j]['label'].replace('./', data_dir)
if not os.path.exists(img_path) or not os.path.exists(label_path):
img, label = read_img(img_path, label_path)
img = preprocess(img)
patch_list, _ = extract_ordered_overlap_patches(img, label, target_shape)
data_list += patch_list
X = np.array([it[0] for it in data_list])
Y = np.array([it[1] for it in data_list])
if is_training:
index = np.random.permutation(len(data_list))
X = X[index, ...]
Y = Y[index, ...]
for step in range(X.shape[0]//batch_size-1):
x = X[step * batch_size:(step + 1) * batch_size, ...]
y = Y[step * batch_size:(step + 1) * batch_size, ...]
yield x, y
class GroupNorm(Layer):
def __init__(self, groups=4):
super(GroupNorm, self).__init__()
self.G = groups
self.eps = 1e-5
def build(self, input_shape):
self.group = self.G if input_shape[-1] % self.G == 0 else 1
self.channel = input_shape[-1]
self.group = min(self.channel, self.group)
self.split = self.channel // self.group
self.gamma = self.add_weight(name='gamma_gn', shape=(1, 1, 1, 1,
input_shape[-1]), initializer='ones', trainable=True)
self.beta = self.add_weight(name='beta_gn', shape=(1, 1, 1, 1,
input_shape[-1]), initializer='zeros', trainable=True)
def call(self, inputs):
N, D, H, W, C = tf.keras.backend.int_shape(inputs)
inputs = tf.reshape(inputs, [-1, D, H, W, self.group, self.split])
mean, var = tf.nn.moments(inputs, [1, 2, 3, 5])
mean = tf.reshape(mean, [-1, 1, 1, 1, self.group, 1])
var = tf.reshape(var, [-1, 1, 1, 1, self.group, 1])
ipt = (inputs - mean) / tf.sqrt(var + self.eps)
output = tf.reshape(inputs, [-1, D, H, W, C]) * self.gamma + self.beta
return output
def conv_res_block(x, filters, activation=ReLU(), kernel_size=3, strides=1, padding='same', num_layers=1):
if num_layers == 1:
x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(x)
x = GroupNorm()(x)
x = activation(x)
return x
shortcut = Conv3D(filters, 1, strides=strides, padding=padding)(x)
shortcut = GroupNorm()(shortcut)
shortcut = activation(shortcut)
for i in range(num_layers):
x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(x)
x = activation(x)
x = x + shortcut
return x
def upsample_block(x, filters, activation=ReLU(), kernel_size=3, strides=1,
padding='same', deconv=False):
if deconv:
x = Conv3DTranspose(filters, 2, strides=2, padding=padding)(x)
x = UpSampling3D(size=2)(x)
x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(x)
x = GroupNorm()(x)
x = activation(x)
return x
# unet3d model
def Unet3d(img_shape, n_filters, n_class):
inputs = Input(shape=img_shape, name='input')
l1 = conv_res_block(inputs, n_filters, num_layers=1)
m1 = MaxPool3D()(l1)
l2 = conv_res_block(m1, n_filters * 2, num_layers=2)
m2 = MaxPool3D()(l2)
l3 = conv_res_block(m2, n_filters * 4, num_layers=3)
m3 = MaxPool3D()(l3)
l4 = conv_res_block(m3, n_filters * 8, num_layers=3)
m4 = MaxPool3D()(l4)
l5 = conv_res_block(m4, n_filters * 16, num_layers=3)
up6 = upsample_block(l5, n_filters * 8)
l6 = conv_res_block(Concatenate()([up6, l4]), n_filters * 8, num_layers=3)
up7 = upsample_block(l6, n_filters * 4)
l7 = conv_res_block(Concatenate()([up7, l3]), n_filters * 4, num_layers=3)
up8 = upsample_block(l7, n_filters * 2)
l8 = conv_res_block(Concatenate()([up8, l2]), n_filters * 2, num_layers=2)
up9 = upsample_block(l8, n_filters)
l9 = conv_res_block(Concatenate()([up9, l1]), n_filters, num_layers=1)
out = Conv3D(n_class, 1, padding='same', activation=keras.activations.softmax)(l9)
model = Model(inputs=inputs, outputs=out, name='output')
return model
# loss function
def explog_loss(y_true, y_pred, n_class, weights=1., w_d=0.8, w_c=0.2, g_d=0.3,
g_c=0.3, eps=1e-5):
Compute exp-log loss
y_true: ground truth with dimension of (batch, depth, height, width)
y_pred: prediction with dimension of (batch, depth, height, width, n_class)
n_class: classes
weights: weights of n classes, a float number or vector with dimension of (n,)
w_d: weight of dice loss
w_c: weight of cross entropy loss
g_d: exponent of dice loss
g_c: exponent of cross entropy loss
score: exp-log loss
y_pred = tf.cast(y_pred, tf.float32)
y_true = tf.cast(tf.one_hot(y_true, n_class), tf.float32)
y_true = tf.reshape(y_true, [-1, n_class])
y_pred = tf.reshape(y_pred, [-1, n_class])
y_pred = tf.clip_by_value(y_pred, eps, 1.0-eps)
intersection = tf.reduce_sum(y_true * y_pred, axis=0)
union = tf.reduce_sum(y_true, axis=0) + tf.reduce_sum(y_pred, axis=0)
dice = (2 * intersection + eps) / (union + eps)
dice = tf.clip_by_value(dice, eps, 1.0-eps)
dice_log_loss = -tf.math.log(dice)
Ld = tf.reduce_mean(tf.pow(dice_log_loss, g_d))
wce = weights * y_true * tf.pow(-tf.math.log(y_pred), g_c)
Lc = tf.reduce_mean(wce)
score = w_d * Ld + w_c * Lc
return score
# metrics
def dice_score(y_true, y_pred, n_class, exp=1e-5):
Compute dice score with ground truth and prediction without argmax
y_true: ground truth with dimension of (batch, depth, height, width)
y_pred: prediction with dimension of (batch, depth, height, width, n_class)
n_class: classes number
score: average dice score in n classes
dices = []
y_pred = np.argmax(y_pred, axis=-1)
for i in range(1, n_class):
pred = y_pred == i
label = y_true == i
intersection = 2 * np.sum(label * pred, axis=(1, 2, 3)) + exp
union = np.sum(label, axis=(1, 2, 3)) + np.sum(pred, axis=(1, 2, 3)) + exp
dice = intersection / union
score = np.mean(dices)
return score
# training process
def train(display_step=10):
batch_size = 1
epochs = 500
input_size = [32, 160, 160, 4]
n_class = 4
first_channels = 8
lr = 0.001
save_model_dir = '../saved_models/'
data_dir = '../data/Task01_BrainTumour/'
with open(os.path.join(data_dir, 'dataset.json'), 'r') as f:
data_info = json.load(f)
path_list = data_info['training']
n_sample = len(path_list)
train_path_list = path_list[:int(n_sample*0.8)]
val_path_list = path_list[int(n_sample * 0.8):]
model = Unet3d(input_size, first_channels, n_class)
input = model.input
pred = model.output
label = tf.placeholder(tf.int32, shape=[None] + input_size[:3])
loss_tf = explog_loss(label, pred, n_class, weights=[1, 10, 20, 20])
global_step = tf.Variable(0, name='global_step', trainable=False)
lr_schedule = tf.train.exponential_decay(
optimizer = tf.train.AdamOptimizer(learning_rate=lr_schedule)
train_opt = optimizer.minimize(loss_tf, global_step=global_step)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
print("start optimize")
config = tf.compat.v1.ConfigProto()
with tf.Session(config=config) as sess:
for epoch in range(epochs):
# training
print('*'*20, 'Train Epoch %d'%epoch, '*'*20)
steps = 0
train_loss_avg = 0
train_dice_avg = 0
train_dataset = data_generator(data_dir, train_path_list, input_size[:3], batch_size, True)
for x, y in train_dataset:
_, loss, pred_logits = sess.run([train_opt, loss_tf, pred], feed_dict={input: x, label: y})
dice = dice_score(y, pred_logits, n_class)
train_dice_avg += dice
train_loss_avg += loss
steps += 1
if steps % display_step==0:
print('epoch %d, steps %d, train loss: %.4f, train dice: %.4f' % (
epoch, steps, train_loss_avg / steps, train_dice_avg / steps))
train_loss_avg /= steps
train_dice_avg /= steps
# validation
print('*'*20, 'Valid Epoch %d'%epoch, '*'*20)
steps = 0
val_loss_avg = 0
val_dice_avg = 0
val_dataset = data_generator(data_dir, val_path_list, input_size[:3], batch_size, False)
for x, y in val_dataset:
val_loss, pred_logits = sess.run([loss_tf, pred], feed_dict={input: x, label: y})
dice = dice_score(y, pred_logits, n_class)
val_dice_avg += dice
steps += 1
val_loss_avg += val_loss
if steps % display_step==0:
print('Epoch {:}, valid steps {:}, loss={:.4f}'.format(epoch, steps, val_loss))
val_loss_avg /= steps
val_dice_avg /= steps
'epoch %d, steps %d, validation loss: %.4f, val dice: %4f' % (epoch, steps, val_loss_avg, val_dice_avg))
print('*'*20, 'Valid Epoch %d'%epoch, '*'*20)
# save model
saver.save(sess, os.path.join(save_model_dir, "epoch_%d_%.4f_model" % (epoch, val_dice_avg)),
def save_img(img_np, save_path):
img_itk = sitk.GetImageFromArray(img_np)
sitk.WriteImage(img_itk, save_path)
def recover_img(patch_preds, pos_list, strides, ori_shape):
sd, sh, sw = strides
patch_shape = patch_preds[0].shape
pd, ph, pw = patch_shape
img = np.zeros(ori_shape)
for patch, pos in zip(patch_preds, pos_list):
i, j, k = pos
img[i*sd:i*sd+pd, j*sh+20:j*sh+20+ph, k*sw+20:k*sw+20+pw] = patch
return img
def predict(model_path, patch_list, input_size, first_channels, n_class):
input_shape = (1,) + tuple(input_size)
inputs = tf.placeholder(tf.float32, shape=input_shape)
model = Unet3d(input_size, first_channels, n_class)
prediction = model(inputs)
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
preds = []
with tf.Session() as sess:
saver.restore(sess, model_path)
for i in range(len(patch_list)):
pred = sess.run(prediction, feed_dict={inputs: np.expand_dims(patch_list[i], 0)})
pred = np.squeeze(np.argmax(pred, -1))
return preds
def test(model_path, img_path, save_dir):
patch_size = [32, 160, 160, 4]
n_class = 4
first_channels = 8
strides = (16, 32, 32)
if not os.path.exists(save_dir):
save_path = os.path.join(save_dir, img_path.split('/')[-1])
img0 = read_img(img_path)
img1 = preprocess(img0)
patch_list, pos_list = extract_ordered_overlap_patches(img1, None, patch_size[:3])
preds = predict(model_path, patch_list, patch_size, first_channels, n_class)
pred = recover_img(preds, pos_list, strides, img0.shape[:3])
save_img(pred, save_path)
return img0, pred
if __name__ == '__main__':
if True:
if False:
model_path = '../saved_models/epoch_20_0.6411_model'
img_path = '../data/Task01_BrainTumour/imagesTs/BRATS_485.nii.gz'
save_dir = '../data/Task01_BrainTumour/imagesPre/'
img, pred = test(model_path, img_path, save_dir)
show_imgs([(img, pred)])
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import json
from skimage import measure
import random
# import tensorflow as tf
import tensorflow.compat.v1 as tf
from tensorflow.keras.layers import Layer
from tensorflow import keras
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, ReLU, UpSampling3D, Input, MaxPool3D, Concatenate
from tensorflow.keras import Model
def z_score_norm(img, mean=None, std=None):
if mean is None:
mean = np.mean(img)
if std is None:
std = np.std(img)
return (img - mean) / (std + 1e-6)
def show_imgs(images, cols=5):
rows = len(images)
titles = ['FLAIR', 'T1w', 'T1gd', 'T2w', 'Label']
f, axes = plt.subplots(max(1, rows), cols)
axes_ravel = axes.ravel()
for i, (image, label) in enumerate(images):
ds = np.where(label != 0)[0]
slice = ds[len(ds) // 2]
for j in range(cols-1):
axes_ravel[i*cols+j].imshow(image[slice, :, :, j], cmap='Greys_r')
axes_ravel[(i+1)*cols-1].imshow(label[slice, :, :], cmap='Greys_r')
plt.subplots_adjust(wspace=0.01, hspace=0)
def read_img(img_path, label_path=None):
img_itk = sitk.ReadImage(img_path)
img_np = sitk.GetArrayFromImage(img_itk)
img_np = np.moveaxis(img_np, 0, 3)
if label_path is not None:
label_itk = sitk.ReadImage(label_path)
label_np = sitk.GetArrayFromImage(label_itk)
return img_np, label_np
return img_np
def extract_ordered_overlap_patches(img, label, patch_size, s=16):
img = img[:, 20:-20, 20:-20, :]
d, h, w, _ = img.shape
patch_d, patch_h, patch_w = patch_size
sd = d - patch_d
sh = h - patch_h
sw = w - patch_w
std = s
sth = s*2
stw = s*2
patch_list = []
pos_list = []
if label is not None:
label = label[:, 20:-20, 20:-20]
for i in range(sd // std + 1):
for j in range(sh // sth + 1):
for k in range(sw//stw + 1):
patch_img = img[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w), :]
if label is not None:
patch_label = label[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w)]
if patch_label.shape != tuple(patch_size):
if np.count_nonzero(patch_label)/np.count_nonzero(label) >= 0.2:
patch_list.append((patch_img, patch_label))
pos_list.append((i, j, k))
pos_list.append((i, j, k))
return patch_list, pos_list
# pre-process images
def preprocess(image):
# z-score normalization in each slice and each channel
for i in range(image.shape[3]):
for z in range(image.shape[0]):
img_slice = image[z, :, :, i]
image[z, :, :, i] = z_score_norm(img_slice)
return image
# data loader
def data_generator(data_dir, path_list, target_shape, batch_size, is_training, buffer_size=8):
if not is_training:
buffer_size = 1
buffer_size = min(len(path_list), buffer_size)
k = len(path_list) // buffer_size
for i in range(k):
data_list = []
for j in range(i*buffer_size, (i+1)*buffer_size):
img_path = path_list[j]['image'].replace('./', data_dir)
label_path = path_list[j]['label'].replace('./', data_dir)
if not os.path.exists(img_path) or not os.path.exists(label_path):
img, label = read_img(img_path, label_path)
img = preprocess(img)
patch_list, _ = extract_ordered_overlap_patches(img, label, target_shape)
data_list += patch_list
X = np.array([it[0] for it in data_list])
Y = np.array([it[1] for it in data_list])
if is_training:
index = np.random.permutation(len(data_list))
X = X[index, ...]
Y = Y[index, ...]
for step in range(X.shape[0]//batch_size-1):
x = X[step * batch_size:(step + 1) * batch_size, ...]
y = Y[step * batch_size:(step + 1) * batch_size, ...]
yield x, y
class GroupNorm(Layer):
def __init__(self, groups=4):
super(GroupNorm, self).__init__()
self.G = groups
self.eps = 1e-5
def build(self, input_shape):
self.group = self.G if input_shape[-1] % self.G == 0 else 1
self.channel = input_shape[-1]
self.group = min(self.channel, self.group)
self.split = self.channel // self.group
self.gamma = self.add_weight(name='gamma_gn', shape=(1, 1, 1, 1,
input_shape[-1]), initializer='ones', trainable=True)
self.beta = self.add_weight(name='beta_gn', shape=(1, 1, 1, 1,
input_shape[-1]), initializer='zeros', trainable=True)
def call(self, inputs):
N, D, H, W, C = tf.keras.backend.int_shape(inputs)
inputs = tf.reshape(inputs, [-1, D, H, W, self.group, self.split])
mean, var = tf.nn.moments(inputs, [1, 2, 3, 5])
mean = tf.reshape(mean, [-1, 1, 1, 1, self.group, 1])
var = tf.reshape(var, [-1, 1, 1, 1, self.group, 1])
ipt = (inputs - mean) / tf.sqrt(var + self.eps)
output = tf.reshape(inputs, [-1, D, H, W, C]) * self.gamma + self.beta
return output
def conv_res_block(x, filters, activation=ReLU(), kernel_size=3, strides=1, padding='same', num_layers=1):
if num_layers == 1:
x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(x)
x = GroupNorm()(x)
x = activation(x)
return x
shortcut = Conv3D(filters, 1, strides=strides, padding=padding)(x)
shortcut = GroupNorm()(shortcut)
shortcut = activation(shortcut)
for i in range(num_layers):
x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(x)
x = activation(x)
x = x + shortcut
return x
def upsample_block(x, filters, activation=ReLU(), kernel_size=3, strides=1,
padding='same', deconv=False):
if deconv:
x = Conv3DTranspose(filters, 2, strides=2, padding=padding)(x)
x = UpSampling3D(size=2)(x)
x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(x)
x = GroupNorm()(x)
x = activation(x)
return x
# unet3d model
def Unet3d(img_shape, n_filters, n_class):
inputs = Input(shape=img_shape, name='input')
l1 = conv_res_block(inputs, n_filters, num_layers=1)
m1 = MaxPool3D()(l1)
l2 = conv_res_block(m1, n_filters * 2, num_layers=2)
m2 = MaxPool3D()(l2)
l3 = conv_res_block(m2, n_filters * 4, num_layers=3)
m3 = MaxPool3D()(l3)
l4 = conv_res_block(m3, n_filters * 8, num_layers=3)
m4 = MaxPool3D()(l4)
l5 = conv_res_block(m4, n_filters * 16, num_layers=3)
up6 = upsample_block(l5, n_filters * 8)
l6 = conv_res_block(Concatenate()([up6, l4]), n_filters * 8, num_layers=3)
up7 = upsample_block(l6, n_filters * 4)
l7 = conv_res_block(Concatenate()([up7, l3]), n_filters * 4, num_layers=3)
up8 = upsample_block(l7, n_filters * 2)
l8 = conv_res_block(Concatenate()([up8, l2]), n_filters * 2, num_layers=2)
up9 = upsample_block(l8, n_filters)
l9 = conv_res_block(Concatenate()([up9, l1]), n_filters, num_layers=1)
out = Conv3D(n_class, 1, padding='same', activation=keras.activations.softmax)(l9)
model = Model(inputs=inputs, outputs=out, name='output')
return model
# loss function
def explog_loss(y_true, y_pred, n_class, weights=1., w_d=0.8, w_c=0.2, g_d=0.3,
g_c=0.3, eps=1e-5):
Compute exp-log loss
y_true: ground truth with dimension of (batch, depth, height, width)
y_pred: prediction with dimension of (batch, depth, height, width, n_class)
n_class: classes
weights: weights of n classes, a float number or vector with dimension of (n,)
w_d: weight of dice loss
w_c: weight of cross entropy loss
g_d: exponent of dice loss
g_c: exponent of cross entropy loss
score: exp-log loss
y_pred = tf.cast(y_pred, tf.float32)
y_true = tf.cast(tf.one_hot(y_true, n_class), tf.float32)
y_true = tf.reshape(y_true, [-1, n_class])
y_pred = tf.reshape(y_pred, [-1, n_class])
y_pred = tf.clip_by_value(y_pred, eps, 1.0-eps)
intersection = tf.reduce_sum(y_true * y_pred, axis=0)
union = tf.reduce_sum(y_true, axis=0) + tf.reduce_sum(y_pred, axis=0)
dice = (2 * intersection + eps) / (union + eps)
dice = tf.clip_by_value(dice, eps, 1.0-eps)
dice_log_loss = -tf.math.log(dice)
Ld = tf.reduce_mean(tf.pow(dice_log_loss, g_d))
wce = weights * y_true * tf.pow(-tf.math.log(y_pred), g_c)
Lc = tf.reduce_mean(wce)
score = w_d * Ld + w_c * Lc
return score
# metrics
def dice_score(y_true, y_pred, n_class, exp=1e-5):
Compute dice score with ground truth and prediction without argmax
y_true: ground truth with dimension of (batch, depth, height, width)
y_pred: prediction with dimension of (batch, depth, height, width, n_class)
n_class: classes number
score: average dice score in n classes
dices = []
y_pred = np.argmax(y_pred, axis=-1)
for i in range(1, n_class):
pred = y_pred == i
label = y_true == i
intersection = 2 * np.sum(label * pred, axis=(1, 2, 3)) + exp
union = np.sum(label, axis=(1, 2, 3)) + np.sum(pred, axis=(1, 2, 3)) + exp
dice = intersection / union
score = np.mean(dices)
return score
# training process
def train():
batch_size = 1
epochs = 500
input_size = [32, 160, 160, 4]
n_class = 4
first_channels = 8
lr = 0.001
save_model_dir = '../saved_models/'
data_dir = '../data/Task01_BrainTumour/'
with open(os.path.join(data_dir, 'dataset.json'), 'r') as f:
data_info = json.load(f)
path_list = data_info['training']
n_sample = len(path_list)
train_path_list = path_list[:int(n_sample*0.8)]
val_path_list = path_list[int(n_sample * 0.8):]
model = Unet3d(input_size, first_channels, n_class)
input = model.input
pred = model.output
label = tf.placeholder(tf.int32, shape=[None] + input_size[:3])
loss_tf = explog_loss(label, pred, n_class, weights=[1, 10, 20, 20])
global_step = tf.Variable(0, name='global_step', trainable=False)
lr_schedule = tf.train.exponential_decay(
optimizer = tf.train.AdamOptimizer(learning_rate=lr_schedule)
train_opt = optimizer.minimize(loss_tf, global_step=global_step)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
for epoch in range(epochs):
# training
steps = 0
train_loss_avg = 0
train_dice_avg = 0
train_dataset = data_generator(data_dir, train_path_list, input_size[:3], batch_size, True)
for x, y in train_dataset:
_, loss, pred_logits = sess.run([train_opt, loss_tf, pred], feed_dict={input: x, label: y})
dice = dice_score(y, pred_logits, n_class)
train_dice_avg += dice
train_loss_avg += loss
steps += 1
print('epoch %d, steps %d, train loss: %.4f, train dice: %.4f' % (
epoch, steps, train_loss_avg / steps, train_dice_avg / steps))
train_loss_avg /= steps
train_dice_avg /= steps
# validation
steps = 0
val_loss_avg = 0
val_dice_avg = 0
val_dataset = data_generator(data_dir, val_path_list, input_size[:3], batch_size, False)
for x, y in val_dataset:
val_loss, pred_logits = sess.run([loss_tf, pred], feed_dict={input: x, label: y})
dice = dice_score(y, pred_logits, n_class)
val_dice_avg += dice
steps += 1
val_loss_avg += val_loss
val_loss_avg /= steps
val_dice_avg /= steps
'epoch %d, steps %d, validation loss: %.4f, val dice: %4f' % (epoch, steps, val_loss_avg, val_dice_avg))
# save model
saver.save(sess, os.path.join(save_model_dir, "epoch_%d_%.4f_model" % (epoch, val_dice_avg)),
def save_img(img_np, save_path):
img_itk = sitk.GetImageFromArray(img_np)
sitk.WriteImage(img_itk, save_path)
def recover_img(patch_preds, pos_list, strides, ori_shape):
sd, sh, sw = strides
patch_shape = patch_preds[0].shape
pd, ph, pw = patch_shape
img = np.zeros(ori_shape)
for patch, pos in zip(patch_preds, pos_list):
i, j, k = pos
img[i*sd:i*sd+pd, j*sh+20:j*sh+20+ph, k*sw+20:k*sw+20+pw] = patch
return img
def predict(model_path, patch_list, input_size, first_channels, n_class):
input_shape = (1,) + tuple(input_size)
inputs = tf.placeholder(tf.float32, shape=input_shape)
model = Unet3d(input_size, first_channels, n_class)
prediction = model(inputs)
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
preds = []
with tf.Session() as sess:
saver.restore(sess, model_path)
for i in range(len(patch_list)):
pred = sess.run(prediction, feed_dict={inputs: np.expand_dims(patch_list[i], 0)})
pred = np.squeeze(np.argmax(pred, -1))
return preds
def test(model_path, img_path, save_dir):
patch_size = [32, 160, 160, 4]
n_class = 4
first_channels = 8
strides = (16, 32, 32)
if not os.path.exists(save_dir):
save_path = os.path.join(save_dir, img_path.split('/')[-1])
img0 = read_img(img_path)
img1 = preprocess(img0)
patch_list, pos_list = extract_ordered_overlap_patches(img1, None, patch_size[:3])
preds = predict(model_path, patch_list, patch_size, first_channels, n_class)
pred = recover_img(preds, pos_list, strides, img0.shape[:3])
save_img(pred, save_path)
return img0, pred
if __name__ == '__main__':
if False:
if True:
model_path = '../saved_models/epoch_271_0.6776_model'
img_path = '../data/Task01_BrainTumour/imagesTs/BRATS_485.nii.gz'
save_dir = '../data/Task01_BrainTumour/imagesPre/'
img, pred = test(model_path, img_path, save_dir)
show_imgs([(img, pred)])