Enhancing UNet: Tailoring Superior Segmentation Models through Transfer Learning (2024)

Oleg Belkovskiy

·

Follow

15 min read

·

Dec 9, 2023

--

Enhancing UNet: Tailoring Superior Segmentation Models through Transfer Learning (2)

This post is focused on implementing a transfer learning-based variation of the UNET architecture within the PyTorch framework. Originally developed by Olaf Ronneberger et al. in 2015 for Biomedical Image Segmentation at the University of Freiburg, Germany, the UNET architecture derives its name from its distinctive contracting and expansive pathways, forming a U-shaped structure of layers. This architecture and its variations have been proven effective in capturing intricate details while preserving spatial information over numerous applications. However, in the pursuit of continuous performance enhancement, exploration of various techniques remains crucial, with one such technique being the integration of transfer learning.

Original UNet Architecture

UNet architecture cis structured around two core components: an encoder and a decoder. The encoder progressively downsamples the input image, capturing hierarchical features at multiple scales. This process allows the model to learn various patterns and representations within the input image. Following the encoding phase, the decoder takes over, and performs symmetrical upsampling of the encoded features. The decoder’s primary objective is to reconstruct the high-resolution segmentation map from the learned features, thereby restoring spatial detail. Ultimately, the decoder produces an output matrix with the same dimensions as the initial input image. Up to this point, this may resemble a sequential convolutional “image-to-image” network, yet the architecture’s strength lies in its utilization of skip connections.

These skip connections establish link between corresponding layers in the encoder and decoder. By doing so, the decoder gains access to high-resolution feature maps from earlier encoder stages, facilitating the preservation of fine-grained details during upsampling. This interaction between the encoder, decoder and skip connections allows UNet to effectively characterize objects within images, making it particularly well-suited for tasks such as image segmentations.

Enhancing UNet: Tailoring Superior Segmentation Models through Transfer Learning (3)

Harnessing a pre-trained model

UNet training time can be shortened using a transfer learning approach — incorporating a pre-trained model that has already been taught how to extract complex features on a diverse datasets, enhancing the model’s generalization capabilities.

In UNet, it’s common to replace the encoder with a pre-trained model, while retaining a trainable decoder. This is because the encoder learns hierarchical features applicable to various tasks, while the decoder’s role is more task-specific.

Incorporating a pre-trained model would involve selecting and loading the model, identifying relevant layers for feature extraction, and seamlessly integrating it into the UNet architecture alongside the trainable decoder.

In this tutorial, we will perform all of these steps with the EfficientNet_B0 pre-trained model and address challenges such as accessing specific layer outputs for skip connections — a critical step in preserving high-resolution details. Additionally, we will ensure the matching of feature and dimension sizes between the encoder and decoder and tackle the issue of training the encoder and decoder with different learning rates.

For this tutorial I used the Defence Science and Technology Laboratory (DSTL) dataset of satellite imagery.

Originally the dataset contains 25 high resolution images (~11 megapixel each, or over 3350 pixels in height and width) provided with segmentation masks for several object classes. To adapt it to my approach, I divided each of them into 224X224 elements with corresponding masks of the same size, resulting in over 5000 non-overlapping subimages, or about 18000 with 50% overlap. This sub-image size aligns with EfficientNet_B0’s optimal processing capabilities.

The resulting images are divided into non-overlapping test and train groups and normalized with mean and STD of ImageNet dataset — the same sataset used for EfficientNet_B0 training.

Please note that the code for raw data preprocessing and dataset preparation is available on GitHub (link provided at the end of the post).

As for segmentation class, we will only focus on class 5, which provides masks for trees, woodland and other types of vegetation — as it has the largest number of objects in it among all other classes.

Now, let’s define both models — the original UNet and EfficientNet_B0-based version. We will try to keep similarities between them where it’s possible to be able to make a comparison more relevant. For EfficientNet_B0-based model I’ll use first 5 layers of its architecture, which has the following structure (note that 4th column states for number of *output* channel):

Enhancing UNet: Tailoring Superior Segmentation Models through Transfer Learning (4)

Original model

Following the sizes of layers in EfficientNet, here’s a definition of UNet of depth 5 with bottleneck layer and corresponding feature sizes:

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels):
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, inputs):
return self.layers(inputs)

class EncoderBlock(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.conv = DoubleConv(in_c, out_c, out_c)
self.pool = nn.MaxPool2d((2, 2))

def forward(self, inputs):
x = self.conv(inputs)
p = self.pool(x)
return x, p

class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, upsample=1):
super().__init__()
if upsample:
self.upconv = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
else:
self.upconv = nn.Identity()
self.layers = DoubleConv(in_channel, out_channels, out_channels)

def forward(self, x, skip_connection):
upsampled = self.upconv(x)
concatenated = torch.cat([skip_connection, upsampled], dim=1)
output = self.layers(concatenated)
return output

class FinalLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, inputs):
return self.layers(inputs)

class UNet(nn.Module):
def __init__(self, num_classes, input_features=3,
layer1_features=32, layer2_features=16, layer3_features=24,
layer4_features=40, layer5_features=80):
super(UNet, self).__init__()

self.num_classes = num_classes

# Layer feature sizes
self.input_features = input_features
self.layer1_features = layer1_features
self.layer2_features = layer2_features
self.layer3_features = layer3_features
self.layer4_features = layer4_features
self.layer5_features = layer5_features

# Encoder layers
self.encoder1 = EncoderBlock(self.input_features, self.layer1_features)
self.encoder2 = EncoderBlock(self.layer1_features, self.layer2_features)
self.encoder3 = EncoderBlock(self.layer2_features, self.layer3_features)
self.encoder4 = EncoderBlock(self.layer3_features, self.layer4_features)
self.encoder5 = EncoderBlock(self.layer4_features, self.layer5_features)

# Bottleneck Layer
self.bottleneck = DoubleConv(self.layer5_features, self.layer5_features, self.layer5_features,)

# Decoder layers
self.decoder1 = DecoderBlock(self.layer5_features, self.layer4_features)
self.decoder2 = DecoderBlock(self.layer4_features, self.layer3_features)
self.decoder3 = DecoderBlock(self.layer3_features, self.layer2_features)
self.decoder4 = DecoderBlock(self.layer2_features, self.layer1_features, upsample=0)
self.decoder5 = DecoderBlock(self.layer1_features, self.layer1_features)

# Final convolution
self.final_conv = FinalLayer(self.layer1_features, self.num_classes)

def forward(self, x):
# Encoder (contracting path)
output1, p1 = self.encoder1(x)
output2, _ = self.encoder2(p1)
output3, p3 = self.encoder3(output2)
output4, p4 = self.encoder4(p3)
output5, p5 = self.encoder5(p4)

# Bottleneck Layer
bn = self.bottleneck(p5)

up1 = self.decoder1(bn, output5)
up2 = self.decoder2(up1 , output4)
up3 = self.decoder3(up2 , output3)
up4 = self.decoder4(up3 , output2)
up5 = self.decoder5(up4 , output1)
# Final convolution to produce segmentation mask
res = self.final_conv(up5)

return res

# Instantiate the model
num_classes = 1 # Binary segmentation
model_orig = UNet(num_classes)

Let’s review this definition and its specificities.

In high-level, the hierarchy consists of an encoding chain, a bottleneck layer, a decoder chain and a final segmentation layer.

Now, let’s have a look on each one of them in more details:

Encoding chain — a series of similarly defined blocks that differ one from another by their feature sizes. As the original paper suggests, those blocks have 2 convolution layers each, followed by activation function and pooling layer — which are effectively responsible for reducing the spatial dimensions (width and height) of the input feature maps each time by the factor of 2, thereby increasing the receptive field of subsequent layers. Encoder returns both the tensor before and after the pooling operation — as the latter goes to the next encoder layer while the former, with higher spatial resolution, is used for skip connections mechanism. It’s worth noting that a small improvement relative to original architecture is introduced here — BatchNorm component after convolutional layers, which contributes to faster and more stable convergence.

Bottleneck layer — a double convolutional layer that doesn’t use or produce skip connections and operates on the largest receptive field layer.

Decoding chain — a series of similarly defined blocks that mirror those in the encoder chain. Those blocks perform the following series of actions:

  • Upsample previous layer output with transpose convolution operation to enlarge spatial size by the factor of 2.
  • Concatenate resulting tensor with output from the corresponding encoder layer (the one before pooling).
  • Run the result through 2 convolutional layers similar to the way it is made in encoder.

Concatenation of skip connections from the encoder and the upsampled feature maps enables the decoder to recover spatial information lost during encoding.

Final segmentation layer — a single 1x1 convolution layer that follows last decoder layer and reduces features to 1 (or other number according to number of segmentation classes)

Two points worth noting:

  1. Concatenation of the last layer output and skip connections from the encoder effectively doubles the number of features (assuming we manage to ensure those layers are matching in feature size). To keep the consistency, the first convolutional layer in each decoder reduces the number of features back by the factor of 2, keeping it symmetrical to the matching encoder layer.
  2. Due to an effort to achieve maximal similarity to EfficientNet, the second encoder layer retains spatial dimensions (112x112 for input and output). This change is reflected in the decoder with an ‘upsample’ parameter, disabled for layer 4.

Having all that, we can visualize the resulting model’s structure using:

from torchview import draw_graph
model_graph = draw_graph(model, input_size=(1,3,224,224), expand_nested=True, depth=1)
model_graph.visual_graph
Enhancing UNet: Tailoring Superior Segmentation Models through Transfer Learning (5)

EfficientNet-based model

The model with encoder that build upon transfer learning would utilize DoubleConv and FinalLayer classes defined above but would require different DecoderBlock. Its definition goes as following:

class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, upsample=1):
super().__init__()
if upsample:
self.upconv = nn.ConvTranspose2d(in_channels*2, in_channels*2, kernel_size=2, stride=2)
else:
self.upconv = nn.Identity()
self.layers = DoubleConv(in_channels * 2, out_channels, out_channels)

def forward(self, x, skip_connection):

target_height = x.size(2)
target_width = x.size(3)
skip_interp = F.interpolate(
skip_connection, size=(target_height, target_width), mode='bilinear', align_corners=False)

concatenated = torch.cat([skip_interp, x], dim=1)

concatenated = self.upconv(concatenated)

output = self.layers(concatenated)
return output

class UNet(nn.Module):
def __init__(self, num_classes, pretrained=True,
input_features=3, layer1_features=32, layer2_features=16,
layer3_features=24, layer4_features=40, layer5_features=80):
super(UNet, self).__init__()
self.effnet = models.efficientnet_b0(pretrained=pretrained)

self.num_classes = num_classes

# # Layer feature sizes
self.input_features = self.input_features
self.layer1_features = self.layer1_features
self.layer2_features = self.layer2_features
self.layer3_features = self.layer3_features
self.layer4_features = self.layer4_features
self.layer5_features = self.layer5_features

# Encoder layers
self.encoder1 = nn.Sequential(*list(self.effnet.features.children())[0]) #out 32,112*112
self.encoder2 = nn.Sequential(*list(self.effnet.features.children())[1]) #out 16,112*112
self.encoder3 = nn.Sequential(*list(self.effnet.features.children())[2]) #out 24,56*56
self.encoder4 = nn.Sequential(*list(self.effnet.features.children())[3]) #out 40,28*28
self.encoder5 = nn.Sequential(*list(self.effnet.features.children())[4]) #out 40,28*28

del self.effnet

for param in self.encoder1.parameters():
param.requires_grad = False
for param in self.encoder2.parameters():
param.requires_grad = False

# Bottleneck Layer
self.bottleneck = DoubleConv(self.layer5_features, self.layer5_features, self.layer5_features)

# Decoder layers
self.decoder1 = DecoderBlock(self.layer5_features, self.layer4_features)
self.decoder2 = DecoderBlock(self.layer4_features, self.layer3_features)
self.decoder3 = DecoderBlock(self.layer3_features, self.layer2_features)
self.decoder4 = DecoderBlock(self.layer2_features, self.layer1_features, upsample=0)
self.decoder5 = DecoderBlock(self.layer1_features, self.layer1_features)

# Final layer
self.final_conv = FinalLayer(self.layer1_features, self.num_classes)

def forward(self, x):
# Encoder (contracting path)
output1 = self.encoder1(x)
output2 = self.encoder2(output1)
output3 = self.encoder3(output2)
output4 = self.encoder4(output3)
output5 = self.encoder5(output4)

# Bottleneck Layer
bn = self.bottleneck(output5)
up1 = self.decoder1(bn, output5)
up2 = self.decoder2(up1, output4)
up3 = self.decoder3(up2, output3)
up4 = self.decoder4(up3, output2)
up5 = self.decoder5(up4, output1)

# Final convolution to produce segmentation mask
res = self.final_conv(up5)

return res

num_classes = 1
pretrained = True

Let’s break down key differences between this model and classic UNet definition.

The model employs an encoder using EfficientNet’s building blocks, accessed via the ‘features’ attribute. Unlike the original UNet, EfficientNet doesn’t have pooling layers. Consequently, in this model, skip connections share the same output both for further encoding layers and for skip connections. To address the difference in resolution between skip connections and the next encoding block, the same output is passed to both, but with lower resolution.

The decoder adapts due to the smaller spatial size of skip connections, which prevents the preservation of the original UNet decoder shape. To ensure architectural similarity, interpolation is used to align the skip connections’ size with the previous decoder output before concatenation. This interpolation method resizes the tensor’s spatial dimensions to match the desired resolution, maintaining input-output shape similarity with the original UNet.

Of course, it’s worth mentioning that there are other methods that might allow architecture matching, like applying padding instead of interpolation or using skip connection tensors of reduced spatial size in decoder, which might have their implications on architecture and performance.

Similarly to the previous model, layer 4 of decoder doesn’t perform upsampling to keep spatial dimensions 112x112 — to match layer 2 of EfficientNet-based encoder.

The resulting model looks like this:

Enhancing UNet: Tailoring Superior Segmentation Models through Transfer Learning (6)

To train the models I used the following spec:

Learning rate

Our dataset in this challenge is fairly small and has limited diversity. Due to that fact, the transfer learning approach can be of a significant benefit, introducing embedded knowledge and abilities of a capable network, trained on ImageNet dataset. Given that, an EfficientNet-based encoder would require only a subtle change, possibly only in higher-level features, which are more specific for your task than low-level edge, curve and trivial shape detectors. Having this in mind and after some experiments, I’ve chosen learning rates of 5e-4 for decoder and bottleneck layers, while encoder layers 3, 4, and 5 received a rate of 1.5e-4. Levels 1 and 2 of the encoder remained frozen during training.

As for the original UNet model, I went for a uniform 5e-4 learning rate applied across all layers.

Scheduler

An exponential LR scheduler with a reduction factor of 0.85 was utilized, allowing for a gradual decrease of the LR to approximately ¼ of the initial LR throughout the epochs. These values showed fairly good convergence and were adopted for training.

Number of epochs:

Considering around 15,000 training images and 1,200 test images, a total of 10 epochs were selected, taking into account time constraints and dataset size.

Optimizer:

ADAM optimizer as popular and effective choice for this kind of tasks

Loss function:

The training process involved a combination of Binary Cross-Entropy (BCE) loss and Dice loss. BCE loss, expressed as the negative log likelihood of predicting pixel classes, measures the dissimilarity between predicted and ground truth pixel-wise classifications, aiming to minimize this difference. On the other hand, Dice loss, calculated as twice the intersection of predicted and ground truth masks divided by their sum, evaluates the overlap of segmentation masks, emphasizing the model’s ability to capture object boundaries and fine details.

While BCE loss focuses on class probability estimation, it, when used alone, might produce ‘ugly’, pixelated and discontinuous segmentation mask, like this:

Enhancing UNet: Tailoring Superior Segmentation Models through Transfer Learning (7)

I used a 75%/25% combination with a benefit to BCE — those values showed good performance in hyperparameters fine-tuning.

BCE function is available in torch library, Dice loss requires definition:

import torch.optim as optim

class DiceLoss(nn.Module):
def __init__(self, smooth=1):
super(DiceLoss, self).__init__()
self.smooth = smooth

def forward(self, inputs, targets):
smooth = 1.0 # Smoothing factor to prevent division by zero
intersection = torch.sum(inputs * targets)
union = torch.sum(inputs) + torch.sum(targets)
dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
return 1.0 - dice

class CombinedLoss(nn.Module):
def __init__(self, gamma=0.85, weight_dice=0.25):
super(CombinedLoss, self).__init__()
self.criterion_BCE = nn.BCELoss()
self.criterion_Dice = DiceLoss()
self.gamma = gamma
self.weight_dice = weight_dice
self.dice_step = dice_step

def get_optimizer_and_scheduler(self, model_parameters):
optimizer = optim.Adam(model_parameters)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.gamma)
return optimizer, scheduler

def forward(self, outputs, targets, epoch):

loss_BCE = self.criterion_BCE(outputs, targets)
loss_Dice = self.criterion_Dice(outputs, targets)
loss_comb = loss_Dice * self.weight_dice + \
loss_BCE * (1 - self.weight_dice)

return loss_comb

Having all that, training loop is defined as following:

# Define loss function and optimizer
criterion_BCE = nn.BCELoss()
criterion_Dice = DiceLoss()

# define and init epoch params
num_epochs = 10
lr_e = 0.0005 # Learning rate for encoder
lr_d = 0.0005 # Learning rate for decoder
# Set learning rates for different model parts
model_parameters = [
{'params': model_orig.encoder1.parameters(), 'lr': lr_e},
{'params': model_orig.encoder2.parameters(), 'lr': lr_e},
{'params': model_orig.encoder3.parameters(), 'lr': lr_e},
{'params': model_orig.encoder4.parameters(), 'lr': lr_e},
{'params': model_orig.encoder5.parameters(), 'lr': lr_e},
{'params': model_orig.bottleneck.parameters(), 'lr': lr_d},
{'params': model_orig.decoder1.parameters(), 'lr': lr_d},
{'params': model_orig.decoder2.parameters(), 'lr': lr_d},
{'params': model_orig.decoder3.parameters(), 'lr': lr_d},
{'params': model_orig.decoder4.parameters(), 'lr': lr_d},
{'params': model_orig.decoder5.parameters(), 'lr': lr_d},
{'params': model_orig.final_conv.parameters(), 'lr': lr_d},
]
# Instantiate CombinedLoss
combined_loss = CombinedLoss()
# Get optimizer and scheduler based on model parameters
optimizer, scheduler = combined_loss.get_optimizer_and_scheduler(model_parameters)

# Define function to train one epoch
def train_one_epoch(model, dataloader, optimizer, combined_loss, epoch):
model.train()
total_loss = 0
num_batches = 0

# Iterate through training batches
for batch in dataloader:
inputs, targets = batch['image'], batch['mask']
outputs = model(inputs)
loss_comb = combined_loss(outputs, targets, epoch)

optimizer.zero_grad()
loss_comb.backward()
optimizer.step()

total_loss += loss_comb.item()
num_batches += 1

epoch_loss = total_loss / num_batches
return epoch_loss

def calc_validation_loss_one_epoch(model, dataloader, combined_loss):
model.eval()
num_batches_test = 0

# Iterate through validation batches
with torch.no_grad():
for batch in dataloader:
num_batches_test += 1
inputs_test, targets_test = batch['image'], batch['mask']
outputs_test = model(inputs_test)
loss_test = combined_loss(outputs_test, targets_test)
total_loss_test += loss_test.item()

epoch_loss_test = total_loss_test / num_batches_test

return epoch_loss_test

# Initialize lists to store results
train_loss_hist = []
val_loss_hist = []

# Iterate through epochs for training
for epoch in range(num_epochs):
# Train one epoch
epoch_loss = train_one_epoch(model_orig, dataloader_train, optimizer, combined_loss, epoch)

# Calculate validation loss for the epoch
epoch_loss_test = calc_validation_loss_one_epoch(model_orig, dataloader_test, combined_loss)
model_orig.train()

# Other updates and storage as needed
scheduler.step()
train_loss_hist.append(epoch_loss)
val_loss_hist.append(epoch_loss_test)

Overview

In this section I’ll compare results of training for 3 models — original UNet, pre-trained EfficientNet-based UNet and not pre-trained EfficientNet-based Unet. The last one would serve as a reference for measuring the impact of knowledge stored in the pre-trained encoder relative to solely architectural benefits of this model.

Comparison will be based on BCE-Dice combined loss over 10 epochs, on train and validation datasets.

Visualization

Here are sample images for illustration of performance of all 3 models on the following 224x224 subimage:

Enhancing UNet: Tailoring Superior Segmentation Models through Transfer Learning (8)
Enhancing UNet: Tailoring Superior Segmentation Models through Transfer Learning (9)

Performance analysis

Enhancing UNet: Tailoring Superior Segmentation Models through Transfer Learning (10)
Enhancing UNet: Tailoring Superior Segmentation Models through Transfer Learning (11)

Convergence Speed:

The pretrained EfficientNet-based UNet initiated with lower initial losses and consistently displayed superior validation loss values in each epoch. This demonstrates its ability to rapidly adapt and learn meaningful representations from the data, leveraging the knowledge encoded in pre-trained weights compared to the original UNet and non-pretrained EfficientNet-based UNet models that learn from scratch without leveraging any pre-trained weights.

Stability:

The pretrained EfficientNet-based UNet illustrated consistent and steady declines in both train and validation losses throughout the training process. This consistency indicates stable learning and reliable model behavior, contributing to its robustness and reliability in generating accurate segmentations.

The original UNet and non-pretrained EfficientNet-based UNet models showed a declining trend in training loss, suggesting effective learning. However, their validation loss exhibited fluctuations while gradually decreasing, hinting at potential challenges in generalizing to unseen data. These models might benefit from techniques like regularization and more sophisticated learning rate management to improve their adaptability to new data. Though it’s important to note that smoother convergence might be also attributed to smaller learning rate for encoders in the EfficientNet-based model.

Generalization and Overfitting:

The pretrained EfficientNet-based UNet achieved consistent and close values between train and validation losses, indicating, as expected, strong generalization abilities as it features a more robust feature extractor in the encoder that trained on Imagenet.

The non-pretrained models only learned from our limited dataset, which is small, and naturally has poor generalization, hence larger and less consistent performance gap between train and validation dataset.

Model Performance after 10 Epochs:

After 10 epochs, the pretrained EfficientNet-based UNet model achieved notably lower validation losses compared to the other models, suggesting superior performance in segmentation tasks on unseen data.

Summary

From the above we can state that across the evaluated models, the pretrained EfficientNet-based UNet showed superior performance in various aspects. It had faster convergence, steady decline in both train and validation losses, and notably lower validation loss values after just a few epochs. This suggests its reliability and robust learning capability. Conversely, while both the original UNet and non-pretrained EfficientNet-based UNet models displayed similar results, patterns of fluctuating yet decreasing trends in validation loss, and also slower convergence and poorer ability to generalize to unseen data, indicating potential issues with overfitting.

As we could see, the application of transfer learning in UNet architecture, particularly leveraging pre-trained models like EfficientNet_B0, has shown compelling advantages for image segmentation tasks. The faster convergence rate, steady reduction in both training and validation losses, and notably superior performance on unseen data highlight its enhanced robustness. This technique not only accelerates the training process but also enhances the model’s generalization ability, allowing it to handle diverse inputs while having a small and relatively hom*ogeneous training dataset.

However, adoption of pre-trained models like EfficientNet requires careful memory and trainable parameters consideration. EfficientNet’s encoder used in this post comprises a larger number of parameters compared to the original UNet encoder, leading to higher memory consumption, posing potential constraints, especially in resource-limited environments with low computational capacity. Specifically, the original UNet architecture I used composed of 441137 parameters, all of them are trainable, while the transfer learning approach produced a model with 728949 parameters, with 726573 of them opened for training — about 60% larger model — a size increase that definitely worth consideration. Of course, one can choose leaner architecture to base the encoder on to mitigate this concern.

In conclusion, transfer learning offers a potent strategy in UNet-based image segmentation by leveraging prior knowledge for efficient learning. Yet, dataset size, model complexity, and computational constraints have to be carefully reviewed to match its requirement, helping harness its full potential effectively. This approach holds immense promise in enhancing model performance and boosting convergence, but a nuanced approach is crucial for optimal results in diverse practical scenarios.

Full notebook is available on:

https://github.com/olegtlv/DSTL_UNet_EfficientNet_3channel

Enhancing UNet: Tailoring Superior Segmentation Models through Transfer Learning (2024)
Top Articles
Latest Posts
Article information

Author: Eusebia Nader

Last Updated:

Views: 6102

Rating: 5 / 5 (60 voted)

Reviews: 91% of readers found this page helpful

Author information

Name: Eusebia Nader

Birthday: 1994-11-11

Address: Apt. 721 977 Ebert Meadows, Jereville, GA 73618-6603

Phone: +2316203969400

Job: International Farming Consultant

Hobby: Reading, Photography, Shooting, Singing, Magic, Kayaking, Mushroom hunting

Introduction: My name is Eusebia Nader, I am a encouraging, brainy, lively, nice, famous, healthy, clever person who loves writing and wants to share my knowledge and understanding with you.