《深度学习与医学图像处理》第6章的内容,讲的是医学图像的语义分割。
医学图像的语义分割是一种计算机视觉处理技术,用于将医学图像中的数据分配到特定的类别,从而区分出图像中的不同结构和组织。语义分割对于医学图像分析尤为重要,可以帮助医生更准确地识别和理解图像中的解剖结构,从而提高疾病的诊断、治疗规划和手术导航的准确性。
语义分割涉及到不少关键性的步骤,包括:
-
图像预处理:对原始图像进行去噪、增强等处理,以提高图像质量。
-
特征提取:使用各种算法(如卷积神经网络CNN)提取图像特征。
-
分割网络:设计和训练一个神经网络模型,该网络能够识别图像中的不同区域,并将其分割成不同的类别。
-
优化处理:对分割结果进行优化,如形态学操作、平滑等,以提高分割的准确性。
-
测试评估:使用各种指标(如交并比IoU、Dice系数等)评估分割结果的质量。
在本章的内容中,重点讲解了以上步骤中所涉及到的损失函数、评价指标、分割模型,并提供了实战的指导。
一、内容思维导图
本章内容部分的思维导图如下:
了解了上图中的相关内容后,就可以按照书中的指导,进行实战了。
二、环境准备
在 【学习分享6】医学图像分类处理和模型训练实践 的环境基础上,可以进行本章的实战,不过还需要安装一下如下的python库:
pip install SimpleITK==2.0.2
因为本章涉及到数据的语义分割,会进行图片的展示,而我是在远程服务器上进行训练的,所以需要做一些预备工作,以便通过vnc来进行图像数据的显示。
# 安装xvfb
sudo apt install -y xvfb
# 开启xvfb
Xvfb :0 -screen 0 1280x960x24 -listen tcp -ac +extension GLX +extension RENDER &
# 设置默认显示
export DISPLAY=:0
# 安装x11vnc
sudo apt install -y x11vnc
# 启动x11vnc
sudo x11vnc -display :0 -forever -shared -rfbport 端口 -passwd 密码 &
通过上面的步骤,开启vnc,然后在本地电脑,使用vnc客户端工具连接。
然后使用python测试画图:
import matplotlib.pyplot as plt
import numpy as np
# 生成数据
x = np.arange(0, 10, 0.1) # 横坐标数据为从0到10之间,步长为0.1的等差数组
y = np.sin(x) # 纵坐标数据为 x 对应的 sin(x) 值
# 生成图形
plt.plot(x, y)
# 显示图形
plt.show()
能够正常显示图片,就可以后续处理了:
三、测试数据集下载
本章的测试数据集,使用的而是MRI公开数据集,从 Medical Segmentation Decathlon 下载,然后解压。
tar xvf Task01_BrainTumour.tar
ls -l Task01_BrainTumour
解压后的文件如下:
四、数据预处理
首先,查看一下数据对应的图像,对应的代码如下:
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import json
from skimage import measure
import random
# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.logging.set_verbosity(tf.logging.ERROR)
from tensorflow.keras.layers import Layer
from tensorflow import keras
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, ReLU, UpSampling3D, Input, MaxPool3D, Concatenate
from tensorflow.keras import Model
def z_score_norm(img, mean=None, std=None):
if mean is None:
mean = np.mean(img)
if std is None:
std = np.std(img)
return (img - mean) / (std + 1e-6)
def show_imgs(images, cols=5):
rows = len(images)
titles = ['FLAIR', 'T1w', 'T1gd', 'T2w', 'Label']
f, axes = plt.subplots(max(1, rows), cols)
axes_ravel = axes.ravel()
for i, (image, label) in enumerate(images):
ds = np.where(label != 0)[0]
ds.sort()
slice = ds[len(ds) // 2]
for j in range(cols-1):
axes_ravel[i*cols+j].set_axis_off()
axes_ravel[i*cols+j].imshow(image[slice, :, :, j], cmap='Greys_r')
axes_ravel[i*cols+j].set_title(titles[j])
axes_ravel[(i+1)*cols-1].set_axis_off()
axes_ravel[(i+1)*cols-1].imshow(label[slice, :, :], cmap='Greys_r')
axes_ravel[(i+1)*cols-1].set_title(titles[-1])
f.tight_layout()
plt.subplots_adjust(wspace=0.01, hspace=0)
plt.show()
def read_img(img_path, label_path=None):
img_itk = sitk.ReadImage(img_path)
img_np = sitk.GetArrayFromImage(img_itk)
img_np = np.moveaxis(img_np, 0, 3)
if label_path is not None:
label_itk = sitk.ReadImage(label_path)
label_np = sitk.GetArrayFromImage(label_itk)
return img_np, label_np
return img_np
img_path = '../data/Task01_BrainTumour/imagesTr/BRATS_001.nii.gz'
label_path = '../data/Task01_BrainTumour/labelsTr/BRATS_001.nii.gz'
img_np, label_np = read_img(img_path, label_path)
show_imgs([[img_np, label_np]])
运行后,会显示上面代码中对应的数据的图像:
因为是测试数据集,所以包含了实际的病灶部位的图像。
然后,对图像进行预处理,对应的代码如下:
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import json
from skimage import measure
import random
# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.logging.set_verbosity(tf.logging.ERROR)
from tensorflow.keras.layers import Layer
from tensorflow import keras
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, ReLU, UpSampling3D, Input, MaxPool3D, Concatenate
from tensorflow.keras import Model
def z_score_norm(img, mean=None, std=None):
if mean is None:
mean = np.mean(img)
if std is None:
std = np.std(img)
return (img - mean) / (std + 1e-6)
def show_imgs(images, cols=5):
rows = len(images)
titles = ['FLAIR', 'T1w', 'T1gd', 'T2w', 'Label']
f, axes = plt.subplots(max(1, rows), cols)
axes_ravel = axes.ravel()
for i, (image, label) in enumerate(images):
ds = np.where(label != 0)[0]
ds.sort()
slice = ds[len(ds) // 2]
for j in range(cols-1):
axes_ravel[i*cols+j].set_axis_off()
axes_ravel[i*cols+j].imshow(image[slice, :, :, j], cmap='Greys_r')
axes_ravel[i*cols+j].set_title(titles[j])
axes_ravel[(i+1)*cols-1].set_axis_off()
axes_ravel[(i+1)*cols-1].imshow(label[slice, :, :], cmap='Greys_r')
axes_ravel[(i+1)*cols-1].set_title(titles[-1])
f.tight_layout()
plt.subplots_adjust(wspace=0.01, hspace=0)
plt.show()
def read_img(img_path, label_path=None):
img_itk = sitk.ReadImage(img_path)
img_np = sitk.GetArrayFromImage(img_itk)
img_np = np.moveaxis(img_np, 0, 3)
if label_path is not None:
label_itk = sitk.ReadImage(label_path)
label_np = sitk.GetArrayFromImage(label_itk)
return img_np, label_np
return img_np
def extract_ordered_overlap_patches(img, label, patch_size, s=16):
img = img[:, 20:-20, 20:-20, :]
d, h, w, _ = img.shape
patch_d, patch_h, patch_w = patch_size
sd = d - patch_d
sh = h - patch_h
sw = w - patch_w
std = s
sth = s*2
stw = s*2
patch_list = []
pos_list = []
if label is not None:
label = label[:, 20:-20, 20:-20]
for i in range(sd // std + 1):
for j in range(sh // sth + 1):
for k in range(sw//stw + 1):
patch_img = img[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w), :]
if label is not None:
patch_label = label[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w)]
if patch_label.shape != tuple(patch_size):
continue
if np.count_nonzero(patch_label)/np.count_nonzero(label) >= 0.2:
patch_list.append((patch_img, patch_label))
pos_list.append((i, j, k))
else:
patch_list.append(patch_img)
pos_list.append((i, j, k))
return patch_list, pos_list
# pre-process images
def preprocess(image):
# z-score normalization in each slice and each channel
for i in range(image.shape[3]):
for z in range(image.shape[0]):
img_slice = image[z, :, :, i]
image[z, :, :, i] = z_score_norm(img_slice)
return image
img_path = '../data/Task01_BrainTumour/imagesTr/BRATS_001.nii.gz'
label_path = '../data/Task01_BrainTumour/labelsTr/BRATS_001.nii.gz'
img_np, label_np = read_img(img_path, label_path)
show_imgs([[img_np, label_np]])
patch_size = (32, 160, 160)
patch_list, _ = extract_ordered_overlap_patches(img_np, label_np, patch_size)
show_imgs(patch_list[:1])
运行后,结果如下:
紧接着,就是数据生成器部分了,对应的代码如下:
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import json
from skimage import measure
import random
# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.logging.set_verbosity(tf.logging.ERROR)
from tensorflow.keras.layers import Layer
from tensorflow import keras
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, ReLU, UpSampling3D, Input, MaxPool3D, Concatenate
from tensorflow.keras import Model
def z_score_norm(img, mean=None, std=None):
if mean is None:
mean = np.mean(img)
if std is None:
std = np.std(img)
return (img - mean) / (std + 1e-6)
def show_imgs(images, cols=5):
rows = len(images)
titles = ['FLAIR', 'T1w', 'T1gd', 'T2w', 'Label']
f, axes = plt.subplots(max(1, rows), cols)
axes_ravel = axes.ravel()
for i, (image, label) in enumerate(images):
ds = np.where(label != 0)[0]
ds.sort()
slice = ds[len(ds) // 2]
for j in range(cols-1):
axes_ravel[i*cols+j].set_axis_off()
axes_ravel[i*cols+j].imshow(image[slice, :, :, j], cmap='Greys_r')
axes_ravel[i*cols+j].set_title(titles[j])
axes_ravel[(i+1)*cols-1].set_axis_off()
axes_ravel[(i+1)*cols-1].imshow(label[slice, :, :], cmap='Greys_r')
axes_ravel[(i+1)*cols-1].set_title(titles[-1])
f.tight_layout()
plt.subplots_adjust(wspace=0.01, hspace=0)
plt.show()
def read_img(img_path, label_path=None):
img_itk = sitk.ReadImage(img_path)
img_np = sitk.GetArrayFromImage(img_itk)
img_np = np.moveaxis(img_np, 0, 3)
if label_path is not None:
label_itk = sitk.ReadImage(label_path)
label_np = sitk.GetArrayFromImage(label_itk)
return img_np, label_np
return img_np
def extract_ordered_overlap_patches(img, label, patch_size, s=16):
img = img[:, 20:-20, 20:-20, :]
d, h, w, _ = img.shape
patch_d, patch_h, patch_w = patch_size
sd = d - patch_d
sh = h - patch_h
sw = w - patch_w
std = s
sth = s*2
stw = s*2
patch_list = []
pos_list = []
if label is not None:
label = label[:, 20:-20, 20:-20]
for i in range(sd // std + 1):
for j in range(sh // sth + 1):
for k in range(sw//stw + 1):
patch_img = img[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w), :]
if label is not None:
patch_label = label[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w)]
if patch_label.shape != tuple(patch_size):
continue
if np.count_nonzero(patch_label)/np.count_nonzero(label) >= 0.2:
patch_list.append((patch_img, patch_label))
pos_list.append((i, j, k))
else:
patch_list.append(patch_img)
pos_list.append((i, j, k))
return patch_list, pos_list
# pre-process images
def preprocess(image):
# z-score normalization in each slice and each channel
for i in range(image.shape[3]):
for z in range(image.shape[0]):
img_slice = image[z, :, :, i]
image[z, :, :, i] = z_score_norm(img_slice)
return image
# data loader
def data_generator(data_dir, path_list, target_shape, batch_size, is_training, buffer_size=8):
if not is_training:
buffer_size = 1
else:
random.shuffle(path_list)
buffer_size = min(len(path_list), buffer_size)
k = len(path_list) // buffer_size
for i in range(k):
data_list = []
for j in range(i*buffer_size, (i+1)*buffer_size):
img_path = path_list[j]['image'].replace('./', data_dir)
label_path = path_list[j]['label'].replace('./', data_dir)
if not os.path.exists(img_path) or not os.path.exists(label_path):
continue
img, label = read_img(img_path, label_path)
img = preprocess(img)
patch_list, _ = extract_ordered_overlap_patches(img, label, target_shape)
data_list += patch_list
X = np.array([it[0] for it in data_list])
Y = np.array([it[1] for it in data_list])
if is_training:
index = np.random.permutation(len(data_list))
X = X[index, ...]
Y = Y[index, ...]
for step in range(X.shape[0]//batch_size-1):
x = X[step * batch_size:(step + 1) * batch_size, ...]
y = Y[step * batch_size:(step + 1) * batch_size, ...]
yield x, y
img_path = '../data/Task01_BrainTumour/imagesTr/BRATS_001.nii.gz'
label_path = '../data/Task01_BrainTumour/labelsTr/BRATS_001.nii.gz'
img_np, label_np = read_img(img_path, label_path)
show_imgs([[img_np, label_np]])
patch_size = (32, 160, 160)
patch_list, _ = extract_ordered_overlap_patches(img_np, label_np, patch_size)
show_imgs(patch_list[:1])
data_dir = '../data/Task01_BrainTumour/'
with open(os.path.join(data_dir, 'dataset.json'), 'r') as f:
data = json.load(f)
train_path_list = data['training']
# patch_size = (32, 160, 160)
test_data = data_generator(data_dir, train_path_list[:3], patch_size, 1, False, 1)
for x,y in test_data:
print(x.shape, y.shape)
运行后,输出结果如下:
有了这个输出,说明准备工作都做好,可以开始实际的训练了。
五、模型训练
模型训练部分的代码如下:
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import json
from skimage import measure
import random
# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.logging.set_verbosity(tf.logging.ERROR)
from tensorflow.keras.layers import Layer
from tensorflow import keras
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, ReLU, UpSampling3D, Input, MaxPool3D, Concatenate
from tensorflow.keras import Model
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
def z_score_norm(img, mean=None, std=None):
if mean is None:
mean = np.mean(img)
if std is None:
std = np.std(img)
return (img - mean) / (std + 1e-6)
def show_imgs(images, cols=5):
rows = len(images)
titles = ['FLAIR', 'T1w', 'T1gd', 'T2w', 'Label']
f, axes = plt.subplots(max(1, rows), cols)
axes_ravel = axes.ravel()
for i, (image, label) in enumerate(images):
ds = np.where(label != 0)[0]
ds.sort()
slice = ds[len(ds) // 2]
for j in range(cols-1):
axes_ravel[i*cols+j].set_axis_off()
axes_ravel[i*cols+j].imshow(image[slice, :, :, j], cmap='Greys_r')
axes_ravel[i*cols+j].set_title(titles[j])
axes_ravel[(i+1)*cols-1].set_axis_off()
axes_ravel[(i+1)*cols-1].imshow(label[slice, :, :], cmap='Greys_r')
axes_ravel[(i+1)*cols-1].set_title(titles[-1])
f.tight_layout()
plt.subplots_adjust(wspace=0.01, hspace=0)
plt.show()
def read_img(img_path, label_path=None):
img_itk = sitk.ReadImage(img_path)
img_np = sitk.GetArrayFromImage(img_itk)
img_np = np.moveaxis(img_np, 0, 3)
if label_path is not None:
label_itk = sitk.ReadImage(label_path)
label_np = sitk.GetArrayFromImage(label_itk)
return img_np, label_np
return img_np
def extract_ordered_overlap_patches(img, label, patch_size, s=16):
img = img[:, 20:-20, 20:-20, :]
d, h, w, _ = img.shape
patch_d, patch_h, patch_w = patch_size
sd = d - patch_d
sh = h - patch_h
sw = w - patch_w
std = s
sth = s*2
stw = s*2
patch_list = []
pos_list = []
if label is not None:
label = label[:, 20:-20, 20:-20]
for i in range(sd // std + 1):
for j in range(sh // sth + 1):
for k in range(sw//stw + 1):
patch_img = img[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w), :]
if label is not None:
patch_label = label[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w)]
if patch_label.shape != tuple(patch_size):
continue
if np.count_nonzero(patch_label)/np.count_nonzero(label) >= 0.2:
patch_list.append((patch_img, patch_label))
pos_list.append((i, j, k))
else:
patch_list.append(patch_img)
pos_list.append((i, j, k))
return patch_list, pos_list
# pre-process images
def preprocess(image):
# z-score normalization in each slice and each channel
for i in range(image.shape[3]):
for z in range(image.shape[0]):
img_slice = image[z, :, :, i]
image[z, :, :, i] = z_score_norm(img_slice)
return image
# data loader
def data_generator(data_dir, path_list, target_shape, batch_size, is_training, buffer_size=8):
if not is_training:
buffer_size = 1
else:
random.shuffle(path_list)
buffer_size = min(len(path_list), buffer_size)
k = len(path_list) // buffer_size
for i in range(k):
data_list = []
for j in range(i*buffer_size, (i+1)*buffer_size):
img_path = path_list[j]['image'].replace('./', data_dir)
label_path = path_list[j]['label'].replace('./', data_dir)
if not os.path.exists(img_path) or not os.path.exists(label_path):
continue
img, label = read_img(img_path, label_path)
img = preprocess(img)
patch_list, _ = extract_ordered_overlap_patches(img, label, target_shape)
data_list += patch_list
X = np.array([it[0] for it in data_list])
Y = np.array([it[1] for it in data_list])
if is_training:
index = np.random.permutation(len(data_list))
X = X[index, ...]
Y = Y[index, ...]
for step in range(X.shape[0]//batch_size-1):
x = X[step * batch_size:(step + 1) * batch_size, ...]
y = Y[step * batch_size:(step + 1) * batch_size, ...]
yield x, y
class GroupNorm(Layer):
def __init__(self, groups=4):
super(GroupNorm, self).__init__()
self.G = groups
self.eps = 1e-5
def build(self, input_shape):
self.group = self.G if input_shape[-1] % self.G == 0 else 1
self.channel = input_shape[-1]
self.group = min(self.channel, self.group)
self.split = self.channel // self.group
self.gamma = self.add_weight(name='gamma_gn', shape=(1, 1, 1, 1,
input_shape[-1]), initializer='ones', trainable=True)
self.beta = self.add_weight(name='beta_gn', shape=(1, 1, 1, 1,
input_shape[-1]), initializer='zeros', trainable=True)
def call(self, inputs):
N, D, H, W, C = tf.keras.backend.int_shape(inputs)
inputs = tf.reshape(inputs, [-1, D, H, W, self.group, self.split])
mean, var = tf.nn.moments(inputs, [1, 2, 3, 5])
mean = tf.reshape(mean, [-1, 1, 1, 1, self.group, 1])
var = tf.reshape(var, [-1, 1, 1, 1, self.group, 1])
ipt = (inputs - mean) / tf.sqrt(var + self.eps)
output = tf.reshape(inputs, [-1, D, H, W, C]) * self.gamma + self.beta
return output
def conv_res_block(x, filters, activation=ReLU(), kernel_size=3, strides=1, padding='same', num_layers=1):
if num_layers == 1:
x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(x)
x = GroupNorm()(x)
x = activation(x)
return x
shortcut = Conv3D(filters, 1, strides=strides, padding=padding)(x)
shortcut = GroupNorm()(shortcut)
shortcut = activation(shortcut)
for i in range(num_layers):
x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(x)
x = activation(x)
x = x + shortcut
return x
def upsample_block(x, filters, activation=ReLU(), kernel_size=3, strides=1,
padding='same', deconv=False):
if deconv:
x = Conv3DTranspose(filters, 2, strides=2, padding=padding)(x)
else:
x = UpSampling3D(size=2)(x)
x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(x)
x = GroupNorm()(x)
x = activation(x)
return x
# unet3d model
def Unet3d(img_shape, n_filters, n_class):
inputs = Input(shape=img_shape, name='input')
l1 = conv_res_block(inputs, n_filters, num_layers=1)
m1 = MaxPool3D()(l1)
l2 = conv_res_block(m1, n_filters * 2, num_layers=2)
m2 = MaxPool3D()(l2)
l3 = conv_res_block(m2, n_filters * 4, num_layers=3)
m3 = MaxPool3D()(l3)
l4 = conv_res_block(m3, n_filters * 8, num_layers=3)
m4 = MaxPool3D()(l4)
l5 = conv_res_block(m4, n_filters * 16, num_layers=3)
up6 = upsample_block(l5, n_filters * 8)
l6 = conv_res_block(Concatenate()([up6, l4]), n_filters * 8, num_layers=3)
up7 = upsample_block(l6, n_filters * 4)
l7 = conv_res_block(Concatenate()([up7, l3]), n_filters * 4, num_layers=3)
up8 = upsample_block(l7, n_filters * 2)
l8 = conv_res_block(Concatenate()([up8, l2]), n_filters * 2, num_layers=2)
up9 = upsample_block(l8, n_filters)
l9 = conv_res_block(Concatenate()([up9, l1]), n_filters, num_layers=1)
out = Conv3D(n_class, 1, padding='same', activation=keras.activations.softmax)(l9)
model = Model(inputs=inputs, outputs=out, name='output')
return model
# loss function
def explog_loss(y_true, y_pred, n_class, weights=1., w_d=0.8, w_c=0.2, g_d=0.3,
g_c=0.3, eps=1e-5):
"""
Compute exp-log loss
Args:
y_true: ground truth with dimension of (batch, depth, height, width)
y_pred: prediction with dimension of (batch, depth, height, width, n_class)
n_class: classes
weights: weights of n classes, a float number or vector with dimension of (n,)
w_d: weight of dice loss
w_c: weight of cross entropy loss
g_d: exponent of dice loss
g_c: exponent of cross entropy loss
Returns:
score: exp-log loss
"""
y_pred = tf.cast(y_pred, tf.float32)
y_true = tf.cast(tf.one_hot(y_true, n_class), tf.float32)
y_true = tf.reshape(y_true, [-1, n_class])
y_pred = tf.reshape(y_pred, [-1, n_class])
y_pred = tf.clip_by_value(y_pred, eps, 1.0-eps)
intersection = tf.reduce_sum(y_true * y_pred, axis=0)
union = tf.reduce_sum(y_true, axis=0) + tf.reduce_sum(y_pred, axis=0)
dice = (2 * intersection + eps) / (union + eps)
dice = tf.clip_by_value(dice, eps, 1.0-eps)
dice_log_loss = -tf.math.log(dice)
Ld = tf.reduce_mean(tf.pow(dice_log_loss, g_d))
wce = weights * y_true * tf.pow(-tf.math.log(y_pred), g_c)
Lc = tf.reduce_mean(wce)
score = w_d * Ld + w_c * Lc
return score
# metrics
def dice_score(y_true, y_pred, n_class, exp=1e-5):
"""
Compute dice score with ground truth and prediction without argmax
Args:
y_true: ground truth with dimension of (batch, depth, height, width)
y_pred: prediction with dimension of (batch, depth, height, width, n_class)
n_class: classes number
Returns:
score: average dice score in n classes
"""
dices = []
y_pred = np.argmax(y_pred, axis=-1)
for i in range(1, n_class):
pred = y_pred == i
label = y_true == i
intersection = 2 * np.sum(label * pred, axis=(1, 2, 3)) + exp
union = np.sum(label, axis=(1, 2, 3)) + np.sum(pred, axis=(1, 2, 3)) + exp
dice = intersection / union
dices.append(dice)
score = np.mean(dices)
return score
# training process
def train(display_step=10):
batch_size = 1
epochs = 500
input_size = [32, 160, 160, 4]
n_class = 4
first_channels = 8
lr = 0.001
save_model_dir = '../saved_models/'
data_dir = '../data/Task01_BrainTumour/'
with open(os.path.join(data_dir, 'dataset.json'), 'r') as f:
data_info = json.load(f)
path_list = data_info['training']
n_sample = len(path_list)
train_path_list = path_list[:int(n_sample*0.8)]
val_path_list = path_list[int(n_sample * 0.8):]
model = Unet3d(input_size, first_channels, n_class)
input = model.input
pred = model.output
label = tf.placeholder(tf.int32, shape=[None] + input_size[:3])
loss_tf = explog_loss(label, pred, n_class, weights=[1, 10, 20, 20])
global_step = tf.Variable(0, name='global_step', trainable=False)
lr_schedule = tf.train.exponential_decay(
lr,
global_step,
decay_steps=5000,
decay_rate=0.98)
optimizer = tf.train.AdamOptimizer(learning_rate=lr_schedule)
train_opt = optimizer.minimize(loss_tf, global_step=global_step)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
print("start optimize")
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
sess.run(init_op)
for epoch in range(epochs):
# training
print('*'*20, 'Train Epoch %d'%epoch, '*'*20)
steps = 0
train_loss_avg = 0
train_dice_avg = 0
train_dataset = data_generator(data_dir, train_path_list, input_size[:3], batch_size, True)
for x, y in train_dataset:
_, loss, pred_logits = sess.run([train_opt, loss_tf, pred], feed_dict={input: x, label: y})
dice = dice_score(y, pred_logits, n_class)
train_dice_avg += dice
train_loss_avg += loss
steps += 1
if steps % display_step==0:
print('epoch %d, steps %d, train loss: %.4f, train dice: %.4f' % (
epoch, steps, train_loss_avg / steps, train_dice_avg / steps))
train_loss_avg /= steps
train_dice_avg /= steps
# validation
print('*'*20, 'Valid Epoch %d'%epoch, '*'*20)
steps = 0
val_loss_avg = 0
val_dice_avg = 0
val_dataset = data_generator(data_dir, val_path_list, input_size[:3], batch_size, False)
for x, y in val_dataset:
val_loss, pred_logits = sess.run([loss_tf, pred], feed_dict={input: x, label: y})
dice = dice_score(y, pred_logits, n_class)
val_dice_avg += dice
steps += 1
val_loss_avg += val_loss
if steps % display_step==0:
print('Epoch {:}, valid steps {:}, loss={:.4f}'.format(epoch, steps, val_loss))
val_loss_avg /= steps
val_dice_avg /= steps
print(
'epoch %d, steps %d, validation loss: %.4f, val dice: %4f' % (epoch, steps, val_loss_avg, val_dice_avg))
print('*'*20, 'Valid Epoch %d'%epoch, '*'*20)
# save model
saver.save(sess, os.path.join(save_model_dir, "epoch_%d_%.4f_model" % (epoch, val_dice_avg)),
write_meta_graph=False)
def save_img(img_np, save_path):
img_itk = sitk.GetImageFromArray(img_np)
sitk.WriteImage(img_itk, save_path)
def recover_img(patch_preds, pos_list, strides, ori_shape):
sd, sh, sw = strides
patch_shape = patch_preds[0].shape
pd, ph, pw = patch_shape
img = np.zeros(ori_shape)
for patch, pos in zip(patch_preds, pos_list):
i, j, k = pos
img[i*sd:i*sd+pd, j*sh+20:j*sh+20+ph, k*sw+20:k*sw+20+pw] = patch
return img
def predict(model_path, patch_list, input_size, first_channels, n_class):
input_shape = (1,) + tuple(input_size)
inputs = tf.placeholder(tf.float32, shape=input_shape)
model = Unet3d(input_size, first_channels, n_class)
prediction = model(inputs)
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
preds = []
with tf.Session() as sess:
sess.run(init_op)
saver.restore(sess, model_path)
for i in range(len(patch_list)):
pred = sess.run(prediction, feed_dict={inputs: np.expand_dims(patch_list[i], 0)})
pred = np.squeeze(np.argmax(pred, -1))
preds.append(pred)
return preds
def test(model_path, img_path, save_dir):
patch_size = [32, 160, 160, 4]
n_class = 4
first_channels = 8
strides = (16, 32, 32)
if not os.path.exists(save_dir):
os.mkdir(save_dir)
save_path = os.path.join(save_dir, img_path.split('/')[-1])
img0 = read_img(img_path)
img1 = preprocess(img0)
patch_list, pos_list = extract_ordered_overlap_patches(img1, None, patch_size[:3])
preds = predict(model_path, patch_list, patch_size, first_channels, n_class)
pred = recover_img(preds, pos_list, strides, img0.shape[:3])
save_img(pred, save_path)
return img0, pred
if __name__ == '__main__':
if True:
train()
if False:
model_path = '../saved_models/epoch_20_0.6411_model'
img_path = '../data/Task01_BrainTumour/imagesTs/BRATS_485.nii.gz'
save_dir = '../data/Task01_BrainTumour/imagesPre/'
img, pred = test(model_path, img_path, save_dir)
show_imgs([(img, pred)])
模型训练过程中,会输出当前的进度,具体如下:
训练的时间较长,需要耐心等待。
训练过过程中,可以查看生成的文件:
六、模型的测试
模型测试部分的代码如下:
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import json
from skimage import measure
import random
# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.logging.set_verbosity(tf.logging.ERROR)
from tensorflow.keras.layers import Layer
from tensorflow import keras
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, ReLU, UpSampling3D, Input, MaxPool3D, Concatenate
from tensorflow.keras import Model
os.environ["CUDA_VISIBLE_DEVICES"]="3"
def z_score_norm(img, mean=None, std=None):
if mean is None:
mean = np.mean(img)
if std is None:
std = np.std(img)
return (img - mean) / (std + 1e-6)
def show_imgs(images, cols=5):
rows = len(images)
titles = ['FLAIR', 'T1w', 'T1gd', 'T2w', 'Label']
f, axes = plt.subplots(max(1, rows), cols)
axes_ravel = axes.ravel()
for i, (image, label) in enumerate(images):
ds = np.where(label != 0)[0]
ds.sort()
slice = ds[len(ds) // 2]
for j in range(cols-1):
axes_ravel[i*cols+j].set_axis_off()
axes_ravel[i*cols+j].imshow(image[slice, :, :, j], cmap='Greys_r')
axes_ravel[i*cols+j].set_title(titles[j])
axes_ravel[(i+1)*cols-1].set_axis_off()
axes_ravel[(i+1)*cols-1].imshow(label[slice, :, :], cmap='Greys_r')
axes_ravel[(i+1)*cols-1].set_title(titles[-1])
f.tight_layout()
plt.subplots_adjust(wspace=0.01, hspace=0)
plt.show()
def read_img(img_path, label_path=None):
img_itk = sitk.ReadImage(img_path)
img_np = sitk.GetArrayFromImage(img_itk)
img_np = np.moveaxis(img_np, 0, 3)
if label_path is not None:
label_itk = sitk.ReadImage(label_path)
label_np = sitk.GetArrayFromImage(label_itk)
return img_np, label_np
return img_np
def extract_ordered_overlap_patches(img, label, patch_size, s=16):
img = img[:, 20:-20, 20:-20, :]
d, h, w, _ = img.shape
patch_d, patch_h, patch_w = patch_size
sd = d - patch_d
sh = h - patch_h
sw = w - patch_w
std = s
sth = s*2
stw = s*2
patch_list = []
pos_list = []
if label is not None:
label = label[:, 20:-20, 20:-20]
for i in range(sd // std + 1):
for j in range(sh // sth + 1):
for k in range(sw//stw + 1):
patch_img = img[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w), :]
if label is not None:
patch_label = label[i*std:min(d, i*std+patch_d), j*sth:min(h, j*sth+patch_h), k*stw:min(w, k*stw+patch_w)]
if patch_label.shape != tuple(patch_size):
continue
if np.count_nonzero(patch_label)/np.count_nonzero(label) >= 0.2:
patch_list.append((patch_img, patch_label))
pos_list.append((i, j, k))
else:
patch_list.append(patch_img)
pos_list.append((i, j, k))
return patch_list, pos_list
# pre-process images
def preprocess(image):
# z-score normalization in each slice and each channel
for i in range(image.shape[3]):
for z in range(image.shape[0]):
img_slice = image[z, :, :, i]
image[z, :, :, i] = z_score_norm(img_slice)
return image
# data loader
def data_generator(data_dir, path_list, target_shape, batch_size, is_training, buffer_size=8):
if not is_training:
buffer_size = 1
else:
random.shuffle(path_list)
buffer_size = min(len(path_list), buffer_size)
k = len(path_list) // buffer_size
for i in range(k):
data_list = []
for j in range(i*buffer_size, (i+1)*buffer_size):
img_path = path_list[j]['image'].replace('./', data_dir)
label_path = path_list[j]['label'].replace('./', data_dir)
if not os.path.exists(img_path) or not os.path.exists(label_path):
continue
img, label = read_img(img_path, label_path)
img = preprocess(img)
patch_list, _ = extract_ordered_overlap_patches(img, label, target_shape)
data_list += patch_list
X = np.array([it[0] for it in data_list])
Y = np.array([it[1] for it in data_list])
if is_training:
index = np.random.permutation(len(data_list))
X = X[index, ...]
Y = Y[index, ...]
for step in range(X.shape[0]//batch_size-1):
x = X[step * batch_size:(step + 1) * batch_size, ...]
y = Y[step * batch_size:(step + 1) * batch_size, ...]
yield x, y
class GroupNorm(Layer):
def __init__(self, groups=4):
super(GroupNorm, self).__init__()
self.G = groups
self.eps = 1e-5
def build(self, input_shape):
self.group = self.G if input_shape[-1] % self.G == 0 else 1
self.channel = input_shape[-1]
self.group = min(self.channel, self.group)
self.split = self.channel // self.group
self.gamma = self.add_weight(name='gamma_gn', shape=(1, 1, 1, 1,
input_shape[-1]), initializer='ones', trainable=True)
self.beta = self.add_weight(name='beta_gn', shape=(1, 1, 1, 1,
input_shape[-1]), initializer='zeros', trainable=True)
def call(self, inputs):
N, D, H, W, C = tf.keras.backend.int_shape(inputs)
inputs = tf.reshape(inputs, [-1, D, H, W, self.group, self.split])
mean, var = tf.nn.moments(inputs, [1, 2, 3, 5])
mean = tf.reshape(mean, [-1, 1, 1, 1, self.group, 1])
var = tf.reshape(var, [-1, 1, 1, 1, self.group, 1])
ipt = (inputs - mean) / tf.sqrt(var + self.eps)
output = tf.reshape(inputs, [-1, D, H, W, C]) * self.gamma + self.beta
return output
def conv_res_block(x, filters, activation=ReLU(), kernel_size=3, strides=1, padding='same', num_layers=1):
if num_layers == 1:
x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(x)
x = GroupNorm()(x)
x = activation(x)
return x
shortcut = Conv3D(filters, 1, strides=strides, padding=padding)(x)
shortcut = GroupNorm()(shortcut)
shortcut = activation(shortcut)
for i in range(num_layers):
x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(x)
x = activation(x)
x = x + shortcut
return x
def upsample_block(x, filters, activation=ReLU(), kernel_size=3, strides=1,
padding='same', deconv=False):
if deconv:
x = Conv3DTranspose(filters, 2, strides=2, padding=padding)(x)
else:
x = UpSampling3D(size=2)(x)
x = Conv3D(filters, kernel_size, strides=strides, padding=padding)(x)
x = GroupNorm()(x)
x = activation(x)
return x
# unet3d model
def Unet3d(img_shape, n_filters, n_class):
inputs = Input(shape=img_shape, name='input')
l1 = conv_res_block(inputs, n_filters, num_layers=1)
m1 = MaxPool3D()(l1)
l2 = conv_res_block(m1, n_filters * 2, num_layers=2)
m2 = MaxPool3D()(l2)
l3 = conv_res_block(m2, n_filters * 4, num_layers=3)
m3 = MaxPool3D()(l3)
l4 = conv_res_block(m3, n_filters * 8, num_layers=3)
m4 = MaxPool3D()(l4)
l5 = conv_res_block(m4, n_filters * 16, num_layers=3)
up6 = upsample_block(l5, n_filters * 8)
l6 = conv_res_block(Concatenate()([up6, l4]), n_filters * 8, num_layers=3)
up7 = upsample_block(l6, n_filters * 4)
l7 = conv_res_block(Concatenate()([up7, l3]), n_filters * 4, num_layers=3)
up8 = upsample_block(l7, n_filters * 2)
l8 = conv_res_block(Concatenate()([up8, l2]), n_filters * 2, num_layers=2)
up9 = upsample_block(l8, n_filters)
l9 = conv_res_block(Concatenate()([up9, l1]), n_filters, num_layers=1)
out = Conv3D(n_class, 1, padding='same', activation=keras.activations.softmax)(l9)
model = Model(inputs=inputs, outputs=out, name='output')
return model
# loss function
def explog_loss(y_true, y_pred, n_class, weights=1., w_d=0.8, w_c=0.2, g_d=0.3,
g_c=0.3, eps=1e-5):
"""
Compute exp-log loss
Args:
y_true: ground truth with dimension of (batch, depth, height, width)
y_pred: prediction with dimension of (batch, depth, height, width, n_class)
n_class: classes
weights: weights of n classes, a float number or vector with dimension of (n,)
w_d: weight of dice loss
w_c: weight of cross entropy loss
g_d: exponent of dice loss
g_c: exponent of cross entropy loss
Returns:
score: exp-log loss
"""
y_pred = tf.cast(y_pred, tf.float32)
y_true = tf.cast(tf.one_hot(y_true, n_class), tf.float32)
y_true = tf.reshape(y_true, [-1, n_class])
y_pred = tf.reshape(y_pred, [-1, n_class])
y_pred = tf.clip_by_value(y_pred, eps, 1.0-eps)
intersection = tf.reduce_sum(y_true * y_pred, axis=0)
union = tf.reduce_sum(y_true, axis=0) + tf.reduce_sum(y_pred, axis=0)
dice = (2 * intersection + eps) / (union + eps)
dice = tf.clip_by_value(dice, eps, 1.0-eps)
dice_log_loss = -tf.math.log(dice)
Ld = tf.reduce_mean(tf.pow(dice_log_loss, g_d))
wce = weights * y_true * tf.pow(-tf.math.log(y_pred), g_c)
Lc = tf.reduce_mean(wce)
score = w_d * Ld + w_c * Lc
return score
# metrics
def dice_score(y_true, y_pred, n_class, exp=1e-5):
"""
Compute dice score with ground truth and prediction without argmax
Args:
y_true: ground truth with dimension of (batch, depth, height, width)
y_pred: prediction with dimension of (batch, depth, height, width, n_class)
n_class: classes number
Returns:
score: average dice score in n classes
"""
dices = []
y_pred = np.argmax(y_pred, axis=-1)
for i in range(1, n_class):
pred = y_pred == i
label = y_true == i
intersection = 2 * np.sum(label * pred, axis=(1, 2, 3)) + exp
union = np.sum(label, axis=(1, 2, 3)) + np.sum(pred, axis=(1, 2, 3)) + exp
dice = intersection / union
dices.append(dice)
score = np.mean(dices)
return score
# training process
def train():
batch_size = 1
epochs = 500
input_size = [32, 160, 160, 4]
n_class = 4
first_channels = 8
lr = 0.001
save_model_dir = '../saved_models/'
data_dir = '../data/Task01_BrainTumour/'
with open(os.path.join(data_dir, 'dataset.json'), 'r') as f:
data_info = json.load(f)
path_list = data_info['training']
n_sample = len(path_list)
train_path_list = path_list[:int(n_sample*0.8)]
val_path_list = path_list[int(n_sample * 0.8):]
model = Unet3d(input_size, first_channels, n_class)
input = model.input
pred = model.output
label = tf.placeholder(tf.int32, shape=[None] + input_size[:3])
loss_tf = explog_loss(label, pred, n_class, weights=[1, 10, 20, 20])
global_step = tf.Variable(0, name='global_step', trainable=False)
lr_schedule = tf.train.exponential_decay(
lr,
global_step,
decay_steps=5000,
decay_rate=0.98)
optimizer = tf.train.AdamOptimizer(learning_rate=lr_schedule)
train_opt = optimizer.minimize(loss_tf, global_step=global_step)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
for epoch in range(epochs):
# training
steps = 0
train_loss_avg = 0
train_dice_avg = 0
train_dataset = data_generator(data_dir, train_path_list, input_size[:3], batch_size, True)
for x, y in train_dataset:
_, loss, pred_logits = sess.run([train_opt, loss_tf, pred], feed_dict={input: x, label: y})
dice = dice_score(y, pred_logits, n_class)
train_dice_avg += dice
train_loss_avg += loss
steps += 1
print('epoch %d, steps %d, train loss: %.4f, train dice: %.4f' % (
epoch, steps, train_loss_avg / steps, train_dice_avg / steps))
train_loss_avg /= steps
train_dice_avg /= steps
# validation
steps = 0
val_loss_avg = 0
val_dice_avg = 0
val_dataset = data_generator(data_dir, val_path_list, input_size[:3], batch_size, False)
for x, y in val_dataset:
val_loss, pred_logits = sess.run([loss_tf, pred], feed_dict={input: x, label: y})
dice = dice_score(y, pred_logits, n_class)
val_dice_avg += dice
steps += 1
val_loss_avg += val_loss
val_loss_avg /= steps
val_dice_avg /= steps
print(
'epoch %d, steps %d, validation loss: %.4f, val dice: %4f' % (epoch, steps, val_loss_avg, val_dice_avg))
# save model
saver.save(sess, os.path.join(save_model_dir, "epoch_%d_%.4f_model" % (epoch, val_dice_avg)),
write_meta_graph=False)
def save_img(img_np, save_path):
img_itk = sitk.GetImageFromArray(img_np)
sitk.WriteImage(img_itk, save_path)
def recover_img(patch_preds, pos_list, strides, ori_shape):
sd, sh, sw = strides
patch_shape = patch_preds[0].shape
pd, ph, pw = patch_shape
img = np.zeros(ori_shape)
for patch, pos in zip(patch_preds, pos_list):
i, j, k = pos
img[i*sd:i*sd+pd, j*sh+20:j*sh+20+ph, k*sw+20:k*sw+20+pw] = patch
return img
def predict(model_path, patch_list, input_size, first_channels, n_class):
input_shape = (1,) + tuple(input_size)
inputs = tf.placeholder(tf.float32, shape=input_shape)
model = Unet3d(input_size, first_channels, n_class)
prediction = model(inputs)
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
preds = []
with tf.Session() as sess:
sess.run(init_op)
saver.restore(sess, model_path)
for i in range(len(patch_list)):
pred = sess.run(prediction, feed_dict={inputs: np.expand_dims(patch_list[i], 0)})
pred = np.squeeze(np.argmax(pred, -1))
preds.append(pred)
return preds
def test(model_path, img_path, save_dir):
patch_size = [32, 160, 160, 4]
n_class = 4
first_channels = 8
strides = (16, 32, 32)
if not os.path.exists(save_dir):
os.mkdir(save_dir)
save_path = os.path.join(save_dir, img_path.split('/')[-1])
img0 = read_img(img_path)
img1 = preprocess(img0)
patch_list, pos_list = extract_ordered_overlap_patches(img1, None, patch_size[:3])
preds = predict(model_path, patch_list, patch_size, first_channels, n_class)
pred = recover_img(preds, pos_list, strides, img0.shape[:3])
save_img(pred, save_path)
return img0, pred
if __name__ == '__main__':
if False:
train()
if True:
model_path = '../saved_models/epoch_271_0.6776_model'
img_path = '../data/Task01_BrainTumour/imagesTs/BRATS_485.nii.gz'
save_dir = '../data/Task01_BrainTumour/imagesPre/'
img, pred = test(model_path, img_path, save_dir)
show_imgs([(img, pred)])
需要注意的是,model_path,需要根据模型训练部分最终的结果,进行填写,才能进行实际的训练测试。
在上述代码中,是BRATS_485.nii.gz做为输入数据,针对模型进行测试。
代码运行后,结果如下:
从上图中可以看到,与肿瘤相关的部分,被准确识别并切割出来了。
七、总结
通过5、6两章的学习和实战,才算扣了医学图像处理的深度学习处理的门了,了解到在医学图像处理方面,模型从数据准备到训练测试的完成步骤,为后续的学习打下基础。