OpenVINO Blog

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.
##
Results
Sort By:
Title
|
Date
Fiona
Zhao

Use Encrypted Model with OpenVINO

November 9, 2023

Deploying deep-learning capabilities to edge devices can present security challenges like ensuring inference integrity, or providing copyright protection of your deep-learning models. OpenVINO provide a simple method with crypto algorithm to protect model in disk. Model encryption, decryption and authentication are not provided by OpenVINO but can be implemented with third-party tools (i.e., OpenSSL). In this example, we use AES-128-cbc algorithm in OpenSSL to demonstrate the model cryptography.

As you can see the mechanism in below image, there are two part to process:

  1. First is to encrypt your plain IR model into encrypted model.
  2. The second part is to use the same password key and IV which used for encryption before to decrypt model at model loading runtime.
The schema of model encryption and decryption by OpenVINO
Step 1: Encrypt model

Make sure you install the OpenSSL, for example in Ubuntu:


$ sudo apt install openssl

Then use command line to do model encryption by OpenSSL AES-128-CBC algorithm. In this simply example, I use same password for Key and IV, it is hexadecimal of string "openvino encrypt". You can use some online str2hex tool to generate hex representation of your string password.


$ openssl enc -aes-128-cbc -in openvino_model.xml -out openvino_model_enc.xml -K 6f70656e76696e6f20656e6372797074 -iv 6f70656e76696e6f20656e6372797074
$ openssl enc -aes-128-cbc -in openvino_model.bin -out openvino_model_enc.bin -K 6f70656e76696e6f20656e6372797074 -iv 6f70656e76696e6f20656e6372797074

Step 2: Decrypt model

Here provide the sample code to read encrypted model into buffer and decrypt to plain model binary. Then read and compile model.


#include <fstream>
#include <iostream>
#include <vector>
#include <cmath>
#include <cctype>
#include <string>
#include <openvino/runtime/core.hpp>
#include <openssl/aes.h>
#include <boost/algorithm/hex.hpp>

using namespace std;

vector<unsigned char> aes_128_cbc_decrypt(
    vector<unsigned char> &cipher,
    std::vector<unsigned char> &key,
    std::vector<unsigned char> iv) {

    AES_KEY ctx;
    AES_set_decrypt_key(key.data(), 128, &ctx);
    std::vector<uint8_t> plain;
    //cipherLen = clearLen + 16 - (clearLen mod 16)
    int plain_size = ceil(cipher.size()/16)*16; //make sure alloc buffer is enough to plain_size
    plain.resize(plain_size);
    std::cout << "AES_cbc_encrypt start:" << std::endl;
    AES_cbc_encrypt(cipher.data(), plain.data(), plain.size(), &ctx, iv.data(), AES_DECRYPT);
    std::cout << "AES_cbc_encrypt done" << std::endl;
    return plain;
}

void decrypt_file(std::ifstream & stream,
                  std::vector<unsigned char> & key,
                  std::vector<unsigned char> & iv,
                  std::vector<uint8_t> & result) {
    std::vector<unsigned char> cipher((std::istreambuf_iterator<char>(stream)),  std::istreambuf_iterator<char>());
    std::cout << "aes_128_cbc_decrypt" << std::endl;
    std::vector<unsigned char> decrypt_model = aes_128_cbc_decrypt(cipher, key, iv);
    result = decrypt_model;

}

int main() {
    std::string key_hex = "6f70656e76696e6f20656e6372797074";
    std::string iv_hex = "6f70656e76696e6f20656e6372797074";
    std::vector key_bytes;
    std::vector iv_bytes;
    boost::algorithm::unhex(key_hex, std::back_inserter(key_bytes));
    boost::algorithm::unhex(iv_hex, std::back_inserter(iv_bytes));
    std::vector<uint8_t> model_data, weights_data;
    std::ifstream model_file("openvino_model_enc.xml",std::ios::in | std::ios::binary), weights_file("openvino_model_enc.bin",std::ios::in | std::ios::binary);
    // Read model files and decrypt them into temporary memory block
    std::cout << "decrypt file" << std::endl;
    decrypt_file(model_file, key_bytes, iv_bytes, model_data); //key & iv is the same
    decrypt_file(weights_file, key_bytes, iv_bytes, weights_data);
    ov::Core core;
    // Load model from temporary memory block
    std::string str_model(model_data.begin(), model_data.end());
    ov::CompiledModel compiled_model=core.compile_model(str_model,ov::Tensor(ov::element::u8, {weights_data.size()}, weights_data.data()),"CPU");
    std::cout << "compile success" << std::endl;
    return 0;
}

This blog just provide an example of model encryption by OpenSSL. This method can only protect you model in disk, for total memory crypto, you can refer technologies like OpenVINO™ Security Add-on in virtual machine to provide an isolated environment for security sensitive operations, and use Intel® SGX (Software Guard Extensions) which allows developers to split a computer's memory into private, predefined, highly secure areas called enclaves, which better protect sensitive information.

Reference:
  1. OpenVINO model protection: https://docs.openvino.ai/2023.1/openvino_docs_OV_UG_protecting_model_guide.html
  2. OpenVINO™ Security Add-on: https://docs.openvino.ai/2023.1/ovsa_get_started.html
  3. OpenSSL official website: https://www.openssl.org/
Read More...
Kunda
Xu

How to use OpenVINO Extension to enable custom operation on GPU

Authors: Kunda,Xu Song, Bell

This chapter introduces the extension mechanism of OpenVINO on GPU. And combine thecode snippets to explain how a custom operation should be expanded on the GPU.

First, we need to understand the OpenVINO Extensibility Mechanism, Custom operations, which are not included in the list, are not recognized by OpenVINO out-of-the-box. The need for custom operation may appearin two cases:

-         A new or rarely used regular framework operation is notsupported in OpenVINO yet.

-         A new user operation that was created for some specific modeltopology by the author of the model using framework extension capabilities.

Importing models with such operations requires additional steps. This guide illustrates the workflow for running inference on models featuring custom operations. 

Introduction

How to implement custom operations by OpenVINO Extensibility API 

OpenVINO Extensibility API enables adding support for those custom operations and using one implementation for Model Optimizer and OpenVINO Runtime.

Defining a new custom operation basically consists of two parts:

-         Definition of operation semantics in OpenVINO, the code that describes how this operation should be inferred consuming input tensor(s) andproducing output tensor(s). The implementation of execution kernels for GPU is described in separate guides.

-         Mapping rule that facilitates conversion of framework operation representation to OpenVINO defined operation semantics.

The first part is required for inference. Thesecond part is required for successful import of a model containing suchoperations from the original framework model format.

How to implement GPU custom operations

To enable operations not supported by OpenVINO™ out of the box, you may need an extension for OpenVINO operation set, and a custom kernel for the device you will target. This article describes custom kernel supportfor the GPU device.

The GPU code path abstracts many details about OpenCL. Youneed to provide the kernel code in OpenCL C and an XML configuration file that connects the kernel and its parameters to the parameters of the operation.

The process of GPU operation extension is the same as theoverall process of CPU operation, but because OpenVINO GPU operations are defined using OpenCL, additional files are required to provide interface and function definitions for custom ops.

Picture1-GPU extension file list

In the above figure, custom_op.cl and custom_op.xml areadditional files required by the GPU extension than the CPU extension.

-         custom_op.cl is the function definition of custom operation using OpenCL

-         custom_op.xml is an extended function of OpenVINO GPU and provides the interface definition of custom operation function.

And need call the core.set_property() method from yourapplication with the "CONFIG_FILE" key and the configuration filename as a value before loading the network that uses custom operations to the plugin.

Picture2-custom_op.cl - custom op OpenCL implement
Picture3-custom_op.xml Define the operation interface implement

Quick Start

Similarto the OpenVINO CPU Extension process, the GPU extension process is as follows, there are several options to implement each part. The following sections will describe them in detail.
It is recommended to familiarize yourself with the implementation of the OpenVINO extension on the CPU first, so that you can better understand it when using GPU custom operations.

Definition of Operation Semantics

There aretwo ways to define custom operations,

-         OpenVINO operation set combination. If thecustom operation can be mathematically represented as a combination of exiting OpenVINO operations and such decomposition gives desired performance.

-         Implementing custom operation by C++/OpenCL. General advice, try to decompose the operation by OpenVINO operation set first, If such decomposition is not possible or appears too bulkywith a large number of constituent operations that do not perform well, then a new class for the custom operation should be implemented

-         Custom Operation Guide Reference Link: https://docs.openvino.ai/2023.3/openvino_docs_Extensibility_UG_add_openvino_ops.html

Here wetake a simple OpenCL custom op as an example to illustrate.

Picture4-operation definition by OpenVINO extension API

Mapping from Framework Operation

 

Mapping of custom operation is implemented differently,depending on model format used for import. If a model is represented in the ONNX (including models exported from PyTorch in ONNX), TensorFlow Lite,PaddlePaddle or TensorFlow formats.

Frontend Extension API Reference Link:  https://docs.openvino.ai/2024/documentation/openvino-extensibility/frontend-extensions.html

 

Registering Extensions

A custom operation class and a new mapping frontend extensionclass object should be registered to be usable in OpenVINO runtime.

Picture5-GPU custom op registering

Create a Library with Extensions

To create an extension library, for example, to load the extensions into OpenVINO Inference engine, perform the following:

CMake file define.

Picture6-GPU extension-CMake file define

Build the extension library, running the commands below

Picture7-GPU extension-cmake cmd line

OpenVINOGPU extension code snippets

Picture8-GPU extension-python sample code snippets

Read More...
No items found.
Nikita
Savelyev

Optimizing Whisper and Distil-Whisper for Speech Recognition with OpenVINO and NNCF

January 29, 2024

Authors: Nikita Savelyev, Alexander Kozlov, Ekaterina Aidova, Maxim Proshin

Introduction

Whisper is a general-purpose speech recognition model from OpenAI. The model can transcribe speech across dozens of languages and even handle poor audio quality or excessive background noise. You can find more information about this model in the research paper, OpenAI blog, model card and GitHub repository.

Recently, a distilled variant of the model called Distil-Whisper has been proposed in the paper Robust Knowledge Distillation via Large-Scale Pseudo Labelling. Compared to Whisper, Distil-Whisper runs several times faster with 50% fewer parameters, while performing to within 1% word error rate (WER) on out-of-distribution evaluation data.

Whisper is a Transformer-based encoder-decoder model, also referred to as a sequence-to-sequence model. It maps a sequence of audio spectrogram features to a sequence of text tokens. First, the raw audio inputs are converted to a log-Mel spectrogram by action of the feature extractor. Then, the Transformer encoder encodes the spectrogram to form a sequence of encoder hidden states. Finally, the decoder autoregressively predicts text tokens, conditional on both the previous tokens and the encoder's hidden states.

You can see the model architecture in the diagram below:

In this article, we would like to demonstrate how to improve Whisper and Distil-Whisper inference speed with OpenVINO for Intel hardware. Additionally, we show how to make models even faster by applying 8-bit Post-training Quantization with Neural Network Compression Framework (NNCF). In the end we present evaluation results from accuracy and performance standpoints on a large-scale dataset.

All code snippets presented in this article are from the Automatic speech recognition using Distil-Whisper and OpenVINO Jupyter notebook, so you can follow along.

Converting Model to OpenVINO format

We are going to load models from Hugging Face Hub with the help of Optimum Intel library which makes it easier to load and run OpenVINO-optimized models. For more details, pleaes refer to the Hugging Face Optimum documentation.

For example, the following code loads the Distil-Whisper large-v2 model ready for inference with OpenVINO.


from optimum.intel.openvino import OVModelForSpeechSeq2Seq

model_id = "distil-whisper/distil-large-v2"
model_path = Path(model_id)
if not model_path.exists():
    ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
        model_id, export=True, compile=False, load_in_8bit=False)
    ov_model.half()
    ov_model.save_pretrained(model_path)
else:
    ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
        model_path, compile=False)

Models from the Distil-Whisper family are available at Distil-Whisper Models collection and Whisper models are available at OpenAI Hugging Face page.

To transcribe an input audio with the loaded model, we first compile the model to the device of choice and then call generate() method on input features prepared by corresponding processor.


from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(model_id)

ov_model.to("AUTO")
ov_model.compile()

# ... load input audio and reference text
input_features = processor(input_audio).input_features
predicted_ids = ov_model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
print(f"Reference: {reference_text}")
print(f"Result: {transcription}")

The output is the following. As you can see the transcription equals the reference text.

Reference: MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
Result:  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

Running Post-Training Quantization with NNCF

NNCF enables post-training quantization by adding quantization layers into the model graph and then using a subset of the training dataset to initialize parameters of these additional quantization layers. During quantization, some layers (e.g., MatMuls, Convolutions) are transformed to be executed in INT8 instead of FP16/FP32. If a quantized operation is parameterized then its corresponding weight variable is also converted to INT8.

In general, the optimization process contains the following steps:

  1. Create a calibration dataset for quantization.
  2. Run nncf.quantize() to obtain quantized encoder and decoder models.
  3. Serialize the INT8 models using openvino.save_model() function.

Whisper model consists of an encoder and decoder submodels. Furthermore, for the decoder model its forward() signature is different for the first call compared to all subsequent calls. During the first call, key-value cache is empty and is not needed for decoder inference. Starting from the second call, key-value cache is fed to the decoder. Because of this, these two cases are represented by two separate OpenVINO models: openvino_decoder_model.xml and openvino_decoder_with_past_model.xml. Since the first decoder model is inferred only once it does not make much sense to quantize it. So, we apply quantization to the encoder and the decoder with past models.

The first step towards quantization is collecting calibration data. For that, we need to collect some number of model inputs for both models. To do that, we patch OpenVINO model request objects with an InferRequestWrapper class instance that will intercept model inputs during inference and store them in a list. We infer the model on about 50 samples from validation split of librispeech_asr dataset.


def collect_calibration_dataset(ov_model: OVModelForSpeechSeq2Seq, calibration_dataset_size: int):
    # Overwrite model request properties, saving the original ones for restoring later
    original_encoder_request = ov_model.encoder.request
    original_decoder_with_past_request = ov_model.decoder_with_past.request
    encoder_calibration_data = []
    decoder_calibration_data = []
    ov_model.encoder.request = InferRequestWrapper(original_encoder_request, encoder_calibration_data)
    ov_model.decoder_with_past.request = InferRequestWrapper(original_decoder_with_past_request,
                                                             decoder_calibration_data)

    calibration_dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
    for sample in islice(calibration_dataset, calibration_dataset_size):
        input_features = extract_input_features(sample)
        ov_model.generate(input_features)

    ov_model.encoder.request = original_encoder_request
    ov_model.decoder_with_past.request = original_decoder_with_past_request

    return encoder_calibration_data, decoder_calibration_data

With the collected calibration data for encoder and decoder models we can proceed to quantization itself. Let's examine the quantization call for the encoder model. For the decoder model, it is similar.


quantized_encoder = nncf.quantize(
    ov_model.encoder.model,                     # ov.Model object of the encoder model
    nncf.Dataset(encoder_calibration_data),     # calibration data wrapped in a nncf.Dataset object
    subset_size=len(encoder_calibration_data),  # number of samples to calibrate on (all are chosen)
    model_type=nncf.ModelType.TRANSFORMER,      # providing the information that Whisper encoder is of
    # a Transformer architecture
    advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.50)    # Smooth Quant 
    # algorithm reduces activation quantization error; optimal alpha was obtained through grid search
)
ov.save_model(quantized_encoder, quantized_model_path / "openvino_encoder_model.xml")

After both models are quantized and saved, the quantized Whisper model can be loaded and run the same way as shown previously. Comparing the transcriptions produced by original and quantized models results in the following.

Original :  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.
Quantized:  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

As you can see for the quantized distil-whisper-large-v2 transcription is the same.

Evaluating on Common Voice Dataset

We evaluate Whisper and Distil-Whisper large-v2 model variants on a Common Voice 13.0 speech-to-text dataset. We use en/test split containing 16372 audio samples amounting to about 27 hours of recordings.

The evaluation is done across three model types: original PyTorch model, original OpenVINO model and quantized OpenVINO model. Additionally, we run tests on three Intel CPUs: Cascade Lake Intel(R) Core(TM) i9-10980XE, Ice Lake Intel(R) Xeon(R) Gold 6338 and Sapphire Rapids Intel(R) Xeon(R) Gold 6430L.

For all combinations above we measure transcription time and accuracy. When measuring time for a model we sum up generate() call durations for all audio samples. Transcription accuracy is represented as Accuracy = (100 - WER), WER stands for Word Error Rate. We compute accuracy for each audio sample and then take the average value across the dataset. The results are given in the table below.

Please note that we report transcription time in relative terms such that the values for each CPU are normalized over its corresponding column. The duration of audio data in the dataset is 27.06 hours and the absolute transcription time values for Whisper large-v2 PyTorch on each CPU are:

  • 20.35 hours for Core i9-10980XE
  • 14.09 hours for Xeon Gold 6338
  • 15.03 hours for Xeon Gold 6430L

Based on the results we can conclude that:

  1. OpenVINO models execute 1.4x - 5.1x faster than PyTorch models with pretty much the same accuracy across all cases.
  2. When compared to original PyTorch models, quantized OpenVINO models provide 2.1x - 6.1x performance boost with 2-3% accuracy drop.

NOTE: in terms of this article we focus on presenting performance values. Accuracy of quantized models can be improved with a more careful selection of calibration data.

Notices and Disclaimers:

Performance varies by use, configuration, and other factors. Learn more at www.intel.com/PerformanceIndex. Performance results are based on testing as of dates shown in configurations and may not reflect all publicly available updates. No product or component can be absolutely secure. Intel technologies may require enabled hardware, software or service activation.

The products described may contain design defects or errors known as errata which may cause the product to deviate from published specifications. Current characterized errata are available on request.

Test Configuration: Intel® Core™ i9-10980XE CPU Processor at 3.00GHz with DDR4 128 GB at 3000MHz, OS: Ubuntu 20.04.3 LTS; Intel® Xeon® Gold 6338 CPU Processor at 2.00GHz with DDR4 256 GB at 3200MHz, OS: Ubuntu 20.04.3 LTS; Intel® Xeon® Gold 6430L CPU Processor at 1.90GHz with DDR5 1024 GB at 4800MHz, OS: Ubuntu 20.04.6 LTS. Testing was performed using 267-distil-whisper-asr notebook for model export and whisper evaluation notebook for model evaluation.

The test was conducted by Intel in December 2023.

Conclusion

We demonstrated how to load and run Whisper and Distil-Whisper models for audio transcription task with OpenVINO and Optimum Intel, and how to perform INT8 post-training quantization of these models with NNCF. Further we evaluated these models on a large scale speech-to-text dataset across multiple CPU devices. The evaluation results show a significant performance boost of OpenVINO vs PyTorch models without loss of transcription quality, and even a larger boost with a tolerable accuracy drop when we apply INT8 quantization.

Read More...
Wenyi
Zou

Enable tokenize and detokenize by creating OpenVINO™ model and CPP runtime pipeline

Authors Wenyi Zou, Su Yang

Introduction

Tokenization is the process of breaking down text into smaller units, such as words or subwords, known as tokens. These tokens are the building blocks for various NLP tasks, including text classification, sentiment analysis, machine translation, and more. Tokenization makes it easier for machines to understand and work with textual data by converting unstructured text into a structured format.

This article will demonstrate a C++ application of the tokenize and detokenize model with Intel’s OpenVINO™ C++ API on Linux/Windows systems.

Model Conversion

The OpenVINO™ Tokenizers project has a tool to convert a HuggingFace tokenizer into OpenVINO™ IR model tokenizer and detokenizer: it provides the convert_tokenizer function that accepts a tokenizer.

Install system dependency and setup environment

Step1: Download code

git clone https://github.com/openvinotoolkit/openvino.genai.git
cd openvino.genai
git submodule update --init

Step2: Install OpenVINO™ Install OpenVINO™ Archives >=2023.3. <INSTALL_DIR> below refers to the extraction location.

For Linux

source <INSTALL_DIR>/setupvars.sh

For Windows

<INSTALL_DIR>\setupvars.bat

Step3: Create python environment

conda create -n ov_genai python=3.10
cd text_generation/causal_lm/cpp
python -m pip install --upgrade-strategy eager "optimum>=1.14" -r ../../../llm_bench/python/requirements.txt ../../../thirdparty/openvino_contrib/modules/custom_operations/[transformers] --extra-index-url https://download.pytorch.org/whl/cpu

step4: Convert tokenizer to OpenVINO™ IR.

Take the tokenize and detokenize of the chatglm3-6b as an example.

convert_tokenizer ./chatglm3-6b/ --output ./chatglm3-6b/ov/ --with-detokenizer --trust-remote-code

Build custom OpenVINO operation extension library

cd thirdparty/openvino_contrib/modules/custom_operations
mkdir build && cd build
cmake ../ -DCMAKE_BUILD_TYPE=Release
cmake --build . --parallel 4

Load and use custom OpenVINO operation extension library


#include <openvino/openvino.hpp>
namespace {
std::pair<ov::Tensor, ov::Tensor> tokenize(ov::InferRequest& tokenizer, std::string&& prompt) {
    constexpr size_t BATCH_SIZE = 1;
    tokenizer.set_input_tensor(ov::Tensor{ov::element::string, {BATCH_SIZE}, &prompt});
    tokenizer.infer();
    return {tokenizer.get_tensor("input_ids"), tokenizer.get_tensor("attention_mask")};
}

std::string detokenize(ov::InferRequest& detokenizer, std::vector<int64_t>& tokens) {
    constexpr size_t BATCH_SIZE = 1;
    detokenizer.set_input_tensor(ov::Tensor{ov::element::i64, {BATCH_SIZE, tokens.size()}, tokens.data()});
    detokenizer.infer();
    return detokenizer.get_output_tensor().data<std::string>()[0];
}
;

int main(int argc, char* argv[]) try {
    if (argc != 3) {
        throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <MODEL_DIR> '<PROMPT>'");
}

    ov::Core core;

#ifdef _WIN32
    core.add_extension(“path/to/user_ov_extensions.dll”);
#else
    core.add_extension(“path/to/libuser_ov_extensions.so”);
#endif

    std::cout << "promt " << argv[2] << std::endl;

    // tokenizer and detokenizer work on CPU only
    ov::InferRequest tokenizer = core.compile_model(
        std::string{argv[1]} + "/openvino_tokenizer.xml", "CPU").create_infer_request();
    auto [input_ids, attention_mask] = tokenize(tokenizer, argv[2]);
    ov::InferRequest detokenizer = core.compile_model(
        std::string{argv[1]} + "/openvino_detokenizer.xml", "CPU").create_infer_request();
    
    std::cout << "input lenghth " << input_ids.get_size() << std::endl;

    std::vector<int64_t> tokens;
    for (size_t idx = 0; idx < input_ids.get_size(); ++idx) {
        tokens.push_back(input_ids.data<int64_t>()[idx]);
    }
    
    std::cout << detokenize(detokenizer, tokens) << std::endl;
    
} catch (const std::exception& error) {
    std::cerr << error.what() << '\n';
    return EXIT_FAILURE;
} catch (...) {
    std::cerr << "Non-exception object thrown\n";
    return EXIT_FAILURE;
}

Read More...
No items found.
Anna
Likholat

OpenVINO Latent Consistency Model C++ pipeline with LoRA model support

January 25, 2024

Introduction

Latent Consistency Models (LCMs) is the next generation of generative models after Latent Diffusion Models (LDMs). While Latent Diffusion Models (LDMs) like Stable Diffusion are capable of achieving the outstanding quality of generation, they often suffer from the slowness of the iterative image denoising process. LCM is an optimized version of LDM. Inspired by Consistency Models (CM), Latent Consistency Models (LCMs) enabled swift inference with minimal steps on any pre-trained LDMs, including Stable Diffusion. The Consistency Models is a new family of generative models that enables one-step or few-step generation. More details about the proposed approach and models can be found using the following resources: project page, paper, original repository.

This article will demonstrate a C++ application of the LCM model with Intel’s OpenVINO™ C++ API on Linux systems. For model inference performance and accuracy, the C++ pipeline is well aligned with the Python implementation.

The full implementation of the LCM C++ demo described in this post is available on the GitHub: openvino.genai/lcm_dreamshaper_v7.

Model Conversion

To leverage efficient inference with OpenVINO™ runtime on Intel platforms, the original model should be converted to OpenVINO™ Intermediate Representation (IR).

LCM model

Optimum Intel can be used to load SimianLuo/LCM_Dreamshaper_v7 model from Hugging Face Hub and convert PyTorch checkpoint to the OpenVINO™ IR on-the-fly, by setting export=True when loading the model, like:

from optimum.intel.openvino import OVLatentConsistencyModelPipeline

model = OVLatentConsistencyModelPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", export=True)
model.save_pretrained("ov_lcm_model")

Tokenizer

OpenVINO Tokenizers is an extension that adds text processing operations to OpenVINO Inference Engine. In addition, the OpenVINO Tokenizers project has a tool to convert a HuggingFace tokenizer into OpenVINO IR model tokenizer and detokenizer: it provides the convert_tokenizer function that accepts a tokenizer Python object and returns an OpenVINO Model object:

from transformers import AutoTokenizer
from openvino_tokenizers import convert_tokenizer
from openvino import compile_model, save_model

hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
ov_tokenizer_encoder = convert_tokenizer(hf_tokenizer)
save_model(ov_tokenizer_encoder, "ov_tokenizer.xml")

Note: Currently OpenVINO Tokenizers can be inferred on CPU devices only.

Conversion step

You can find the full script for model conversion at the original repo.

Note: The tutorial assumes that the current working directory is and <openvino.genai repo>/image_generation/lcm_ dreamshaper_v7/cpp all paths are relative to this folder.

Let’s prepare a Python environment and install dependencies:

conda create -n openvino_lcm_cpp python==3.10
conda activate openvino_lcm_cpp
conda install -c conda-forge 'openvino>=2023.3.0'
python -m pip install -r scripts/requirements.txt
python -m pip install ../../../thirdparty/openvino_contrib/modules/custom_operations/[transformers]

Now we can use the script scripts/convert_model.py to download and convert models:

cd scripts
python convert_model.py -lcm "SimianLuo/LCM_Dreamshaper_v7" -t FP16

C++ Pipeline

Pipeline flow

Let’s now talk about the logical structure of the LCM model pipeline.

Just like the classic Stable Diffusion pipeline, the LCM pipeline consists of three important parts:
- A text encoder to create a condition to generate an image from a text prompt.
- U-Net for step-by-step denoising the latent image representation.
- Autoencoder (VAE) for decoding the latent space to an image.

The pipeline takes a latent image representation and a text prompt transformed to text embedding via CLIP’s text encoder as an input. The initial latent image representation is generated using random noise generator. LCM uses a guidance scale for getting time step conditional embeddings as input for the diffusion process, while in Stable Diffusion, it used for scaling output latents.

Next, the U-Net iteratively denoises the random latent image representations while being conditioned on the text embeddings. The output of the U-Net, being the noise residual, is used to compute a denoised latent image representation via a scheduler algorithm. LCM introduces its own scheduling algorithm that extends the denoising procedure introduced by denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. The denoising process is repeated for a given number of times to step-by-step retrieve better latent image representations. When complete, the latent image representation is decoded by the decoder part of the variational auto encoder.

The C++ implementations of the scheduler algorithm and LCM pipeline are available at the following links: LCM Scheduler, LCM Pipeline.

LoRA support

LoRA (Low-Rank Adaptation) is a training technique for fine-tuning Stable Diffusion models. There are various LoRA models available on https://civitai.com/tag/lora.

The main idea for LoRA weights enabling, is to append weights onto the OpenVINO LCM models at runtime before compiling the Unet/text_encoder model. The method is to extract LoRA weights from safetensors file, find the corresponding weights in Unet/text_encoder model and insert the LoRA bias weights. The common approach to add LoRA weights looks like:

The original LoRA safetensor model is loaded via safetensors.h. The layer name and weight of LoRA are modified with Eigen Lib and inserted into Unet/text_encoder OpenVINO model using ov::pass::MatcherPass - you can see the implementation in the file common/diffusers/src/lora.cpp.

To run the LCM demo with the LoRA model, first download LoRA, for example: LoRa/Soulcard.

Build and Run LCM demo

Let’s start with the dependencies installation:

conda activate openvino_lcm_cpp
conda install -c conda-forge eigen c-compiler cxx-compiler make

Now we can build the application:

cmake -DCMAKE_BUILD_TYPE=Release -S . -B build
cmake --build build --config Release --parallel
cd build

And finally we’re ready to run the LCM demo. By default the positive prompt is set to: “a beautiful pink unicorn”.

Please note, that the quality of the resulting image depends on the quality of the random noise generator, so there is a difference for output images generated by the C++ noise generator and the PyTorch generator. Use oprion -r to read the PyTorch generated noise from the provided textfiles for the alignment with Python pipeline.

Note: Run ./lcm_dreamshaper -h to see all the available demo options

Let’s try to run the application in a few modes:

Read the numpy latent input and noise for scheduler instead of C++ std lib for the alignment with Python pipeline: ./lcm_dreamshaper -r

Generate image with C++ std lib generated latent and noise : ./lcm_dreamshaper

Generate image with Soulcard LoRa and C++ generated latent and noise: ./lcm_dreamshaper -r -l path/to/soulcard.safetensors

See Also

  1. Optimizing Latent Consistency Model for Image Generation with OpenVINO™ and NNCF
  2. Image generation with Latent Consistency Model and OpenVINO
  3. C++ Pipeline for Stable Diffusion v1.5 with Pybind for Lora Enabling
  4. Enable LoRA weights with Stable Diffusion Controlnet Pipeline

Read More...
Alexander
Kozlov

Q4'23: Technology Update – Low Precision and Model Optimization

December 20, 2023

Authors

Alexander Kozlov, Nikita Savelyev, Nikolay Lyalyushkin, Vui Seng Chua, Pablo Munoz, Alexander Suslov, Andrey Anufriev, Liubov Talamanova, Yury Gorbachev, Nilesh Jain, Maxim Proshin

Summary

This quarter we observe that most of the work is still dedicated to the Large Language Model optimization. Researchers try to break through W4A8 quantization setup for LLMs achieving the accuracy results that allows considering such optimized models for deployment scenario. Some teams work on lower-precision settings such as 2-bit weight quantization or even binary weight compression. Interestingly, some teams propose to stick to a higher bit-width (FP6) and data-free optimization approach to avoid overfitting to calibration data. We also see an increasing interest in applying various types of weight sparsity in LLMs. And, of course, we should note the tremendous improvement in the inference time of Diffusion models caused by the decrease in the overall number of iterations in the diffusion process. This allows running variations of Stable Diffusion on mobile devices in the below 1 second.‍

Papers with notable results

Quantization

  • AWEQ: Post-Training Quantization with Activation-Weight Equalization for Large Language Models by Jilin University (https://arxiv.org/pdf/2311.01305.pdf). Authors apply a known recipe for DL models quantization to LLM models. It contains weight equalization and bias correction methods stacked together. The difference is in how they estimate parameters and where to apply both methods. The method shows good results for W8A8 and W4A8 settings and outperforms GPTQ method on LLAMA and OPT models.
  • AFPQ: Asymmetric Floating Point Quantization for LLMs by China Universities and Microsoft Research Asia (https://arxiv.org/pdf/2311.01792.pdf).Authors propose accurate asymmetric schema for the floating-point quantization. Instead of using typical asymmetric schema with scale and zero point, they use just2 scales: one is for positive values and another - for negative ones. It gives better accuracy NF4/NF3 quantization on different LLAMA models with no memory overhead. Code is available: https://github.com/zhangsichengsjtu/AFPQ.
  • Enhancing Computation Efficiency in Large Language Models through Weight and Activation Quantization by Hanyang University, SAPEON Korea Inc., Seoul National University (https://arxiv.org/pdf/2311.05161.pdf). Authors present two techniques: activation-quantization-aware scaling (a trade-off between SQ and AWQ) and sequence-length-aware calibration (adaptation of OPTQ to various sequence lengths) to enhance PTQ by considering the combined effects on weights and activations and aligning calibration sequence lengths to target tasks. They also introduce dINT, a hybrid data format combining integer and denormal representations, to address the underflow issue in W4A8 quantization, where small values are rounded to zero. The combined approach allows for achieving superior results compared to baselines. However, dINT has a limitation of efficient implementation on a general-purpose HW, such as CPU and GPU.
  • POST-TRAINING QUANTIZATIONWITH LOW-PRECISION MINIFLOATS AND INTEGERS ON FPGAS by AMD Research, National University of Singapore, and Tampere University (https://arxiv.org/pdf/2311.12359.pdf). Authors compare integer and mini float quantization techniques, encompassing a combination of state-of-the-art PTQ methods, such as weight equalization, bias correction, SmoothQuant, learned rounding, and GPTQ. They explore the accuracy-hardware tradeoffs, providing analysis for three models -ResNet-18, MobileNetV2, and ViT-B32 - based on a custom FPGA implementation. Experiments indicate that mini float quantization typically outperforms integer quantization for bit-widths of four or more, both for weights and activations. However, when compared against FPGA hardware cost model, integer quantization often retains its Pareto optimality due to its smaller hardware footprint at a given precision.
  • I&S-ViT: An Inclusive& Stable Method for Pushing the Limit of Post-Training ViTs Quantization by Xiamen University, Tencent, and Peng Cheng Laboratory (https://arxiv.org/pdf/2311.10126.pdf). The paper introduces a method that regulates the PTQ of ViTs in an inclusive and stable fashion. It first identifies two issues in the PTQ of ViTs: (1)Quantization inefficiency in the prevalent log2 quantizer for post-Softmax activations; (2) Rugged and magnified loss landscape in coarse-grained quantization granularity for post-LayerNorm activations. Then, the method addresses these issues by introducing: (1) A novel shift-uniform-log2 quantizer(SULQ) that incorporates a shift mechanism followed by uniform quantization to achieve both an inclusive domain representation and accurate distribution approximation; (2) A three-stage smooth optimization strategy that amalgamates the strengths of channel-wise and layer-wise quantization to enable stable learning. The method achieves comparable results in the W4A4 and W3A3 quantization settings.
  • Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing by Qualcomm AI Research (https://arxiv.org/pdf/2306.12929.pdf). This work aims to remove outliers by construction (pretraining) so that transformer can be quantized easily without the need of finer quantization granularity (e.g. per channel, per group). The authors root-caused that outliers in trained transformers are essentially the artifact of attention head attenuating uninformative tokens and outliers emerged in the formulation/backpropagation of softmax, residual connections and layer normalization to sustain the effect of these tokens. Two independent solutions are proposed - (1) Clipped softmax that allows exact zeros and ones in softmax to avoid growing outliers during training. (2) Gated attention which is a tiny neural network (linear+sigmoid) added to the vanilla attention to decouple the needs of large attention output for disregarding the uninformative tokens. Pretraining with the proposed formulation on BERT, OPT and ViT has been shown to converge similarly if not better than baseline recipe. Most notably, the ease of per-tensor int8 static quantization to both weight and activation in post-training fashion has been empirically verified. Code is coming soon at https://github.com/qualcomm-ai-research/outlier-free-transformers
  • A Speed Odyssey for Deployable Quantization of LLMs by Meituan (https://arxiv.org/pdf/2311.09550.pdf).Authors propose a solution for deployable W4A8 quantization that comprises a tailored quantization configuration and a novel Fast GEMM kernel for 4-bitinteger matrix multiplication that reduces the cost, and it achieves 2.23× and1.45× speed boosting over the TensorRT-LLM FP16 and INT8 implementation respectively. The W4A8 recipe is proven mostly on par with the state-of-the-art W8A8 quantization method SmoothQuant on a variety of common language benchmarks for the state-of-the-art LLMs.
  • TFMQ-DM: Temporal Feature Maintenance Quantization for Diffusion Models by SenseTime and universities of US, China and Australia (https://arxiv.org/pdf/2311.16503.pdf). Authors investigate the problems in quantization of diffusion models and claim that these models heavily depend on the time-step t to achieve satisfactory multi-round denoising when t is encoded to a temporal feature by a few modules totally irrespective of the sampling data. They propose a Temporal Feature Maintenance Quantization framework building upon a Temporal Information Block which is just related to the time-step t and unrelated to the sampling data. Powered by this block design, authors propose temporal information aware reconstruction and finite set calibration to align the full-precision temporal features in a limited time. The method achieves accurate results even in W4A8quantization setting.
  • QUIK: TOWARDS END-TO-END4-BIT INFERENCE ON GENERATIVE LARGE LANGUAGE MODELS by ETH Zurich, Institute of Science and Technology Austria, Xidian University, KAUST, Neural Magic (https://arxiv.org/pdf/2310.09259v2.pdf). The paper addresses the problem where both weights and activations should be quantized. Authors show that the majority of inference computations for large generative models such as LLaMA, OPT, and Falcon can be performed with both weights and activations being cast to 4 bits, in a way that leads to practical speedups, while at the same time maintaining good accuracy. They achieve this via a hybrid quantization strategy called QUIK, which compresses most of the weights and activations to 4-bit, while keeping some outlier weights and activations in higher precision which leads to practical end-to-end throughput improvements of up to 3.4x relative to FP16 execution. Code is available at: https://github.com/IST-DASLab/QUIK.
  • Post-training Quantization with Progressive Calibration and Activation Relaxing for Text-to-Image Diffusion Models by Tsinghua University (https://arxiv.org/pdf/2311.06322.pdf). Authors propose a post-training quantization method for text-to-image diffusion models, which consists of a progressive calibration strategy that considers the accumulated quantization error across timesteps, and an activation relaxing strategy that improves the performance with a small cost. They also propose a new QDiffBench benchmark, which utilizes data in the same domain for a more accurate evaluation of the generation accuracy.
  • Enabling Fast 2-bit LLM on GPUs: Memory Alignment, Sparse Outlier, and Asynchronous Dequantization by Shanghai Jiao Tong University and Tsinghua University (https://arxiv.org/pdf/2311.16442.pdf).The paper proposes range-aware quantization with memory alignment. It points out that the range of weights by groups varies. Thus, only 25% of the weights are quantized using 4-bit with memory alignment. Such a method reduces the accuracy loss for 2-bit Llama2-7b quantization from 8.7% to 2.9%. Authors show as well that only a small fraction of outliers exist in weights quantized using2-bit. These quantize these sparse outliers with < 3% increased average weight bit and improve the accuracy by >0.5%. They also accelerate GPU kernels by introducing asynchronous dequantization achieving 3.92× improvement on a kernel level.
  • ZeroQuant(4+2): Redefining LLMs Quantization with a New FP6-Centric Strategy for Diverse Generative Tasks by DeepSpeed (https://arxiv.org/pdf/2312.08583.pdf).Authors show that popular data-aware LLM compression methods such as GPTQ can overfit to calibrated datasets especially on moderate-size LLMs (<=1B). They also illustrate that FP6, employing a basic round-to-nearest (RTN) algorithm and a per-channel quantization approach, consistently achieves accuracy on par with full-precision models. They propose an unpacking (mapping) scheme for FP8 so that it can be efficiently used with FP16 inference.

Pruning/Sparsity

  • Sparse Fine-tuning for Inference Acceleration of Large Language Models by IST Austria, Skoltech & Yandex, and Neural Magic (https://arxiv.org/pdf/2310.06927.pdf). The paper analyses the challenges of LLMs pruning, namely loss spikes leading to divergence, poor recovery from fine-tuning, and overfitting. To overcome these issues, authors incorporate standard cross-entropy, output knowledge distillation, and a type of per-token ℓ2 knowledge distillation on top of SparseGPT method. They show that the resulting sparse models can be executed with inference speedups on CPU and GPU, especially when stacking with INT8quantization. The code is available: https://github.com/IST-DASLab/SparseFinetuning.
  • ReLU Strikes Back: Exploiting Activation Sparsity in LLMs by Apple. (https://arxiv.org/pdf/2310.04564.pdf). This work advocates to reinstate ReLU as main activation function in LLMs due to its intriguing property – high post-ReLU activation sparsity can be translated to computational efficiency with sparse runtime. To overcome training from scratch and for LLMs employing GeLU/SiLU, the paper proposes “Relufication”, a two-stage uptraining by first replacing non-ReLU activations in pre-trained LLMs with ReLU, and then appending ReLU to normalization layers in second stage for more sparsity. With increase of activation sparsity, the authors also observe high overlapping activated neurons during decoding (termed aggregated sparsity) and suggest weight reuse to alleviate memory transfer. The authors show application of aggregated sparsity in speculative decoding and demonstrate 27% speedup of OPT-6.7B at a minor degradation of perplexity.
  • Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time by Rice University, Zhe Jiang University, Stanford University, University of California, ETH Zurich, Adobe Research, MetaAI, Carnegie Mellon University (https://proceedings.mlr.press/v202/liu23am/liu23am.pdf). Authors propose a contextual dynamic sparsity method for LLMs. Contrary to a usual sparsity approach where a model is pruned once and then inferenced on every input sample in the same pruned state, here authors compute the set of pruned operations on the run resulting in a more flexible pruning scheme. Foreach transformer layer this is achieved by predicting set of MHA heads and MLP matrix columns to exclude based on previous layers activations. Prediction is performed by small separately trained perceptron networks. To remove the performance bottleneck produced by the need to inference of perceptron networks authors propose to make sparsity predictions for (i+1)-th transformer layer based on activations from (i-1)-th layer, resulting in parallel computation of i-th layer and sparsity sets for (i+1)-th layer. This is viable due to shown similarity between activations of neighboring layers in LLMs. For OPT-175Bmodel the approach achieves over 6x performance improvement compared to Hugging Face implementation and over 2x improvement compared to state-of-the-art FasterTransformer model.
  • SparseByteNN: A Novel Mobile Inference Acceleration Framework Based on Fine-Grained Group Sparsity by ByteDance (https://arxiv.org/pdf/2310.19509.pdf). The paper introduces SparseByteNN, consisting of three components: a)compression algorithm component, which provides out-of-the-box pruning capabilities for pre-trained models b) model conversion tool, which converts the model IR of the training framework into Model IR of sparse engine c) sparse inference engine, which provides efficient inference implementation compatible with CPUs for fine-grained kernel group sparsity. Experimental results on Qualcomm 855 show that for 30% sparse MobileNet-v1, SparseByteNN achieves 1.27×speedup over the dense version. The code will be available at: https://github.com/lswzjuer/SparseByteNN.

Neural Architecture Search

  • LoNAS: Elastic Low-Rank Adapters for Efficient Large Language Models by Anonymous (https://openreview.net/pdf?id=pzB-1OCS6gd). Researchers demonstrate a novel integration of low-rank (LoRA)adapters with Neural Architecture Search. LoNAS efficiently fine-tunes and compress large language models (LLMs). A weight-sharing super-network is generated using the frozen weights of the input model, and the attached elastic low-rank adapters. The reduction in trainable parameters results in less the memory requirements to train the super-network, enabling the manipulation of LLMs in resource-constrained devices, without sacrificing the performance of the resulting compressed models. LoNAS’ high-performing compressed models result in faster inference times, cost savings during the model’s lifetime, and an increase in the range of devices in which large language models can be deployed. Experiments’ results on six reasoning datasets demonstrate the benefits of LoNAS.
  • Bridging the Gap between Foundation Models and Heterogenous Federated Learning by Iowa State U. and Intel Labs (https://arxiv.org/abs/2310.00247). This paper explores the application of Neural Architecture Search (NAS) in combination with Federated Learning (FL). The proposed framework, Resource-aware Federated Foundation Models (RaFFM) introduces model compression and salient parameter prioritization in the context of Federated Learning, allowing for the collaborative training of large foundation models using heterogeneous devices. Compared to traditional FL methods, RaFFM yields better resource utilization, without sacrificing in model performance.
  • Rankitect: Ranking Architecture Search Battling World-class Engineers at Meta Scale by Meta Platforms (https://arxiv.org/pdf/2311.08430.pdf). Researchers at Meta demonstrate the real-world applications of Neural Architecture Search (NAS). They apply NAS to production models, e.g., Click Through Rate (CTR) model, on a system that serves billions of users. The baseline models explored in this work have already been optimized by world-class engineers. The proposed NAS framework, Rankitect, improves over existing models by exploring search spaces with no inductive bias from the baseline models, and discovers new models from scratch that outperform those hand-crafted by human experts. Rankitect also keeps human engineers in-the-loop by allowing the manual design of search spaces, which results in even more efficient models. 
  • QuadraNet: Improving High-Order Neural Interaction Efficiency with Hardware-Aware Quadratic Neural Networks by George Mason University, University of Maryland, University at Buffalo, Peking University (https://arxiv.org/pdf/2311.17956.pdf). This paper presents QuadraNet, a new neural network design methodology based on efficient quadratic neurons that captures high-order neural interactions similar to Transformer-based models. The design of this alternative to Transformer-based models is hardware-aware through the application of Neural Architecture Search(NAS). Experiments with QuadraNets show improvements of 1.5x in throughput without any reduction in accuracy compared to their Transformer-based counterparts.

Other

  • Divergent Token Metrics: Measuring degradation to prune away LLM components – and optimize quantization by Aleph Alpha, Hessian.AI and German Universities. (https://arxiv.org/pdf/2311.01544.pdf). The work highlights that the commonly used perplexity (PPL) metric in compression research does not reflect the degradation of compressed model and cannot distinguish subtleties (see figure below). The authors propose a family of divergent token metrics (DTM), namely First Token Divergence, i.e., when the first diverging token happens w.r.t baseline generated text, as well as Share of Divergent Tokens denoting the total number of divergent tokens. In a series of experiments pertaining to layer-wise quantization or pruning, DTM-based ranking consistently outperforms PPL-based ranking methods.
  • SIMPLIFYING TRANSFORMERBLOCKS by ETH Zurich (https://arxiv.org/pdf/2311.01906.pdf). The paper introduces a set of Transformer block pruning techniques that makes them more lightweight from the number of parameters and computations standpoint. This set includes: removing skip connection both in the Attention sub-block and Feed-Forward sub-block, removing value and projection parameters, removing normalization layers, and model depth scaling. Authors also show how to recover the accuracy after model perturbation using fine-tuning. The proposed method produces the decoder and decoder Transformer models that perform on part with their baselines: 15% faster training throughput, and using 15% fewer parameters.
  • MobileDiffusion: Subsecond Text-to-Image Generation on Mobile Devices by Google (https://arxiv.org/pdf/2311.16567.pdf).The paper presents a comprehensive guide for crafting highly efficient text-to-image diffusion models. Authors applied the following tricks to highly optimize UNet model in the Diffusion pipeline: more transformers in the middle of Unet (at lower resolution), retaining cross-attention layers while discarding only the self-attention layers at high resolutions, sharing key-value projections, replacing gelu with swish, fine-tune softmax into relu, trim feed-forward layers, use separable convolution, prune redundant residual blocks, reduce sampling iterations, knowledge distillation. The resulting model is able to generate 512×512 images in sub-second on mobile devices: 0.2 second on iPhone 15 Pro.
  • Online Speculative Decoding by UC Berkeley, UCSD, Sisu Data, SJTU (https://arxiv.org/pdf/2310.07177.pdf). Practical speedup of speculative decoding is often impeded by the capability gap between draft and target model which can be 10-20X gap in parameters, leading to high rejection of draft predictions and fallback to more forward passes of target model. This work proposes online fine-tuning of draft model by distillation with the readily available rejected predictions. The proposed solution incorporates a replay buffer tracking logits of draft and target model, and distillation backpropagation is executed at a regular interval. Experimental results demonstrate not only significant improvement in acceptance rate, translating up to a theoretical 3X of latency reduction, but also adaptability against distribution shift in input queries.
  • Token Fusion: Bridging the Gap between Token Pruning and Token Merging by Michigan State University Samsung Research America (https://arxiv.org/pdf/2312.01026.pdf).The paper introduces a method (ToFu) that combines token pruning and token merging. ToFu dynamically adapts to each layer’s properties, ensuring optimal performance based on the model’s functional linearity with respect to the interpolation in its input. Authors exploit MLERP merging technique, an enhancement over traditional average merging, inspired by the SLERP method. This approach merges tokens while preserving their norm distribution. Evaluation shows that ToFu outperforms ToMe in terms of accuracy while showing the similar performance gain at inference.

Deep Learning Software

  • Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads  by Together.AI. Contemporary speculative decoding solutions (Leviathan et al., Chen et al.) require multiple models (target and draft models) which often involves intricate optimization& selection of draft models to attain practical acceleration. For simplicity, Together.AI unveils a user-friendly framework, Medusa, built a top of a research work in 2018, "Block wise Parallel Decoding", with multiple enhancements. Medusa simplifies the creation of draft models without separate models by extending base model with multiple decoding heads. By keeping the base model frozen, the Medusa heads are trained in a parameter-efficient way and all on a single GPU. Medusa also features a tree-based attention mechanism for parallel evaluation of the proposed candidates, and a truncated sampling for efficient creative generation. Results and framework can be found at https://github.com/FasterDecoding/Medusa.
  • HyperAttention: Long-context Attention in Near-Linear Time by Yale University and Google https://github.com/insuhan/hyper-attn. Authors propose algorithm that consists of (1) finding heavy entries inattention matrix and (2) column subsampling. For (1), authors use the sorted locality sensitive hashing (sortLSH) based on the Hamming distance. Applying sortLSH makes heavy entries in the attention matrix (sorting rows/columns) located in near diagonal hence authors do block-diagonal approximation which can be done fast. The method supports casual masking. Code is available here: https://github.com/insuhan/hyper-attn.
  • Flash-Decoding for long-context inference by Stanford University. Flash Attention v1 & v2 are designed and optimized primarily for training case and exhibit low utilization of compute units when applied for LLM generation, especially for long context. As identified cause is rooted in low- batch size (query tokens) in relative to context length, Flash Decoding extends flash attention by adding 2nd-level tiling over the keys/values to improve compute utilization while retaining memory efficiency of flash attention. On A100, the micro-benchmark for multi-head attention with flash decoding kernel achieves almost constant run-time as the sequence length scales to up to 64k, translating up to 8X speedup of CodeLLaMa-34b over vanilla flash attention at very long sequences. Implementation is available at official flash attention repo & xformers.
  • LLM in a flash: Efficient Large Language Model Inference with Limited Memory by Apple (https://arxiv.org/pdf/2312.11514.pdf).The paper tackles the challenge of efficiently running LLMs that exceed the available DRAM capacity by storing the model parameters on flash memory but bringing them on demand to DRAM. The method involves constructing an inference cost model that harmonizes with the flash memory behavior, guiding to optimize in two critical areas: reducing the volume of data transferred from flash and reading data in larger, more contiguous chunks. Within this flash memory-informed framework, authors introduce two principal techniques. First, “windowing” strategically reduces data transfer by reusing previously activated neurons, and second, “row-column bundling”, tailored to the sequential data access strengths of flash memory, increases the size of data chunks read from flash memory. These methods collectively enable running models up to twice the size of the available DRAM, with a 4-5x and 20-25x increase in inference speed compared to naive loading approaches in CPU and GPU, respectively.
Read More...
Liubov
Talamanova

Optimizing Latent Consistency Model for Image Generation with OpenVINO™ and NNCF

November 23, 2023

Authors: Liubov Talamanova, Ekaterina Aidova, Alexander Kozlov

Introduction

Latent Diffusion Models (LDMs) make a revolution in AI-generated art. This technology enables creation of high-quality images simply by writing a text prompt. While LDMs like Stable Diffusion are capable of achieving the outstanding quality of generation, they often suffer from the slowness of the iterative image denoising process. Latent Consistency Model (LCM) is an optimized version of LDM. Inspired by Consistency Models (CM), Latent Consistency Models (LCMs) enabled swift inference with minimal steps on any pre-trained LDMs, including Stable Diffusion. The Consistency Models is a new family of generative models that enables one-step or few-step generation. More details about the proposed approach and models can be found using the following resources: project page, paper, original repository.

Similar to original Stable Diffusion pipeline, the LCM pipeline consists of three important parts:

  • Text Encoder to create a condition to generate an image from a text prompt.
  • U-Net for step-by-step denoising latent image representation.
  • Autoencoder (VAE) for decoding latent space to image.

In this post we explain how to optimize the LCM inference by OpenVINO for Intel hardware. Since LCM is trained in a way that it should be resistant to perturbations, it means that we can also apply common optimization methods such as quantization to lower the precision while expecting a consistent generation result and we do apply 8-bit Post-training Quantization from Neural Network Compression Framework (NNCF)

Convert models to OpenVINO format

To leverage efficient inference with OpenVINO runtime on the intel platform, the original model should be converted to OpenVINO Intermediate Representation (IR). OpenVINO supports the conversion of PyTorch models directly via Model Conversion API. ov.convert_model function accepts instance of PyTorch model and example inputs for tracing and returns object of ov.Model class, ready to use or save on disk using ov.save_model function. You can find conversion details of LCM in the OpenVINO LCM Notebook.

Processing time of the diffusion model

The diffusion pipeline requires multiple iterations to generate an image. Each iteration requires a non-negligible amount of time, depending on your inference device. We have benchmarked the stable diffusion pipeline on an Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz. The number of iterations was set at 10.

Benchmarking results:

Average Latency : 6.54 seconds

Encoding Phase:
Text encoding: 0.05 seconds

Denoising Loop : 4.28 seconds
U-Net part (4 iterations): 4.27 seconds
Scheduler: 0.01 seconds

Decoding Phase:
VAE decoding: 2.21 seconds

The U-Net part of the denoising loop takes more than 60% of the full pipeline execution time. That is why the computation cost and speed of the U-Net denoising becomes the critical path in the pipeline.

In this blog, we use Neural Network Compression Framework (NNCF) Post-Training Quantization (PTQ) API to quantize the U-Net model, which can further boost the model inference while keeping acceptable accuracy without fine-tuning. Quantizing the rest of the diffusion pipeline does not significantly improve inference performance but can lead to a substantial degradation of the accuracy.

Quantization

The quantization process includes the following steps:

  1. Create a calibration dataset for the quantization.
  2. Run nncf.quantize to obtain a quantized model.
  3. Save the INT8 model using ov.save_model function.

You can look at the dataset preparation for the U-Net model in OpenVINO LCM Notebook. General rules about dataset preparation you can find at OpenVINO documentation.

For INT8 quantization of LCM, we found some useful tricks to mitigate accuracy degradation caused by accuracy sensitive layers: 

  • The U-Net part of the LCM pipeline has a backbone with a transformer that operates on latent patches. To better preserve accuracy after NNCF PTQ, we should pass model_type=nncf.ModelType.Transformer to nncf.quantize function. It keeps several accuracy sensitive layers in FP16 precision.
  • Default symmetric quantization of weights and activations also leads to accuracy degradation of LCM. We recommend using preset=nncf.QuantizationPreset.MIXED to use symmetric quantization of weights and asymmetric quantization of activations that are more sensitive and impact the generation results more. So applying asymmetric quantization to activations helps to better represent their values and leads to better accuracy with no impact on the inference latency.
  • It was also discovered that the Fast Bias (error) Correction algorithm (FBC), which is a default part of NNCF PTQ, results in unexpected artifacts in the generated images. To disable FBC, we should pass advanced_parameters=nncf.AdvancedQuantizationParameters(disable_bias_correction=True) to nncf.quantize function.

Once the dataset is ready and the model object is instantiated, you can apply 8-bit quantization to it using the optimization workflow below:


import nncf
import openvino as ov

core = ov.Core()
unet = core.read_model(UNET_OV_PATH)
quantized_unet = nncf.quantize(
    model=unet,
    preset=nncf.QuantizationPreset.MIXED,
    calibration_dataset=nncf.Dataset(unet_calibration_data),
    model_type=nncf.ModelType.TRANSFORMER,
    advanced_parameters=nncf.AdvancedQuantizationParameters(
        disable_bias_correction=True
    )
)
ov.save_model(quantized_unet, UNET_INT8_OV_PATH

Text-to-image generation

The left image was generated using the original LCM pipeline from PyTorch. The middle image was generated using the model converted to OpenVINO FP16. The right image was generated using LCM with the quantized INT8 U-Net. Input prompt is “a beautiful pink unicorn, 8k”, seed is 1234567 and the number of inference steps is 4.

If you would like to generate your own images and compare original and quantized models, you can run an Interactive demo at the end of OpenVINO LCM Notebook.

We also measured time for the image generation by LCM pipeline with input prompt “a beautiful pink unicorn, 8k”, seed is 1234567 and 4 inference steps.

*Average time across 3 independent runs.

Performance speedup PyTorch vs OpenVINO+NNCF is 1.38x.

Notices and Disclaimers:

Performance varies by use, configuration, and other factors. Learn more at www.intel.com/PerformanceIndex​. ​Performance results are based on testing as of dates shown in configurations and may not reflect all publicly available ​updates. No product or component can be absolutely secure.​​​ ​Intel technologies may require enabled hardware, software or service activation.​​​​​​​​

The products described may contain design defects or errors known as errata which may cause the product to deviate from published specifications. Current characterized errata are available on request.​​

Test Configuration: Intel® Core™ i9-10980XE CPU Processor at 3.00GHz with DDR4 128 GB at 3600MHz, OS: Ubuntu 22.04.2 LTS. Tested with OpenVINO LCM Notebook.

The test was conducted by Intel on November 7, 2023.

Conclusion

In this blog, we introduced how to enable and quantize the Latent Consistency Model with OpenVINO™ runtime and NNCF:

  • Proposed NNCF INT8 PTQ quantization improves performance of image generation pipeline while preserving generation quality.
  • Provided OpenVINO LCM Notebook for model enabling, quantization, comparison of FP16 and INT8 model inference times and deployment with OpenVINO™ and NNCF.

As the next step, you can consider migration to a native OpenVINO C++ API for even faster generation pipeline inference and possibility to embed it into the client or edge-device application. You can find an example of such a pipeline here

Please give a star to NNCF and OpenVINO repositories if you find them useful.

Read More...
No items found.
Anna
Likholat

How to build and run OpenVino™ C++ Benchmark Application for Linux

October 10, 2023

Introduction

The OpenVINO™ Benchmark Application estimates deep learning inference performance on supported devices for synchronous and asynchronous modes.

NOTE: This guide describes the usage of the C++ implementation of the Benchmark Tool. For the Python implementation, refer to the Benchmark Python Tool page. The Python version is recommended for benchmarking models used in Python applications, and the C++ version is recommended for benchmarking models used in C++ applications.

In this tutorial, we will guide you through building and running the C++ implementation of the Benchmark Tool on Ubuntu with OpenVINO™ 2023.1.0 release and demonstrate its usage by benchmarking the Inception (GoogleNet) V3 deep learning model. The following steps outline the process:

  1. Download and Convert the Model
  2. Install OpenVINO™ Runtime
  3. Build OpenVINO™ C++ Runtime Samples
  4. Run the Benchmark Application


The benchmark application works with models in the OpenVINO™ IR (.xml and .bin), ONNX (.onnx), TensorFlow (*.pb), TensorFlow Lite (*.tflite) and PaddlePaddle (*.pdmodel) formats. Make sure to convert your models if necessary (see "Model conversion to OpenVINO™ IR format" step below).

Requirements

Before getting started, ensure that you have the following requirements in place:

  • Ubuntu 18.04 or higher
  • CMake version 3.10 or higher


Step 1: Install OpenVINO™

To get started, first install OpenVINO™ Runtime C++ API.

Download and Setup OpenVINO™ Runtime archive file for Linux for your system. The following steps describe the installation process for Ubuntu 20.04 x86_64 system:

1. Download the archive file, extract the files, rename the extracted folder, and move it to the desired path:

curl -L https://storage.openvinotoolkit.org/repositories/openvino/packages/2023.1/linux/l_openvino_toolkit_ubuntu20_2023.1.0.12185.47b736f63ed_x86_64.tgz --output openvino_2023.1.0.tgz
tar -xf openvino_2023.1.0.tgz
sudo mkdir /opt/intel/openvino_2023.1.0
sudo mv -v l_openvino_toolkit_ubuntu20_2023.1.0.12185.47b736f63ed_x86_64/* /opt/intel/openvino_2023.1.0


2. Install required system dependencies on Linux. To do this, OpenVINO provides a script in the extracted installation directory. Run the following command:

cd /opt/intel/openvino_2023.1.0
sudo -E ./install_dependencies/install_openvino_dependencies.sh


3. For simplicity, it is useful to create a symbolic link as below:

cd /opt/intel
sudo ln -s openvino_2023.1.0 openvino_2023


4. Set OpenVINO™ environment variables. Open a terminal window and run the setupvars.sh script to temporarily set your environment variables. If your <INSTALL_DIR> is not /opt/intel/openvino_2023, use the correct one instead:

source /opt/intel/openvino_2023/setupvars.sh

Step 2: Build OpenVINO™ C++ Runtime Samples

In the existing terminal window where the OpenVINO™ environment is set up, navigate to the /opt/intel/openvino_2023.1.0/samples/cpp directory and run the /build_samples.sh script:

cd /opt/intel/openvino_2023.1.0/samples/cpp
./build_samples.sh


As a result of a successful build, you'll get the message with a path to the sample binaries:

...
[100%] Linking CXX executable ../intel64/Release/benchmark_app
[100%] Built target benchmark_app
[100%] Built target ie_samples

Build completed, you can find binaries for all samples in the /home/user/openvino_cpp_samples_build/intel64/Release subfolder.


NOTE
: You can also use the -b option to specify the sample build directory and -i to specify the sample install directory, for example:

./build_samples.sh -b /home/user/ov_samples/build -i /home/user/ov_samples

NOTE: The build_samples.sh script will build all the samples in the /opt/intel/openvino_2023.1.0/samples/cpp folder. Remove the other samples from the folder if you want to build only a few samples or only the benchmark_app.

Step 3: Run the Benchmark Application

NOTE: You can use your model for benchmark running or if necessary download model for demo using the Model Downloader. You can find pre-trained models from either public models or Intel’s pre-trained modelsfrom the OpenVINO™ Open Model Zoo. Following are the steps to install the tools and obtain the IR for the Inception (GoogleNet) V3 PyTorch model:

pip install "openvino-dev>=2023.1.0"
omz_downloader --name googlenet-v3-pytorch
omz_converter --name googlenet-v3-pytorch --precisions FP32 


The googlenet-v3-pytorch IR files will be located at: <CURRENT_DIRECTORY>/public/googlenet-v3-pytorch/FP32

Navigate to the samples binaries folder and run the benchmark_app with the following command:

cd /home/user/openvino_cpp_samples_build/intel64/Release
./benchmark_app -m path/to/public/googlenet-v3-pytorch/FP32/googlenet-v3-pytorch.xml


By default, the application will load the specified model onto the CPU and perform inferencing on batches of randomly generated data inputs for 60 seconds. As it loads, it prints information about benchmark parameters. When benchmarking is completed, it reports the minimum, average, and maximum inferencing latency and average the throughput.

NOTE: You can use images from the media files collection available at test_data and infer with specific input data using the -i argument to benchmark_app.

You may be able to improve benchmark results beyond the default configuration by configuring some of the execution parameters for your model. Please find other options for configuring execution parameters here: Benchmark C++ Tool Configuration Options

Model conversion to OpenVINO™ IR format

You can use OpenVINO™ Model Converter to convert your model to Intermediate Representation (IR) when necessary:

1. Install OpenVINO™ for Python which includes the necessary components for utilizing the OpenVINO™ Model Converter.

NOTE: Ensure you install the same version of OpenVINO™ Runtime Package for Python as the OpenVINO™ Runtime C++ API in step 2.

pip install "openvino>=2023.1.0"


2. To convert the model to IR, run Model Converter:

ovc INPUT_MODEL


Related Articles

Install OpenVINO™ Runtime on Linux from an Archive File

Transition from Legacy Conversion API¶

OpenVINO™ Benchmark C++ Tool

OpenVINO™ Samples Overview

OpenVINO™ Development Tools

Running OpenVINO™ C++ samples on Visual Studio

Read More...