-
[PyTorch/에러] RuntimeError: Error(s) in loading state_dictMachine Learning/PyTorch 2022. 12. 18. 00:55728x90
pretrain 된 모델을 불러와 추가 데이터에 대해 finetuning 할 때 더 좋은 성능을 보여주는 경우가 많습니다.
이때
load_state_dict
을 이용해 모델을 load 하는데,사전에 정의한 모델과 불러오려는 모델의
state_dict - key
가 다를 경우 아래와 같은 에러가 발생할 수 있습니다.RuntimeError: Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.c onv2.weight", "layer1.0.bn2.weight", ... Unexpected key(s) in state_dict: "model.backbone.conv1.weight", "model.backbone.bn1.weight", "model.backbone.bn1.bias", "model.backbone.bn1.running_mean", "model.backbone.bn1.running_var", "model.backbone.bn1.num_batches_tracked", "model.backbone.l ayer1.0.conv1.weight", "model.backbone.layer1.0.bn1.weight", ...
ResNet 예시이며, 불러오려는 모델의 모든 key에
model.backbone.
이 추가로 붙어있습니다.이러한 Unexpected key(s)의 앞부분을 제거하는 방법은 아래와 같습니다.
backbone = timm.create_model(args.model_name, num_classes=0, pretrained=True) checkpoint_path = './pretrained_model/model.ckpt' checkpoint = torch.load(checkpoint_path)["state_dict"] # 1st) for key in list(checkpoint.keys()): if 'model.backbone.' in key: checkpoint[key.replace('model.backbone.','')] = checkpoint[key] del checkpoint[key] # 2nd) ''' for key in list(checkpoint): checkpoint[key.replace("model.backbone.", "")] = checkpoint.pop(key) ''' backbone.load_state_dict(checkpoint, strict=False)
728x90'Machine Learning > PyTorch' 카테고리의 다른 글
[PyTorch] Deep learning with PyTorch - Intro (0) 2022.04.27