PLAI¶
PLAI是一个基于PyTorch的神经网络量化工具 - 用于将浮点神经网络转换为定点 神经网络实现(给GTI 2801s使用), 或从头开始训练定点模型。 PLAI使用主机的CPU和GPU进行训练,使用GTI NPU USB Dongle进行推理验证。
PLAI现支持GNet1、GNet18和GNetfc三种基于VGG-16的模型。
运行要求及建议¶
PLAI运行硬件要求如下:
- Intel i5 3.0 GHz及以上主频或者更高性能CPU (Intel i7为较佳选择)
- 8 GB及以上内存
- 独立显卡6GB及以上显存,推荐使用GTX 1060及以上显卡,AMD显卡不适用。(此项为可选,但强烈推荐,可大大缩短训练时间)
PLAI最终会使用USB Dongle进行推理测试,如果有则可以进行配置使用。
PLAI现支持以下系统:
- Ubuntu LTS 16.04
- Windows 10
运行环境配置¶
环境依赖¶
- Python3
- PyTorch
- OpenCV
- CUDA 9.0及以上版本(可选)
Ubuntu¶
这里以使用Miniconda进行环境配置为例。
首先,从https://conda.io/miniconda.html下载Python3.7版本的Miniconda,这里下载了64-bit版本。
安装过程如下:
ubunut16.04:~$ sudo chmod +x Downloads/Miniconda3-latest-Linux-x86_64.sh
ubunut16.04:~$ ./Downloads/Miniconda3-latest-Linux-x86_64.sh
Welcome to Miniconda3 4.5.11
In order to continue the installation process, please review the license
agreement.
Please, press ENTER to continue
>>> (回车)
...
Do you accept the license terms? [yes|no]
[no] >>> yes(回车)
...
- Press CTRL-C to abort the installation
- Or specify a different location below
[/home/firefly/miniconda3] >>> (回车)
... (安装过程)
Do you wish the installer to prepend the Miniconda3 install location
to PATH in your /home/firefly/.bashrc ? [yes|no]
[no] >>> yes(回车)
以上将Miniconda安装在用户根目录miniconda3下,同时设置默认使用Miniconda的程序。
可以通过以下操作使Miniconda生效并测试:
ubunut16.04:~$ source ~/.bashrc
ubunut16.04:~$ conda -V
conda 4.5.11
接着,如果有英伟达独立显卡加速,可以通过以下操作安装PyTorch和OpenCV:
ubunut16.04:~$ conda install pytorch torchvision -c pytorch
ubunut16.04:~$ pip install opencv-contrib-python
否则,请执行以下操作:
ubunut16.04:~$ conda install pytorch-cpu torchvision-cpu -c pytorch
ubunut16.04:~$ pip install opencv-contrib-python
如果有显卡加速可参考此页面进行安装配置,否则可跳过此步骤。
cuda安装操作摘抄如下:
ubunut16.04:~$ wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_10.0.130-1_amd64.deb
ubunut16.04:~$ sudo dpkg -i cuda-repo-ubuntu1604_10.0.130-1_amd64.deb
ubunut16.04:~$ sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/7fa2af80.pub
ubunut16.04:~$ sudo apt-get update
ubunut16.04:~$ sudo apt-get install cuda
最后,进行USB Dongle配置(可选),可在PLAI目录下执行以下操作配置并验证:
ubunut16.04:~/PLAI$ sudo cp lib/python/gtilib/*.rules /etc/udev/rules.d/
ubunut16.04:~/PLAI$ ls /dev/sg* -l
crw-rw-rw- 1 root disk 21, 0 11月 20 10:28 /dev/sg0
crw-rw-rw- 1 root disk 21, 1 11月 20 10:28 /dev/sg1
如果出现未找到设备的情况请参考常见问题进行排查。
Windows 10¶
待完善…(可部分参考Ubuntu的配置过程)
环境测试¶
以下操作可测试环境完整性,如无错误,则配置完成。
ubunut16.04:~$ python
Python 3.7.0 (default, Jun 28 2018, 13:15:42)
[GCC 7.2.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch,cv2
>>> torch.cuda.is_available()
True
>>>
参数设置¶
training.json文件¶
- num_classes - 数据集分类个数,即data/train和data/val中的文件夹(分类)数
- max_epoch - 所有训练向量一次用于更新权重的次数
- learning_rate - 确定权重变化的速度
- train_batch_size - 根据GPU内存设置
- test_batch_size - 根据GPU内存设置
- mask_bits - 表示每个主层(卷积层)的掩码
- act_bits - 表示每个主层(卷积层)的激活参数
- resume - 设置是否从一个已知的checkpoint中开始训练
- finetune - 可选项, 启用此项通常能得到更高的精度
- full - 设置是否训练一个全精度模型
mask_bits和act_bits参数参考如下:
- GNetfc
- mask_bits: 3,3,1,1,1,1
- act_bits: 5,5,5,5,5,5
- GNet18
- mask_bits: 3,3,3,3,1
- act_bits: 5,5,5,5,5
- GNet1
- mask_bits: 3,3,1,1,1
- act_bits: 5,5,5,5,5
PLAI.py参数¶
PLAI.py暂不支持命令行参数,需要修改PLAI.py源码。
通常修改源码位置大约在142行,原始内容如下:
gtiPLAI = PLAI(num_classes=2, data_dir=data_dir, checkpoint_dir=checkpoint_dir, model_type=0, module_type=1, device_type=1)
- num_classes - 可在training.json中设置
- data_dir - 默认为data目录
- checkpoint_dir - 默认为checkpoint目录
- model_type - 设置训练的模型,0: GNetfc, 1: GNet18, 2:GNet1
- module_type - 0: Conv(w/o bias) + bn + bias,1: Conv(w/ bias) + bn
- device_type - 用于推理的GTI设备类型,ftdi: 0, emmc: 1
模型训练¶
将训练图片数据按类型存放到PLAI data目录下以类型名命名的文件夹中,然后调整training.json和修改PLAI.py(默认网络模型为GNetfc)在PLAI根目录下执行以下操作即可。
ubunut16.04:~$ python PLAI.py
模型使用¶
训练结束会生成coefDat_2801.dat、coefBin_2801.bin(GNetfc没有此文件)和data/pic_label.txt,如果是GNet1可用到AI资料U盘中的sample中测试。 其中userinput.txt可在PLAI nets目录下找到,如netConfig_2801_gnet1.txt。使用示例如下:
liteSample -c coefDat_2801.dat -u netConfig_2801_gnet1.txt -f coefBin_2801.bin -l pic_label.txt
测试时请注意userinput.txt即-u
参数文件中的设备节点是否正确。