-
[Loss Landscape 시각화] PyHessian: Neural Networks Through the Lens of the HessianMachine Learning/Foundation 2022. 4. 25. 16:19728x90
Deep neural networks의 Hessian(second-order derivative) information을 빠르게 계산할 수 있는
PyHessian에 대한 간략한 소개 및 이를 활용한 Loss Landscape 시각화 방법을 설명하겠습니다.
Hessian 정보를 활용하면 DNN 모델의 중요한 weights에 perturbation을 주어 Loss landscape을 그릴 수 있습니다.
예를 들어, 저희는 [ paper / code ]를 참고해 loss landscape을 3차원으로 그려
새롭게 제안한 방법이 이전 method 및 모델과 비교해 wide 한 landscape을 형성함을 강조했습니다.
[ CPR: Classifier-Projection Regularization for Continual Learning, ICLR 2021 ]
모델의 Loss landscape을 시각화하고 싶거나, 새롭게 개발한 모델이 이전 모델과 비교해 더 wide(or flat) local minima를 형성함을 검증하고 싶을 때 PyHessian 활용을 추천드립니다.
Outline
PyHessian의 저자들은 딥뉴럴넷에서 많은 계산량을 요구하는 Hessian information (the top Hessian eignevalues, the Hessian trace 등)을 빠르게 구할 수 있는 framework을 제안했습니다. (explicit 한 값이 아닌 approximation) 논문에서는 이렇게 구한 Hessian information을 NNs를 분석하는 데 사용했으며, 이번 포스트에서는 PyHessian을 활용해 3차원 Loss landscape을 시각화 하는 데 초점을 맞춰 설명 하겠습니다. [ code 참고 ]
Loss Landscape Visualization
1. Method
loss landscape을 그리는 원리는 model의 parameter(=weights)에 약간의 perturbation(weights에 작은 값을 더하거나 빼서 원래 값에 변형을 주는 과정)를 주고, perturbation이 커짐에 따라 loss 값이 얼마나 달라지는지 시각화하는 것입니다. 이때 고려해야 할 것은 'weights 중 어느 element에 얼마만큼의 perturbation을 줄 것인가' 입니다.
저자는 github에서 3개의 예시를 보여주고 있습니다. [ code ]
제일 왼쪽 그림은 모델 parameter에 random direction(value)으로 perturb을 준 것이고, 이때 loss를 뜻하는 y축을 보면 다른 그림들에 비해 loss의 변화가 매우 작은 것을 확인할 수 있습니다. parameter가 변화했음에도 loss가 크게 변하지 않는 것은 모델이 예측하는데 많은 영향을 주는 중요한 parameter에 큰 변화를 주지 못했다고 해석할 수 있습니다.
이에 반해 gradient direction과 top Hessian eigenvector direction으로 perturb을 줄 경우 loss 변화가 크고, 이는 예측에 영향을 많이 끼치는 중요한 parameter를 잘 선택했다고 해석할 수 있습니다. 이때 gradient direction은 하나밖에 없기 때문에 3D loss landscape를 그릴 수 없습니다. 하지만 top Hessian eigenvector의 경우 top1, top2 두 개를 구할 수 있고, 이를 이용해 3차원 시각화가 가능합니다.
2D loss landscape에 대한 설명은 저자의 github에 자세히 나와있지만, 3D 시각화 방법에 대한 자세한 설명이 없기에 이번 포스트에서는 이를 자세히 다룹니다.
2. Practice [ code ]
Cloning the PyHessian
git clone https://github.com/amirgholami/PyHessian.git
PyHessian 디렉토리 안에서 새로운 ipython 파일을 생성하고, 필요한 라이브러리를 import
import numpy as np import torch from torchvision import datasets, transforms from utils import * # get the dataset from pyhessian import hessian # Hessian computation from density_plot import get_esd_plot # ESD plot from pytorchcv.model_provider import get_model as ptcv_get_model # model import matplotlib.pyplot as plt %matplotlib inline
더보기## ModuleNotFoundError: No module named 'pytorchcv' ## 위와 같은 에러가 뜰 경우 아래 명령 실행 !pip install pytorchcv
# enable cuda devices import os os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"]="0"
model, dataset 불러오기 (예시 : ResNet, CIFAR-10)
# get the model model = ptcv_get_model("resnet20_cifar10", pretrained=True) # change the model to eval mode to disable running stats upate model.eval() # create loss function criterion = torch.nn.CrossEntropyLoss() # get dataset train_loader, test_loader = getData() # for illustrate, we only use one batch to do the tutorial for inputs, targets in train_loader: break # we use cuda to make the computation fast model = model.cuda() inputs, targets = inputs.cuda(), targets.cuda()
Hessian computation module 생성
# create the hessian computation module hessian_comp = hessian(model, criterion, data=(inputs, targets), cuda=True)
3차원으로 그리기 위해 2개의 top eigenvector 얻음
(주의 : tutorial에서는 2차원으로 그리기 때문에 top_n 설정이 생략)# get the top eigenvector top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues(top_n=2)
perturbation 설정(lambda) 및 그에 따른 Loss 계산
# lambda is a small scalar that we use to perturb the model parameters along the eigenvectors lams1 = np.linspace(-0.5, 0.5, 21).astype(np.float32) lams2 = np.linspace(-0.5, 0.5, 21).astype(np.float32) loss_list = [] # create a copy of the model model_perb1 = ptcv_get_model("resnet20_cifar10", pretrained=True) model_perb1.eval() model_perb1 = model_perb1.cuda() model_perb2 = ptcv_get_model("resnet20_cifar10", pretrained=True) model_perb2.eval() model_perb2 = model_perb2.cuda() for lam1 in lams1: for lam2 in lams2: model_perb1 = get_params(model, model_perb1, top_eigenvector[0], lam1) model_perb2 = get_params(model_perb1, model_perb2, top_eigenvector[1], lam2) loss_list.append((lam1, lam2, criterion(model_perb2(inputs), targets).item())) loss_list = np.array(loss_list)
matplotlib 활용한 시각화
fig = plt.figure() landscape = fig.gca(projection='3d') landscape.plot_trisurf(loss_list[:,0], loss_list[:,1], loss_list[:,2],alpha=0.8, cmap='viridis') #cmap=cm.autumn, cmamp = 'hot') landscape.set_title('Loss Landscape') landscape.set_xlabel('ε_1') landscape.set_ylabel('ε_2') landscape.set_zlabel('Loss') ## 시각화 원하는 각도 및 거리를 설정 #landscape.view_init(elev=15, azim=75) landscape.dist = 6 plt.show()
위의 코드를 활용해 2개의 모델을 겹쳐서 어느 모델이 더 flat/sharp 한지 비교할 수 있습니다.
2개 모델을 함께 표현하는 코드는 github에 공유하겠습니다.728x90