SampleQAT
Inference of quantization aware trained networks using TensorRT
Install / Use
/learn @NVIDIA/SampleQATREADME
TensorRT inference of Resnet-50 trained with QAT.
Table Of Contents
- Description
- How does this sample work?
- Prerequisites
- Running the sample
- Additional resources
- Changelog
- Known issues
- License
Description
This sample demonstrates workflow for training and inference of Resnet-50 model trained using Quantization Aware Training. The inference implementation is experimental prototype and is provided with no guarantee of support.
How does this sample work?
This sample demonstrates
- Training a Resnet-50 model using quantization aware training.
- Post processing and conversion to ONNX graph to ensure it is successfully parsed by TensorRT.
- Inference of Resnet-50 QAT graph with TensorRT.
Prerequisites
Dependencies required for this sample
-
<a href="https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow">TensorFlow NGC containers</a> (20.01-tf1-py3 NGC container or above for Steps 1-4. Please use
tf1variants which have TF 1.15.2 version installed. This sample does not work with public version of Tensorflow 1.15.2 library) -
Install the dependencies for Python3 inside the NGC container.
- For Python 3 users, from the root directory, run:
python3 -m pip install -r requirements.txt
- For Python 3 users, from the root directory, run:
-
TensorRT-7.1
-
<a href="https://github.com/NVIDIA/TensorRT/tree/release/7.1/tools/onnx-graphsurgeon">ONNX-Graphsurgeon 0.2.1</a>
Running the sample
NOTE: Steps 1-4 require <a href="https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow">NGC containers</a> (TensorFlow 20.01-tf1-py3 NGC container or above). Steps 5-7 can be executed within or outside the NGC container
Step 1: Quantization Aware Training
Please follow detailed instructions on how to <a href="https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Classification/ConvNets/resnet50v1.5#quantization-aware-training">finetune a RN50 model using QAT</a>.
This stage involoves
- Finetune a RN50 model with quantization nodes and save the final checkpoint.
- Post process the above RN50 QAT checkpoint by reshaping the weights of final FC layer into a 1x1 conv layer.
Step 2: Export frozen graph of RN50 QAT
Export the RN50 QAT graph replacing the final FC layer with a 1x1 conv layer. Please follow these <a href="https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Classification/ConvNets/resnet50v1.5#exporting-frozen-graphs">instructions</a> to generate a frozen graph in desired data formats.
Step 3: Constant folding
Once we have the frozen graph from Step 2, run the following command to perform constant folding on TF graph
python fold_constants.py --input <input_pb> --output <output_pb_name>
Arguments:
--input: Input Tensorflow graph--output_node: Output node name of the RN50 graph (Default:resnet50_v1.5/output/softmax_1)--output: Output name of constant folded TF graph.
Step 4: TF2ONNX conversion
TF2ONNX converter is used to convert the constant folded tensorflow frozen graph into ONNX graph. For RN50 QAT, tf.quantization.quantize_and_dequantize operation (QDQ) is converted into QuantizeLinear and DequantizeLinear operations.
Support for converting QDQ operations has been added in 1.6.1 version of TF2ONNX.
Command to convert RN50 QAT TF graph to ONNX
python3 -m tf2onnx.convert --input <path_to_rn50_qat_graph> --output <output_file_name> --inputs input:0 --outputs resnet50/output/softmax_1:0 --opset 11
Arguments:
--input: Name of TF input graph--output: Name of ONNX output graph--inputs: Name of input tensors--outputs: Name of output tensors--opset: ONNX opset version
Step 5: Post processing ONNX
Run the following command to postprocess the ONNX graph using ONNX-Graphsurgeon API. This step removes the transpose nodes after Dequantize nodes.
python postprocess_onnx.py --input <input_onnx_file> --output <output_onnx_file>
Arguments:
--input: Input ONNX graph--output: Output name of postprocessed ONNX graph.
Step 6: Build TensorRT engine from ONNX graph
python build_engine.py --onnx <input_onnx_graph>
Arguments:
--onnx: Path to RN50 QAT onnx graph--engine: Output file name of TensorRT engine.--verbose: Flag to enable verbose logging
Step 7: TensorRT Inference
Command to run inference on a sample image
python infer.py --engine <input_trt_engine>
Arguments:
--engine: Path to input RN50 TensorRT engine.--labels: Path to imagenet 1k labels text file provided.--image: Path to the sample image--verbose: Flag to enable verbose logging
Sample --help options
To see the full list of available options and their descriptions, use the -h or --help command line option. For example:
usage: <python <filename>.py> [-h]
Additional resources
The following resources provide a deeper understanding about Quantization aware training, TF2ONNX and importing a model into TensorRT using Python:
Quantization Aware Training
- Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
- Quantization Aware Training guide
- Resnet-50 Deep Learning Example
- Deep Residual Learning for Image Recognition
Parsers
Documentation
- Introduction To NVIDIA’s TensorRT Samples
- Working With TensorRT Using The Python API
- Importing A Model Using A Parser In Python
- NVIDIA’s TensorRT Documentation Library
Changelog
June 2020: Initial release of this sample
Known issues
Tensorflow operation tf.quantization.quantize_and_dequantize is used for quantization during training. The gradient of this operation is not clipped based on input range.
License
The sampleQAT license can be found in the LICENSE file.
Related Skills
node-connect
336.9kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
83.0kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
336.9kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
commit-push-pr
83.0kCommit, push, and open a PR
