×
Community Blog The EasyNLP Chinese Text-to-Image Generation Model Can Make Anyone an Artist in Seconds!

The EasyNLP Chinese Text-to-Image Generation Model Can Make Anyone an Artist in Seconds!

This article introduces text-to-image generation technology and how to implement it in EasyNLP.

By Chengyu Wang and Tingting Liu

Introduction

1

Nothing is better than words for publicity, and nothing is better than drawing for preserving form. – Ji Lu [Jin Dynasty (265-420 A.D.)]

Multimodal data (text, image, and sound) is an important carrier for humans to know, understand, and express everything in the world. In recent years, the explosive growth of multimodal data has prospered the content Internet and brought significant requirements for multimodal content understanding and generation. Unlike common cross-modal understanding tasks, the text-to-image generation task is a popular cross-modal generation task that aims to generate an image that corresponds to the given text. It has unlocked AI imagination and stimulated human creativity. Typical text-to-image generation models include DALL-E and DALL-E2, developed by OpenAI. Recently, the industry has trained larger and newer models (such as Parti and Imagen) proposed by Google.

However, the models above cannot be used to handle requirements in Chinese. These models' large number of parameters make it difficult for users in the open-source community to directly perform fine-tuning and inferring. The open-source framework, EasyNLP, has been upgraded again. It integrates with advanced text-to-image generation architectures, Transformer and VQGAN. At the same time, it gives free access to the open-source community for the checkpoints of Chinese text-to-image generation models with different parameters and the corresponding fine-tuning and inference interfaces. Users can make a small amount of domain-related fine-tuning based on the checkpoints and create various arts easily without consuming a large amount of computing resources.

EasyNLP is an easy-to-use Chinese NLP algorithm framework developed by the Alibaba Cloud Machine Learning Platform for Artificial Intelligence Team based on PyTorch. It provides an end-to-end NLP development experience from training to deployment. EasyNLP provides simple interfaces for users to develop NLP models, including features such as the NLP application AppZoo, pre-trained model ModelZoo, and DataHub. Due to the increasing demand for cross-modal understanding and generation, EasyNLP supports various cross-modal models, especially those in Chinese, to be promoted to the open-source community.

We hope to serve more NLP and multimodal algorithm developers and researchers and work with the communities to promote the development of NLP/multimodal technology and the implementation of models. This article briefly introduces text-to-image generation technology and how to implement it in EasyNLP, helping anyone become an artist easily. The pictures at the beginning of this article were designed by EasyNLP.

A Brief Introduction to the Text-to-Image Generation Model

The following several Transformer-based working processes are used to briefly introduce text-to-image generation technology. DALL-E (developed by OpenAI) takes a two-stage approach to image generation. In the first stage, DALL-E trains a discrete variational autoencoder (dVAE) model to convert a 256×256 RGB image into 32×32 image tokens. This step compresses and discretizes the images to facilitate text-to-image generation. In the second stage, DALL-E trains an autoregressive Transformer model to convert the text into the 1024 image tokens above.

The CogView model proposed by Tsinghua University and others optimizes the process of the two stages above. In the following figure, CogView uses sentence pieces as the text tokenizer to enrich the spatial expression of the input text and uses various technologies when fine-tuning the model (such as image super-resolution and style transfer).

2

The ERNIE-ViLG model considers the transferability of the learning knowledge of the Transformer model and learns both the text-to-image generation and image-to-text generation tasks. The following figure shows the architecture:

3

With the development of text-to-image generation technology, new models and technologies are constantly emerging. For example, OFA unifies multiple cross-modal generation tasks in the same model architecture. DALL-E 2 (also developed by OpenAI) is an upgraded version of DALL-E. It introduces hierarchical image generation technology and uses CLIP encoder to incorporate CLIP pre-trained cross-modal characterization better. Then, Google proposed the architecture of Diffusion Model, which can effectively generate large high-definition images, as shown below:

4

This article will not go into detail. Interested readers may refer to the references.

The EasyNLP Text-to-Image Generation Model

The scale of the aforementioned models is often at the level of billions (or tens of billions) of parameters. Although huge models can generate high-quality images, the requirements for computing resources and pre-training data make it difficult for these models to be widely applied in the open-source community, especially when they need to be oriented to vertical fields. This section details the Chinese text-to-image generation model provided by EasyNLP, which still has a good text-to-image generation effect in the case of a small number of parameters.

Model Architecture

The following figure shows the model architecture:

5

Considering that the complexity of the Transformer model increases quadratically with the length of sequences, the training of the text-to-image generation model is generally carried out in a two-stage combination of image vector quantization and autoregressive training.

Image vector quantization refers to the discrete encoding of an image. For example, a 256×256 RGB image is downsampled 16 times to obtain 16×16 discrete sequences. Each image token in the sequences corresponds to a representation in the codebook. Common image vector quantization methods include VQVAE, VQVAE-2, and VQGAN. We use the model weights of f16_16384 (16-fold down-sampling, codebook size of 16384) trained by VQGAN on ImageNet to generate discretized sequences of images.

Autoregressive training takes text sequences and image sequences as input. In the image part, each image token is only used with the tokens of the text sequences and the previous image tokens for attention calculation. We use GPT as the backbone, which can adapt to the generation task of different models. In the model prediction stage, text sequences are input, and the model gradually generates fixed-length image sequences in an autoregressive manner and then reconstructs an image through the VQGAN decoder.

Parameter Settings for Open-Source Models

In EasyNLP, we provide two versions of the Chinese text-to-image generation model. The model parameter configurations are listed below:

Model Settings pai-painter-base-zh pai-painter-large-zh
Parameters 202M 433M
Number of Layers 12 24
Attention Heads 12 16
Hidden Vector Size 768 1024
Text Length 32 32
Image Length 16 x 16 16 x 16
Image Size 256 x 256 256 x 256
Codebook Size 16384 16384

Model Implementation

In the EasyNLP framework, we build a model based on the backbone of minGPT at the model layer. The core part is listed below:

self.first_stage_model = VQModel(ckpt_path=vqgan_ckpt_path).eval()
self.transformer = GPT(self.config)

The encoding of VQModel is listed below:

# in easynlp/appzoo/text2image_generation/model.py

@torch.no_grad()
def encode_to_z(self, x):
    quant_z, _, info = self.first_stage_model.encode(x)
    indices = info[2].view(quant_z.shape[0], -1)
    return quant_z, indices

x = inputs['image']
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
# one step to produce the logits
_, z_indices = self.encode_to_z(x)  # z_indice: torch.Size([batch_size, 256]) 

The decoding of VQModel is listed below:

# in easynlp/appzoo/text2image_generation/model.py

@torch.no_grad()
def decode_to_img(self, index, zshape):
    bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
    quant_z = self.first_stage_model.quantize.get_codebook_entry(
        index.reshape(-1), shape=bhwc)
    x = self.first_stage_model.decode(quant_z)
    return x

# Sample is the result generated in the training phase, which is similar to the generation in the prediction phase. For more information, see the following generation.
index_sample = self.sample(z_start_indices, c_indices,
                           steps=z_indices.shape[1],
                           ...)
x_sample = self.decode_to_img(index_sample, quant_z.shape)

Transformer uses minGPT for construction. It outputs text tokens after the discrete encoding of an image is input. The forward propagation process is listed below:

# in easynlp/appzoo/text2image_generation/model.py

def forward(self, inputs):
    x = inputs['image']
    c = inputs['text']
    x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
    # one step to produce the logits
    _, z_indices = self.encode_to_z(x)  # z_indice: torch.Size([batch_size, 256]) 
    c_indices = c
    
    if self.training and self.pkeep < 1.0:
        mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
                                                     device=z_indices.device))
        mask = mask.round().to(dtype=torch.int64)
        r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
        a_indices = mask*z_indices+(1-mask)*r_indices
    
    else:
        a_indices = z_indices
        cz_indices = torch.cat((c_indices, a_indices), dim=1)
        # target includes all sequence elements (no need to handle first one
        # differently because we are conditioning)
        target = z_indices
        # make the prediction
        logits, _ = self.transformer(cz_indices[:, :-1])
        # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
        logits = logits[:, c_indices.shape[1]-1:]
return logits, target

In the prediction phase, the inputs are text tokens and the output is a 256*256 image. First, the input text is preprocessed into token sequences:

# in easynlp/appzoo/text2image_generation/predictor.py

def preprocess(self, in_data):
    if not in_data:
        raise RuntimeError("Input data should not be None.")

    if not isinstance(in_data, list):
        in_data = [in_data]
    rst = {"idx": [], "input_ids": []}
    max_seq_length = -1
    for record in in_data:
        if "sequence_length" not in record:
            break
        max_seq_length = max(max_seq_length, record["sequence_length"])
    max_seq_length = self.sequence_length if (max_seq_length == -1) else max_seq_length

    for record in in_data:
        text= record[self.first_sequence]
        try:
            self.MUTEX.acquire()
            text_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
            text_ids = text_ids[: self.text_len]
            n_pad = self.text_len - len(text_ids)
            text_ids += [self.pad_id] * n_pad
            text_ids = np.array(text_ids) + self.img_vocab_size

        finally:
            self.MUTEX.release()

        rst["idx"].append(record["idx"]) 
        rst["input_ids"].append(text_ids)
return rst

Then, discrete tokens sequences of the image with the length of 16*16 are generated:

# in easynlp/appzoo/text2image_generation/model.py

def generate(self, inputs, top_k=100, temperature=1.0):
    cidx = inputs
    sample = True
    steps = 256
    for k in range(steps):
        x_cond = cidx
        logits, _ = self.transformer(x_cond)
        # pluck the logits at the final step and scale by temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop probabilities to only the top k options
        if top_k is not None:
            logits = self.top_k_logits(logits, top_k)
        # apply softmax to convert to probabilities
        probs = torch.nn.functional.softmax(logits, dim=-1)
        # sample from the distribution or take the most likely
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        # append to the sequence and continue
        cidx = torch.cat((cidx, ix), dim=1)
    img_idx = cidx[:, 32:]
return img_idx

Finally, we call the decoding of VQModel to convert these discrete token sequences into an image.

Model Effects

We have verified the effectiveness of the Chinese text-to-image generation model of the EasyNLP framework on four Chinese public datasets: COCO-CN, MUGE, Flickr8k-CN, and Flickr30k-CN. At the same time, we have compared the effect of this model with CogView and DALL-E. The results are listed below:

6

1) MUGE is the Chinese large-scale multimodal evaluation benchmark for e-commerce scenarios published by the Tianchi platform. We use the results of the valid dataset for MUGE and the results of the test dataset for other datasets to facilitate the calculation of metrics.

2) The official code of CogView is derived from https://github.com/THUDM/CogView

3) There is no public official code for the DALL-E model. What has been published contains only the code for VQVAE, not for Transformer. We reproduce the code based on the widely watched version of https://github.com/lucidrains/DALLE-pytorch and the checkpoints recommended by this version. The checkpoints have 0.209 billion parameters, which is 1/100 of the parameters of OpenAI DALL-E. (OpenAI DALL-E has 12 billion parameters, in which CLIP has 0.4 billion parameters.)

Typical Cases

We fine-tune the base and large models on the natural scenery dataset COCO-CN respectively. The effect of the model is listed below:

Example 1: A Playful Dog Is Running across the Grass.

7

Example 2: A Body of Water with a Sunset as the Background.

8

We have accumulated a large amount of e-commerce product data from Alibaba and got the text-to-image generation model for e-commerce products through fine-tuning. The effect is listed below.

Example 3: Girls' Pullover Bottoming Sweater for Autumn and Winter

9

Example 4: Comfortable Women's Dark Leather Shoes for Job Interviews and Long-Hour Work in Spring and Summer.

10

In addition to being applied in specific fields, the text-to-image generation model also assists human artistic creation. We can use the trained model and become masters of Chinese painting, as shown in the following example:

11

More examples:

12

Tutorial

After appreciating the works generated by the model, what should we do if we want to train our own text-to-image generation models? In the following, we briefly introduce how to perform fine-tuning and inferring on the pre-trained text-to-image generation model in the EasyNLP framework.

Install EasyNLP

Please visit this link for more information about how to install EasyNLP.

Data Preparation

First, prepare the training data and validation data, which is a tsv file. This file contains two columns separated by t. The first column is the index number, the second column is the text, and the third column is the base64 encoding of the image. The input file for the test contains two columns: the index number and the text.

We also provide the following sample code (for the convenience of developers) to convert the image to base64:

import base64
from io import BytesIO
from PIL import Image

img = Image.open(fn)
img_buffer = BytesIO()
img.save(img_buffer, format=img.format)
byte_data = img_buffer.getvalue()
base64_str = base64.b64encode(byte_data) # bytes

The following files have been preprocessed and can be used for the test:

# train
https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_train_text_imgbase64.tsv

# valid
https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_val_text_imgbase64.tsv

# test
https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_test.text.tsv

Model Training

Run the following command to fine-tune the model:

easynlp \
    --mode=train \
    --worker_gpu=1 \
    --tables=MUGE_val_text_imgbase64.tsv,MUGE_val_text_imgbase64.tsv \
    --input_schema=idx:str:1,text:str:1,imgbase64:str:1 \
    --first_sequence=text \
    --second_sequence=imgbase64 \
    --checkpoint_dir=./finetuned_model/ \
    --learning_rate=4e-5 \
    --epoch_num=1 \
    --random_seed=42 \
    --logging_steps=100 \
    --save_checkpoint_steps=1000 \
    --sequence_length=288 \
    --micro_batch_size=16 \
    --app_name=text2image_generation \
    --user_defined_parameters='
        pretrain_model_name_or_path=alibaba-pai/pai-painter-large-zh
        size=256
        text_len=32
        img_len=256
        img_vocab_size=16384
' 

We provide base and large versions of pre-trained models, and the pretrain_model_name_or_path is alibaba-pai/pai-painter-base-zh and alibaba-pai/pai-painter-large-zh, respectively.

After training, the model is saved to ./finetuned_model /.

Batch Inference of the Model

After the model is trained, we can use it for image generation, as shown in the following example:

easynlp \
    --mode=predict \
    --worker_gpu=1 \
    --tables=MUGE_test.text.tsv \
    --input_schema=idx:str:1,text:str:1 \
    --first_sequence=text \
    --outputs=./T2I_outputs.tsv \
    --output_schema=idx,text,gen_imgbase64 \
    --checkpoint_dir=./finetuned_model/ \
    --sequence_length=288 \
    --micro_batch_size=8 \
    --app_name=text2image_generation \
    --user_defined_parameters='
        size=256
        text_len=32
        img_len=256
        img_vocab_size=16384
'

The result is stored in a tsv file. Each line corresponds to an input text. The output image is encoded in base64.

Use the Pipeline Interface to Quickly Experience the Effect of Text-to-Image Generation

We have implemented the inference pipeline feature within the EasyNLP framework to facilitate developers further. Users can run the following command to call the fine-tuned text-to-image generation model in the e-commerce scenario:

# Directly build a pipeline.
default_ecommercial_pipeline = pipeline("pai-painter-commercial-base-zh")

# Model prediction
data = ["Loose T-shirt"]
results = default_ecommercial_pipeline(data)  # Each result is a base64 encoding of the generated image

# Convert base64 to image
def base64_to_image(imgbase64_str):
    image = Image.open(BytesIO(base64.urlsafe_b64decode(imgbase64_str)))
    return image

# Save the image named with the text
for text, result in zip(data, results):
    imgpath = '{}.png'.format(text)
    imgbase64_str = result['gen_imgbase64']
    image = base64_to_image(imgbase64_str)
    image.save(imgpath)
print('text: {}, save generated image: {}'.format(text, imgpath))

In addition to e-commerce scenarios, we provide models for the following scenarios:

  • Natural Scenery: "pai-painter-scenery-base-zh"
  • Chinese Landscape Painting: "pai-painter-painting-base-zh"

After replacing "pai-painter-commercial-base-zh" in the sample code above, you can experience it directly. We welcome you to try it.

For the text-to-image generation models fine-tuned by users, we also open the pipeline interface for custom model loading:

# Load the model and build the pipeline.
local_model_path = ...
text_to_image_pipeline = pipeline("text2image_generation", local_model_path)

# Model prediction
data = ["xxxx"]
results = text_to_image_pipeline(data)  # Each result is a base64 encoding of the generated image

Outlook

In this work, we have integrated the Chinese text-to-image generation function in the EasyNLP framework and opened the checkpoints of the model, so users from the open-source community can make a small number of domain-related fine-tuning and various artistic creations within limited resources. In the future, we plan to launch more related models in the EasyNLP framework. Please stay tuned. We will also integrate more SOTA models (especially Chinese models) in the EasyNLP framework to support various NLP and multimodal tasks. In addition, the Alibaba Cloud Machine Learning Platform for Artificial Intelligence Team continues to promote the self-development of Chinese multimodal models. Users are welcome to follow us and join our open-source community to build Chinese NLP models and multimodal algorithm libraries!

EasyNLP address on GitHub: https://github.com/alibaba/EasyNLP

References

  1. Chengyu Wang, Minghui Qiu, Taolin Zhang, Tingting Liu, Lei Li, Jianing Wang, Ming Wang, Jun Huang, Wei Lin. EasyNLP: A Comprehensive and Easy-to-use Toolkit for Natural Language Processing. arXiv
  2. Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, Ilya Sutskever. Zero-Shot Text-to-Image Generation. ICML 2021: 8821-8831
  3. Ming Ding, Zhuoyi Yang, Wenyi Hong, Wendi Zheng, Chang Zhou, Da Yin, Junyang Lin, Xu Zou, Zhou Shao, Hongxia Yang, Jie Tang. CogView: Mastering Text-to-Image Generation via Transformers. NeurIPS 2021: 19822-19835
  4. Han Zhang, Weichong Yin, Yewei Fang, Lanxin Li, Boqiang Duan, Zhihua Wu, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang. ERNIE-ViLG: Unified Generative Pre-training for Bidirectional Vision-Language Generation. arXiv
  5. Peng Wang, An Yang, Rui Men, Junyang Lin, Shuai Bai, Zhikang Li, Jianxin Ma, Chang Zhou, Jingren Zhou, Hongxia Yang. Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework. ICML 2022
  6. Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. Hierarchical Text-Conditional Image Generation with CLIP Latents. arXiv
  7. Van Den Oord A, Vinyals O. Neural discrete representation learning. NIPS 2017
  8. Esser P, Rombach R, Ommer B. Taming transformers for high-resolution image synthesis. CVPR 2021: 12873-12883.
  9. Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho, David J. Fleet, Mohammad Norouzi: Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding. arXiv
  10. Jiahui Yu, Yuanzhong Xu, Jing Yu Koh, Thang Luong, Gunjan Baid, Zirui Wang, Vijay Vasudevan, Alexander Ku, Yinfei Yang, Burcu Karagol Ayan, Ben Hutchinson, Wei Han, Zarana Parekh, Xin Li, Han Zhang, Jason Baldridge, Yonghui Wu. Scaling Autoregressive Models for Content-Rich Text-to-Image Generation. arXiv
0 0 0
Share on

You may also like

Comments

Related Products