1.1 Fine-tuning là gì ?

Chắc hẳn phần lớn ai thao tác làm việc với các mã sản phẩm trong deep learning phần nhiều đã nghe/quen với quan niệm Transfer learningFine tuning. Khái niệm tổng quát: Transfer learning là tận dụng tri thức học được từ 1 vấn đề để áp dụng vào 1 vấn đề có tương quan khác. Một ví dụ đối chọi giản: thay vì chưng train 1 mã sản phẩm mới trọn vẹn cho việc phân các loại chó/mèo, bạn ta rất có thể tận dụng 1 mã sản phẩm đã được train trên ImageNet dataset với hằng triệu ảnh. Pre-trained mã sản phẩm này sẽ tiến hành train tiếp trên tập dataset chó/mèo, quá trình train này diễn ra nhanh hơn, công dụng thường tốt hơn. Có nhiều kiểu Transfer learning, các bạn cũng có thể tham khảo trong bài xích này: Tổng hòa hợp Transfer learning. Trong bài xích này, mình vẫn viết về 1 dạng transfer learning phổ biến: Fine-tuning.

Bạn đang xem: Fine-tune là gì

Bạn vẫn xem: Fine tune là gì

Đang xem: Fine tuning là gì

Hiểu solo giản, fine-tuning là các bạn lấy 1 pre-trained model, tận dụng một phần hoặc cục bộ các layer, thêm/sửa/xoá 1 vài ba layer/nhánh để tạo ra 1 mã sản phẩm mới. Thường các layer đầu của mã sản phẩm được freeze (đóng băng) lại – tức weight các layer này sẽ không bị thay đổi giá trị trong quy trình train. Vì sao bởi các layer này đã có khả năng trích xuất tin tức mức trìu tượng thấp , khả năng này được học từ quá trình training trước đó. Ta freeze lại nhằm tận dụng được kỹ năng này với giúp việc train ra mắt nhanh hơn (model chỉ phải update weight ở những layer cao). Có khá nhiều các Object detect mã sản phẩm được desgin dựa trên những Classifier model. VD Retina model (Object detect) được tạo ra với backbone là Resnet.


*

1.2 vì sao pytorch thay vày Keras ?

Chủ đề bài viết hôm nay, bản thân sẽ gợi ý fine-tuning Resnet50 – 1 pre-trained mã sản phẩm được hỗ trợ sẵn vào torchvision của pytorch. Nguyên nhân là pytorch mà chưa phải Keras ? tại sao bởi việc fine-tuning model trong keras rất đối chọi giản. Dưới đấy là 1 đoạn code minh hoạ cho việc xây dựng 1 Unet dựa trên Resnet vào Keras:

from tensorflow.keras import applicationsresnet = applications.resnet50.ResNet50()layer_3 = resnet.get_layer(“activation_9”).outputlayer_7 = resnet.get_layer(“activation_21”).outputlayer_13 = resnet.get_layer(“activation_39”).outputlayer_16 = resnet.get_layer(“activation_48”).output#Adding outputs decoder with encoder layersfcn1 = Conv2D(…)(layer_16)fcn2 = Conv2DTranspose(…)(fcn1)fcn2_skip_connected = Add()()fcn3 = Conv2DTranspose(…)(fcn2_skip_connected)fcn3_skip_connected = Add()()fcn4 = Conv2DTranspose(…)(fcn3_skip_connected)fcn4_skip_connected = Add()()fcn5 = Conv2DTranspose(…)(fcn4_skip_connected)Unet = Model(inputs = resnet.input, outputs=fcn5)Bạn rất có thể thấy, fine-tuning mã sản phẩm trong Keras đích thực rất đối chọi giản, dễ dàng làm, dễ dàng hiểu. Việc showroom thêm các nhánh rất đơn giản bởi cú pháp 1-1 giản. Trong pytorch thì ngược lại, kiến thiết 1 model Unet tương tự sẽ tương đối vất vả và phức tạp. Người mới học sẽ chạm mặt khó khăn vì chưng trên mạng ko nhiều những hướng dẫn cho câu hỏi này. Vậy nên bài xích này mình đang hướng dẫn cụ thể cách fine-tune vào pytorch để vận dụng vào việc Visual Saliency prediction

2. Visual Saliency prediction

2.1 What is Visual Saliency ?


*

Khi nhìn vào 1 bức ảnh, mắt thường xuyên có xu hướng tập trung nhìn vào 1 vài cửa hàng chính. Ảnh trên đấy là 1 minh hoạ, màu kim cương được sử dụng để thể hiện mức độ thu hút. Saliency prediction là câu hỏi mô rộp sự triệu tập của mắt người khi quan ngay cạnh 1 bức ảnh. Thay thể, bài bác toán yên cầu xây dựng 1 model, model này nhận ảnh đầu vào, trả về 1 mask mô bỏng mức độ thu hút. Như vậy, mã sản phẩm nhận vào 1 input đầu vào image với trả về 1 mask có kích thước tương đương.

Để rõ hơn về việc này, chúng ta có thể đọc bài: Visual Saliency Prediction with Contextual Encoder-Decoder Network.Dataset thông dụng nhất: SALICON DATASET

2.2 Unet

Note: Bạn hoàn toàn có thể bỏ qua phần này nếu sẽ biết về Unet

Đây là 1 bài toán Image-to-Image. Để xử lý bài toán này, mình sẽ xây dựng 1 model theo bản vẽ xây dựng Unet. Unet là 1 trong kiến trúc được sử dụng nhiều trong việc Image-to-image như: semantic segmentation, tự động color, super resolution … bản vẽ xây dựng của Unet tất cả điểm tựa như với bản vẽ xây dựng Encoder-Decoder đối xứng, được thêm những skip connection từ Encode thanh lịch Decode tương ứng. Về cơ bản, những layer càng tốt càng trích xuất thông tin ở nấc trìu tượng cao, điều ấy đồng nghĩa với việc các thông tin nút trìu tượng thấp như con đường nét, color sắc, độ phân giải… sẽ ảnh hưởng mất mát đi trong quy trình lan truyền. Fan ta thêm những skip-connection vào để giải quyết và xử lý vấn đề này.

Với phần Encode, feature-map được downscale bằng các Convolution. Ngược lại, tại đoạn decode, feature-map được upscale bởi những Upsampling layer, trong bài này mình sử dụng các Convolution Transpose.


*

2.3 Resnet

Để xử lý bài toán, mình sẽ xây dựng dựng mã sản phẩm Unet cùng với backbone là Resnet50. Bạn nên tìm hiểu về Resnet nếu chưa biết về kiến trúc này. Hãy quan gần kề hình minh hoạ dưới đây. Resnet50 được tạo thành các khối béo . Unet được xây dừng với Encoder là Resnet50. Ta sẽ kéo ra output của từng khối, tạo các skip-connection kết nối từ Encoder quý phái Decoder. Decoder được kiến tạo bởi những Convolution Transpose layer (xen kẽ trong những số ấy là các lớp Convolution nhằm mục đích giảm số chanel của feature bản đồ -> giảm số lượng weight mang lại model).

Theo ý kiến cá nhân, pytorch rất dễ dàng code, dễ nắm bắt hơn không hề ít so cùng với Tensorflow 1.x hoặc ngang ngửa Keras. Mặc dù nhiên, việc fine-tuning model trong pytorch lại cạnh tranh hơn tương đối nhiều so cùng với Keras. Vào Keras, ta không phải quá vồ cập tới loài kiến trúc, luồng cách xử lý của model, chỉ cần lấy ra những output tại 1 số layer nhất quyết làm skip-connection, ghép nối và tạo ra ra model mới.


*

3. Code

Tất cả code của chính mình được đóng gói trong file notebook Salicon_main.ipynb. Chúng ta có thể tải về với run code theo link github: github/trungthanhnguyen0502 . Trong nội dung bài viết mình đang chỉ chuyển ra hầu hết đoạn code chính.

Import các package

import albumentations as Aimport numpy as npimport torchimport torchvisionimport torch.nn as nn import torchvision.transforms as Timport torchvision.models as modelsfrom torch.utils.data import DataLoader, Datasetimport ….

3.1 utils functions

Trong pytorch, tài liệu có thứ tự dimension không giống với Keras/TF/numpy. Thông thường với numpy xuất xắc keras, hình ảnh có dimension theo trang bị tự (batchsize,h,w,chanel)(batchsize, h, w, chanel)(batchsize,h,w,chanel). Thứ tự vào Pytorch ngược lại là (batchsize,chanel,h,w)(batchsize, chanel, h, w)(batchsize,chanel,h,w). Mình sẽ xây dựng 2 hàm toTensor với toNumpy để biến hóa qua lại giữa hai format này.

def toTensor(np_array, axis=(2,0,1)): return torch.tensor(np_array).permute(axis)def toNumpy(tensor, axis=(1,2,0)): return tensor.detach().cpu().permute(axis).numpy() ## display one image in notebookdef plot_img(img): … ## display multi imagedef plot_imgs(imgs): …

3.2 Define model

3.2.1 Conv và Deconv

Mình sẽ xây dựng dựng 2 function trả về module Convolution cùng Convolution Transpose (Deconv)

def Deconv(n_input, n_output, k_size=4, stride=2, padding=1): Tconv = nn.ConvTranspose2d( n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False) block = return nn.Sequential(*block) def Conv(n_input, n_output, k_size=4, stride=2, padding=0, bn=False, dropout=0): conv = nn.Conv2d( n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False) block = return nn.Sequential(*block)

3.2.2 Unet model

Init function: ta đang copy các layer đề nghị giữ từ resnet50 vào unet. Tiếp nối khởi tạo những Conv / Deconv layer và các layer cần thiết.

Forward function: cần bảo vệ luồng cách xử lý của resnet50 được không thay đổi giống code cội (trừ Fully-connected layer). Sau đó ta ghép nối những layer lại theo phong cách xây dựng Unet đã thể hiện trong phần 2.

Xem thêm: Tìm Hiểu Về Thư Viện Axios, Cơ Bản Về Axios Trong Vuejs

class Unet(nn.Module): def __init__(self, resnet): super().__init__() self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() # get some layer from resnet khổng lồ make skip connection self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 # convolution layer, use lớn reduce the number of channel => reduce weight number self.conv_5 = Conv(2048, 512, 1, 1, 0) self.conv_4 = Conv(1536, 512, 1, 1, 0) self.conv_3 = Conv(768, 256, 1, 1, 0) self.conv_2 = Conv(384, 128, 1, 1, 0) self.conv_1 = Conv(128, 64, 1, 1, 0) self.conv_0 = Conv(32, 1, 3, 1, 1) # deconvolution layer self.deconv4 = Deconv(512, 512, 4, 2, 1) self.deconv3 = Deconv(512, 256, 4, 2, 1) self.deconv2 = Deconv(256, 128, 4, 2, 1) self.deconv1 = Deconv(128, 64, 4, 2, 1) self.deconv0 = Deconv(64, 32, 4, 2, 1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) skip_1 = x x = self.maxpool(x) x = self.layer1(x) skip_2 = x x = self.layer2(x) skip_3 = x x = self.layer3(x) skip_4 = x x5 = self.layer4(x) x5 = self.conv_5(x5) x4 = self.deconv4(x5) x4 = torch.cat(, dim=1) x4 = self.conv_4(x4) x3 = self.deconv3(x4) x3 = torch.cat(, dim=1) x3 = self.conv_3(x3) x2 = self.deconv2(x3) x2 = torch.cat(, dim=1) x2 = self.conv_2(x2) x1 = self.deconv1(x2) x1 = torch.cat(, dim=1) x1 = self.conv_1(x1) x0 = self.deconv0(x1) x0 = self.conv_0(x0) x0 = self.sigmoid(x0) return x0 device = torch.device(“cuda”)resnet50 = models.resnet50(pretrained=True)model = Unet(resnet50)model.to(device)## Freeze resnet50″s layers in Unetfor i, child in enumerate(model.children()): if i 7: for param in child.parameters(): param.requires_grad = False

3.3 Dataset and Dataloader

Dataset trả dấn 1 list các image_path và mask_dir, trả về image với mask tương ứng.

Define MaskDataset

class MaskDataset(Dataset): def __init__(self, img_fns, mask_dir, transforms=None): self.img_fns = img_fns self.transforms = transforms self.mask_dir = mask_dir def __getitem__(self, idx): img_path = self.img_fns img_name = img_path.split(“/”).split(“.”) mask_fn = f”self.mask_dir/img_name.png” img = cv2.imread(img_path) mask = cv2.imread(mask_fn) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) if self.transforms: sample = “image”: img, “mask”: mask sample = self.transforms(**sample) img = sample mask = sample # to Tensor img = img/255.0 mask = np.expand_dims(mask, axis=-1)/255.0 mask = toTensor(mask).float() img = toTensor(img).float() return img, mask def __len__(self): return len(self.img_fns)Test dataset

img_fns = glob(“./Salicon_dataset/image/train/*.jpg”)mask_dir = “./Salicon_dataset/mask/train”train_transform = A.Compose(, height=256, width=256, p=0.4), A.HorizontalFlip(p=0.5), A.Rotate(limit=(-10,10), p=0.6),>)train_dataset = MaskDataset(img_fns, mask_dir, train_transform)train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)# kiểm tra datasetimg, mask = next(iter(train_dataset))img = toNumpy(img)mask = toNumpy(mask)img = (img*255.0).astype(np.uint8)mask = (mask*255.0).astype(np.uint8)heatmap_img = cv2.applyColorMap(mask, cv2.COLORMAP_JET)combine_img = cv2.addWeighted(img, 0.7, heatmap_img, 0.3, 0)plot_imgs(

3.4 Train model

Vì bài toán đơn giản và dễ dàng và khiến cho dễ hiểu, mình vẫn train theo cách đơn giản dễ dàng nhất, ko validate trong qúa trình train nhưng mà chỉ lưu mã sản phẩm sau 1 số epoch độc nhất định

train_params = optimizer = torch.optim.Adam(train_params, lr=0.001, betas=(0.9, 0.99))epochs = 5model.train()saved_dir = “model”os.makedirs(saved_dir, exist_ok=True)loss_function = nn.MSELoss(reduce=”mean”)for epoch in range(epochs): for imgs, masks in tqdm(train_loader): imgs_gpu = imgs.to(device) outputs = model(imgs_gpu) masks = masks.to(device) loss = loss_function(outputs, masks) loss.backward() optimizer.step()

3.5 chạy thử model

img_fns = glob(“./Salicon_dataset/image/val/*.jpg”)mask_dir = “./Salicon_dataset/mask/val”val_transform = A.Compose()model.eval()val_dataset = MaskDataset(img_fns, mask_dir, val_transform)val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, drop_last=True)imgs, mask_targets = next(iter(val_loader))imgs_gpu = imgs.to(device)mask_outputs = model(imgs_gpu)mask_outputs = toNumpy(mask_outputs, axis=(0,2,3,1))imgs = toNumpy(imgs, axis=(0,2,3,1))mask_targets = toNumpy(mask_targets, axis=(0,2,3,1))for i, img in enumerate(imgs): img = (img*255.0).astype(np.uint8) mask_output = (mask_outputs*255.0).astype(np.uint8) mask_target = (mask_targets*255.0).astype(np.uint8) heatmap_label = cv2.applyColorMap(mask_target, cv2.COLORMAP_JET) heatmap_pred = cv2.applyColorMap(mask_output, cv2.COLORMAP_JET) origin_img = cv2.addWeighted(img, 0.7, heatmap_label, 0.3, 0) predict_img = cv2.addWeighted(img, 0.7, heatmap_pred, 0.3, 0) result = np.concatenate((img,origin_img, predict_img),axis=1) plot_img(result)Kết quả thu được: