使用JojoGAN创建风格化的面部图

磐创AI
关注

介绍

风格迁移是神经网络的一个发展领域,它是一个非常有用的功能,可以集成到社交媒体和人工智能应用程序中。几个神经网络可以根据训练数据将图像样式映射和传输到输入图像。在本文中,我们将研究 JojoGAN,以及仅使用一种参考样式来训练和生成具有该样式的任何图像的过程。

JoJoGAN:One Shot Face Stylization

One Shot Face Stylization(一次性面部风格化)可用于 AI 应用程序、社交媒体过滤器、有趣的应用程序和业务用例。随着 AI 生成的图像和视频滤镜的日益普及,以及它们在社交媒体和短视频、图像中的使用,一次性面部风格化是一个有用的功能,应用程序和社交媒体公司可以将其集成到最终产品中。

因此,让我们来看看用于一次性生成人脸样式的流行 GAN 架构——JojoGAN。

JojoGAN 架构

JojoGAN 是一种风格迁移程序,可让将人脸图像的风格迁移为另一种风格。它通过GAN将参考风格图像反转为近似的配对训练数据,根据风格化代码生成真实的人脸图像,并与参考风格图像相匹配。然后将该数据集用于微调 StyleGAN,并且可以使用新的输入图像,JojoGAN 将根据 GAN 反转(inversion)将其转换为该特定样式。

JojoGAN 架构和工作流程

JojoGAN 只需一种参考风格即可在很短的时间内(不到 1 分钟)进行训练,并生成高质量的风格化图像。

JojoGan 的一些例子

JojoGAN 生成的风格化图像的一些示例:

风格化的图像可以在各种不同的输入风格上生成并且可以修改。

JojoGan 代码深潜

让我们看看 JojoGAN 生成风格化人像的实现。有几个预训练模型可用,它们可以在我们的风格图像上进行训练,或者可以修改模型以在几分钟内更改风格。

JojoGAN 的设置和导入

克隆 JojoGAN 存储库并导入必要的库。在 Google Colab 存储中创建一些文件夹,用于存储反转代码、样式图像和模型。

!git clone https://github.com/mchong6/JoJoGAN.git

%cd JoJoGAN

!pip install tqdm gdown scikit-learn==0.22 scipy lpips dlib opencv-python wandb

!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip

!sudo unzip ninja-linux.zip -d /usr/local/bin/

import torch

torch.backends.cudnn.benchmark = True

from torchvision import transforms, utils

from util import *

from PIL import Image

import math

import random

import os

import numpy

from torch import nn, autograd, optim

from torch.nn import functional

from tqdm import tqdm

import wandb

from model import *

from e4e_projection import projection

from google.colab import files

from copy import deepcopy

from pydrive.auth import GoogleAuth

from pydrive.drive import GoogleDrive

from google.colab import auth

from oauth2client.client import GoogleCredentials

模型文件

使用 Pydrive 下载模型文件。一组驱动器 ID 可用于预训练模型。这些预训练模型可用于随时随地生成风格化图像,并具有不同的准确度。之后,可以训练用户创建的模型。

#Download models

#optionally enable downloads with pydrive in order to authenticate and avoid drive download limits.

download_with_pydrive = True  

device = 'cuda' #['cuda', 'cpu']

!wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2

!bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2

!mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat

%matplotlib inline

drive_ids = {

   "stylegan2-ffhq-config-f.pt": "1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK",

   "e4e_ffhq_encode.pt": "1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7",

   "restyle_psp_ffhq_encode.pt": "1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd",

   "arcane_caitlyn.pt": "1gOsDTiTPcENiFOrhmkkxJcTURykW1dRc",

   "arcane_caitlyn_preserve_color.pt": "1cUTyjU-q98P75a8THCaO545RTwpVV-aH",

   "arcane_jinx_preserve_color.pt": "1jElwHxaYPod5Itdy18izJk49K1nl4ney",

   "arcane_jinx.pt": "1quQ8vPjYpUiXM4k1_KIwP4EccOefPpG_",

   "arcane_multi_preserve_color.pt": "1enJgrC08NpWpx2XGBmLt1laimjpGCyfl",

   "arcane_multi.pt": "15V9s09sgaw-zhKp116VHigf5FowAy43f",

   "sketch_multi.pt": "1GdaeHGBGjBAFsWipTL0y-ssUiAqk8AxD",

   "disney.pt": "1zbE2upakFUAx8ximYnLofFwfT8MilqJA",

   "disney_preserve_color.pt": "1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi",

   "jojo.pt": "13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4",

   "jojo_preserve_color.pt": "1ZRwYLRytCEKi__eT2Zxv1IlV6BGVQ_K2",

   "jojo_yasuho.pt": "1grZT3Gz1DLzFoJchAmoj3LoM9ew9ROX_",

   "jojo_yasuho_preserve_color.pt": "1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L",

   "art.pt": "1a0QDEHwXQ6hE_FcYEyNMuv5r5UnRQLKT",



# from StyelGAN-NADA

class Downloader(object):

   def __init__(self, use_pydrive):

       self.use_pydrive = use_pydrive

       if self.use_pydrive:

           self.authenticate()

   def authenticate(self):

       auth.authenticate_user()

       gauth = GoogleAuth()

       gauth.credentials = GoogleCredentials.get_application_default()

       self.drive = GoogleDrive(gauth)

   def download_file(self, file_name):

       file_dst = os.path.join('models', file_name)

       file_id = drive_ids[file_name]

       if not os.path.exists(file_dst):

           print(f'Downloading {file_name}')

           if self.use_pydrive:

               downloaded = self.drive.CreateFile({'id':file_id})

               downloaded.FetchMetadata(fetch_all=True)

               downloaded.GetContentFile(file_dst)

           else:

               !gdown --id $file_id -O $file_dst

downloader = Downloader(download_with_pydrive)


downloader.download_file('stylegan2-ffhq-config-f.pt')

downloader.download_file('e4e_ffhq_encode.pt')


加载生成器

加载原始和微调生成器。设置用于调整图像大小和规范化图像的 transforms。

latent_dim = 512

# Load original generator

original_generator = Generator(1024, latent_dim, 8, 2).to(device)

ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)

original_generator.load_state_dict(ckpt["g_ema"], strict=False)

mean_latent = original_generator.mean_latent(10000)

# to be finetuned generator

generator = deepcopy(original_generator)

transform = transforms.Compose(

   [

       transforms.Resize((1024, 1024)),

       transforms.ToTensor(),

       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

   ]

输入图像

设置输入图像位置。对齐和裁剪面并重新设置映射的样式。

#image to the test_input directory and put the name here

filename = 'face.jpeg' #@param {type:"string"}

filepath = f'test_input/{filename}'

name = strip_path_extension(filepath)+'.pt'

# aligns and crops face

aligned_face = align_face(filepath)

# my_w = restyle_projection(aligned_face, name, device, n_iters=1).unsqueeze(0)

my_w = projection(aligned_face, name, device).unsqueeze(0)

预训练图

选择预训练好的图类型,选择不保留颜色的检查点,效果更好。

plt.rcParams['figure.dpi'] = 150

pretrained = 'sketch_multi' #['art', 'arcane_multi', 'sketch_multi', 'arcane_jinx', 'arcane_caitlyn', 'jojo_yasuho', 'jojo', 'disney']

#Preserve color tries to preserve color of original image by limiting family of allowable transformations.

if preserve_color:

   ckpt = f'{pretrained}_preserve_color.pt'

else:

   ckpt = f'{pretrained}.pt'

生成结果

加载检查点和生成器并设置种子值,然后开始生成风格化图像。用于 Elon Musk 的输入图像将根据图类型进行风格化。

#Generate results

n_sample =  5#{type:"number"}

seed = 3000 #{type:"number"}

torch.manual_seed(seed)

with torch.no_grad():

   generator.eval()

   z = torch.randn(n_sample, latent_dim, device=device)

   original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)

   sample = generator([z], truncation=0.7, truncation_latent=mean_latent)

   original_my_sample = original_generator(my_w, input_is_latent=True)

   my_sample = generator(my_w, input_is_latent=True)

# display reference images

if pretrained == 'arcane_multi':

   style_path = f'style_images_aligned/arcane_jinx.png'

elif pretrained == 'sketch_multi':

   style_path = f'style_images_aligned/sketch.png'

else:   

   style_path = f'style_images_aligned/{pretrained}.png'

style_image = transform(Image.open(style_path)).unsqueeze(0).to(device)

face = transform(aligned_face).unsqueeze(0).to(device)


my_output = torch.cat([style_image, face, my_sample], 0)

生成的结果

结果生成为预先训练的类型“Jojo”,看起来相当准确。

现在让我们看一下在自创样式上训练 GAN。

使用你的风格图像进行训练

选择一些面部图,甚至创建一些自己的面部图并加载这些图像以训练 GAN,并设置路径。裁剪和对齐人脸并执行 GAN 反转。

names = ['1.jpg', '2.jpg', '3.jpg']

targets = []

latents = []

for name in names:

   style_path = os.path.join('style_images', name)

   assert os.path.exists(style_path), f"{style_path} does not exist!"

   name = strip_path_extension(name)

   # crop and align the face

   style_aligned_path = os.path.join('style_images_aligned', f'{name}.png')

   if not os.path.exists(style_aligned_path):

       style_aligned = align_face(style_path)

       style_aligned.save(style_aligned_path)

   else:

       style_aligned = Image.open(style_aligned_path).convert('RGB')

   # GAN invert

   style_code_path = os.path.join('inversion_codes', f'{name}.pt')

   if not os.path.exists(style_code_path):

       latent = projection(style_aligned, style_code_path, device)

   else:

       latent = torch.load(style_code_path)['latent']

   latents.append(latent.to(device))

targets = torch.stack(targets, 0)

latents = torch.stack(latents, 0)

微调 StyleGAN

通过调整 alpha、颜色保留和设置迭代次数来微调 StyleGAN。加载感知损失的鉴别器并重置生成器。

#Finetune StyleGAN

#alpha controls the strength of the style

alpha =  1.0 # min:0, max:1, step:0.1

alpha = 1-alpha

#preserve color of original image by limiting family of allowable transformations

preserve_color = False 

#Number of finetuning steps.

num_iter = 300

#Log training on wandb and interval for image logging

use_wandb = False 

log_interval = 50

if use_wandb:

   wandb.init(project="JoJoGAN")

   config = wandb.config

   config.num_iter = num_iter

   config.preserve_color = preserve_color

   wandb.log(

   {"Style reference": [wandb.Image(transforms.ToPILImage()(target_im))]},

   step=0)

# load discriminator for perceptual loss

discriminator = Discriminator(1024, 2).eval().to(device)

ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)

discriminator.load_state_dict(ckpt["d"], strict=False)

# reset generator

del generator

generator = deepcopy(original_generator)

g_optim = optim.Adam(generator.parameters(), lr=2e-3, betas=(0, 0.99))

训练生成器从潜在空间生成图像,并优化损失。

if preserve_color:

   id_swap = [9,11,15,16,17]

z = range(numiter)

for idx in tqdm( z):

   mean_w = generator.get_latent(torch.randn([latents.size(0), latent_dim]).to(device)).unsqueeze(1).repeat(1, generator.n_latent, 1)

   

in_latent = latents.clone()

   in_latent[:, id_swap] = alpha*latents[:, id_swap] + (1-alpha*mean_w[:, id_swap]


   img = generator(in_latent, input_is_latent=True)

   with torch.no_grad():

       real_feat = discriminator(targets)
   

   fake_feat = discriminator(img)

   loss = sum([functional.l1_loss(a, b) for a, b in zip(fake_feat, real_feat)])/len(fake_feat)  
   

   if use_wandb:

       wandb.log({"loss": loss}, step=idx)

       if idx % log_interval == 0:

           generator.eval()

           my_sample = generator(my_w, input_is_latent=True)

           generator.train()

           wandb.log(

           {"Current stylization": [wandb.Image(my_sample)]},

           step=idx)

   g_optim.zero_grad()

   loss.backward()

   g_optim.step()

使用 JojoGAN 生成结果

现在生成结果。下面已经为原始图像和示例图像生成了结果以进行比较。

#Generate resultsn_sample =  5

seed = 3000

torch.manual_seed(seed)

with torch.no_grad():

   generator.eval()

   z = torch.randn(n_sample, latent_dim, device=device)

   original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)

   sample = generator([z], truncation=0.7, truncation_latent=mean_latent)

   original_my_sample = original_generator(my_w, input_is_latent=True)

   my_sample = generator(my_w, input_is_latent=True)


# display reference images

style_images = []

for name in names:

   style_path = f'style_images_aligned/{strip_path_extension(name)}.png'

   style_image = transform(Image.open(style_path))

   style_images.append(style_image)

face = transform(aligned_face).to(device).unsqueeze(0)

style_images = torch.stack(style_images, 0).to(device)

my_output = torch.cat([face, my_sample], 0)

output = torch.cat([original_sample, sample], 0)

生成的结果

现在,你可以使用 JojoGAN 生成你自己风格的图像。结果令人印象深刻,但可以通过调整训练方法和训练图像中的更多特征来进一步改进。

结论

JojoGAN 能够以快速有效的方式准确地映射和迁移用户定义的样式。关键要点是:

· JojoGAN 可以只用一种风格进行训练,以轻松映射并创建任何面部的风格化图

· JojoGAN 非常快速有效,可以在不到一分钟的时间内完成训练

· 结果非常准确,类似于逼真的肖像

· JojoGAN 可以轻松微调和修改,使其适用于 AI 应用程序

因此,无论风格类型、形状和颜色如何,JojoGAN 都是用于风格转移的理想神经网络,因此可以成为各种社交媒体应用程序和 AI 应用程序中非常有用的功能。

       原文标题 : 使用JojoGAN创建风格化的面部图

声明: 本文由入驻OFweek维科号的作者撰写,观点仅代表作者本人,不代表OFweek立场。如有侵权或其他问题,请联系举报。
侵权投诉

下载OFweek,一手掌握高科技全行业资讯

还不是OFweek会员,马上注册
打开app,查看更多精彩资讯 >
  • 长按识别二维码
  • 进入OFweek阅读全文
长按图片进行保存