Blog Full Notice
back to main page

4 분 소요

motivation: subprocess만들어서 모델 돌리기

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

"""Project given image to the latent space of pretrained network pickle."""

import copy
import os
from time import perf_counter

import click
import imageio
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F

import dnnlib
import legacy

import multiprocessing
from tqdm import tqdm
from itertools import chain

def project(
    G,
    target: torch.Tensor,  # d[C,H,W] and dynamic range [0,255], W & H must match G output resolution
    *,
    num_steps                  = 150,
    w_avg_samples              = 10000,
    initial_learning_rate      = 0.1,
    initial_noise_factor       = 0.05,
    lr_rampdown_length         = 0.25,
    lr_rampup_length           = 0.05,
    noise_ramp_length          = 0.75,
    regularize_noise_weight    = 1e5,
    verbose                    = False,
    device: torch.device
):
    assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)

    def logprint(*args):
        if verbose:
            print(*args)

    G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore

    # Compute w stats.
    logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
    z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
    w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None)
    w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)
    w_avg = np.mean(w_samples, axis=0, keepdims=True)
    w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5

    # Setup noise inputs.
    noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }

    # Load VGG16 feature detector.
    url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
    with dnnlib.util.open_url(url) as f:
        vgg16 = torch.jit.load(f).eval().to(device)

    # Features for target image.
    target_images = target.unsqueeze(0).to(device).to(torch.float32)
    if target_images.shape[2] > 256:
        target_images = F.interpolate(target_images, size=(256, 256), mode='area')
    target_features = vgg16(target_images, resize_images=False, return_lpips=True)

    w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True)
    optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)

    # Init noise.
    for buf in noise_bufs.values():
        buf[:] = torch.randn_like(buf)
        buf.requires_grad = True

    for step in range(num_steps):
        # Learning rate schedule.
        t = step / num_steps
        w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
        lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
        lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
        lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
        lr = initial_learning_rate * lr_ramp
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Synth images from opt_w.
        w_noise = torch.randn_like(w_opt) * w_noise_scale
        ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
        synth_images = G.synthesis(ws, noise_mode='const')

        # Downsample image to 256x256 if it's larger than that.
        synth_images = (synth_images + 1) * (255/2)
        if synth_images.shape[2] > 256:
            synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')

        # Features for synth images.
        synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
        dist = (target_features - synth_features).square().sum()

        # Noise regularization.
        reg_loss = 0.0
        for v in noise_bufs.values():
            noise = v[None,None,:,:]
            while True:
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
                if noise.shape[2] <= 8:
                    break
                noise = F.avg_pool2d(noise, kernel_size=2)
        loss = dist + reg_loss * regularize_noise_weight

        # Step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        # logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
        # Normalize noise.
        with torch.no_grad():
            for buf in noise_bufs.values():
                buf -= buf.mean()
                buf *= buf.square().mean().rsqrt()

    # Return single w_opt as [1, 1, 512] for all layers.
    return w_opt.detach()

#----------------------------------------------------------------------------

def process_images(subdir_list, network_pkl, device, num_steps):
    # 각 프로세스에서 네트워크를 별도로 로드
    with dnnlib.util.open_url(network_pkl) as fp:
        G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device)
    
    for subdir in subdir_list:
        for filename in os.listdir(subdir):
            if filename.lower().endswith('.jpg'):
                file_path = os.path.join(subdir, filename)
                target_pil = PIL.Image.open(file_path).convert('RGB')
                w, h = target_pil.size
                s = min(w, h)
                target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
                target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
                target_uint8 = np.array(target_pil, dtype=np.uint8)
                
                projected_w = project(
                    G,
                    target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device),
                    num_steps=num_steps,
                    device=device,
                    verbose=False
                )
                
                npz_out_path = os.path.join(subdir, os.path.splitext(filename)[0] + '.npz')
                np.savez(npz_out_path, w=projected_w.cpu().numpy())
                print(f"Saved {npz_out_path}")

# 메인 함수 및 멀티프로세싱 설정은 동일하게 유지


@click.command()
@click.option('--network', 'network_pkl', required=True, help='Network pickle filename')
@click.option('--target-dir', 'target_dir', required=True, help='Directory containing target images', metavar='DIR')
@click.option('--num-steps', default=150, help='Number of optimization steps', show_default=True, type=int)
@click.option('--num-processes', default=8, help='Number of parallel processes', show_default=True, type=int)
def run_projection(network_pkl, target_dir, num_steps, num_processes):
    device = torch.device('cuda')
    subdir_list = [os.path.join(target_dir, d) for d in os.listdir(target_dir) if os.path.isdir(os.path.join(target_dir, d))]
    chunks = [subdir_list[i::num_processes] for i in range(num_processes)]
    
    processes = []
    for chunk in chunks:
        p = multiprocessing.Process(target=process_images, args=(chunk, network_pkl, device, num_steps))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()



if __name__ == "__main__":
    run_projection() # pylint: disable=no-value-for-parameter
    
#----------------------------------------------------------------------------

각 subprocess마다 G network을 받아오도록 설계했다.

질문: 근데 보통 multi gpu 환경에서 subprocess를 만들어서 사용하진 않지?

ChatGPT

맞습니다. 일반적으로 멀티 GPU 환경에서는 서브프로세스를 사용하는 것보다 PyTorch의 기본 제공 기능인 DataParallel 또는 DistributedDataParallel (DDP)을 사용하는 것이 더 일반적이고 효율적입니다. 이러한 방법은 GPU 간의 통신을 최적화하고 프로그램의 복잡성을 줄여줍니다.

DataParallel

torch.nn.DataParallel은 싱글 머신 내의 여러 GPU에 모델을 복제하여 작업을 분배하고, 결과를 자동으로 합칩니다. 이 방식은 간단한 사용법으로 인기가 있지만, 각 GPU 간의 데이터 복사로 인해 오버헤드가 크고, 배치 사이즈가 커질수록 스케일링 효율이 떨어지는 단점이 있습니다.

DistributedDataParallel (DDP)

torch.nn.parallel.DistributedDataParallel은 멀티프로세스를 사용하여 각 프로세스가 하나의 GPU를 독립적으로 사용하게 합니다. DDP는 네트워크 상에서 데이터를 효율적으로 분배하고, 각 GPU에서 독립적으로 연산을 수행 후 결과를 합치는 과정을 최적화합니다. 이는 대규모 분산 학습 환경에서 더 나은 성능과 확장성을 제공합니다.

코드 예시: DistributedDataParallel 사용

멀티 GPU 환경에서 이미지 프로젝션 작업을 DDP로 구현하는 간단한 예제를 제공하겠습니다. 이 예제는 각 GPU에서 독립적으로 모델을 로드하고 연산을 수행하는 방식을 보여줍니다.

python


import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def project_ddp(rank, world_size, network_pkl, target_dir, num_steps):
    setup(rank, world_size)
    device = torch.device(f'cuda:{rank}')
    
    # 모델과 기타 필요 자원을 device에 로드
    # 예: G = load_model(network_pkl).to(device)
    # DDP로 모델을 래핑
    # G = DDP(G, device_ids=[rank])

    # 이미지 프로젝션 로직 구현
    # 프로세스 종료 후 cleanup 호출
    cleanup()

def run_projection_ddp(network_pkl, target_dir, num_steps, num_gpus):
    mp.spawn(project_ddp,
             args=(num_gpus, network_pkl, target_dir, num_steps),
             nprocs=num_gpus,
             join=True)

if __name__ == "__main__":
    network_pkl = 'path_to_network.pkl'
    target_dir = 'path_to_images'
    num_steps = 150
    num_gpus = torch.cuda.device_count()
    run_projection_ddp(network_pkl, target_dir, num_steps, num_gpus)

이 코드는 각 GPU에서 독립적으로 작업을 수행하도록 설정하며, DDP를 사용하여 각 GPU의 결과를 효율적으로 결합합니다. 이 방식은 복잡한 네트워크 통신을 최적화하고, 멀티 GPU 환경에서 더 높은 성능을 제공합니다.

댓글남기기