• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

wilile26811249/MobileViT: Unofficial PyTorch implementation of MobileViT based o ...

原作者: [db:作者] 来自: 网络 收藏 邀请

开源软件名称(OpenSource Name):

wilile26811249/MobileViT

开源软件地址(OpenSource Url):

https://github.com/wilile26811249/MobileViT

开源编程语言(OpenSource Language):

Python 100.0%

开源软件介绍(OpenSource Introduction):

MobileViT

Unofficial PyTorch implementation of MobileViT based on paper MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TRANSFORMER.


Table of Contents


Model Architecture

Trulli

MobileViT Architecture

Usage

import torch
import models

img = torch.randn(1, 3, 256, 256)
net = models.MobileViT_S()

# XXS: 1.3M 、 XS: 2.3M 、 S: 5.6M
print("MobileViT-S params: ", sum(p.numel() for p in net.parameters()))
print(f"Output shape: {net(img).shape}")

Training

  • Single node with one GPU
python main.py
  • Single node with multi GPU
CUDA_VISIBLE_DEVICES=3,4 python -m torch.distributed.launch --nproc_per_node=2 --master_port=6666 main_ddp.py
optional arguments:
  -h, --help            show this help message and exit
  --gpu_device GPU_DEVICE
                        Select specific GPU to run the model
  --batch-size N        Input batch size for training (default: 64)
  --epochs N            Number of epochs to train (default: 20)
  --num-class N         Number of classes to classify (default: 10)
  --lr LR               Learning rate (default: 0.01)
  --weight-decay WD     Weight decay (default: 1e-5)
  --model-path PATH     Path to save the model

Experiment

Accuracy of ImageNet

Loss of ImageNet

MobileVit-S Pretrained Weights: weight

MobileVit-XXS Pretrained Weights: weight

How to load pretrained weight(training with DataParrael)

Solution by the @Sehaba95:

def load_mobilevit_weights(model_path):
  # Create an instance of the MobileViT model
  net = MobileViT_S()

  # Load the PyTorch state_dict
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']

  # Since there is a problem in the names of layers, we will change the keys to meet the MobileViT model architecture
  for key in list(state_dict.keys()):
    state_dict[key.replace('module.', '')] = state_dict.pop(key)

  # Once the keys are fixed, we can modify the parameters of MobileViT
  net.load_state_dict(state_dict)

  return net

net = load_mobilevit_weights("MobileViT_S_model_best.pth.tar")

Model Dataset Learning Rate LR Scheduler Optimizer Weight decay Acc@1/Val Acc@5/Val
MobileViT ImageNet-1k 0.05 Cosine LR SGDM 1e-5 61.918% 83.05%

Citation

@InProceedings{Sachin2021,
  title = {MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TRANSFORMER},
  author = {Sachin Mehta and Mohammad Rastegari},
  booktitle = {},
  year = {2021}
}

If this implement have any problem please let me know, thank you.




鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap