ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [PyTorch/에러] RuntimeError: Error(s) in loading state_dict
    Machine Learning/PyTorch 2022. 12. 18. 00:55
    728x90

     

    pretrain 된 모델을 불러와 추가 데이터에 대해 finetuning 할 때 더 좋은 성능을 보여주는 경우가 많습니다.

    이때  load_state_dict 을 이용해 모델을 load 하는데,

    사전에 정의한 모델과 불러오려는 모델의  state_dict - key 가 다를 경우 아래와 같은 에러가 발생할 수 있습니다.

    RuntimeError: Error(s) in loading state_dict for ResNet:
        Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.c
    onv2.weight", "layer1.0.bn2.weight", ...
    
        Unexpected key(s) in state_dict: "model.backbone.conv1.weight", "model.backbone.bn1.weight", "model.backbone.bn1.bias", "model.backbone.bn1.running_mean", "model.backbone.bn1.running_var", "model.backbone.bn1.num_batches_tracked", "model.backbone.l
    ayer1.0.conv1.weight", "model.backbone.layer1.0.bn1.weight", ...

     

    ResNet 예시이며, 불러오려는 모델의 모든 key에  model.backbone. 이 추가로 붙어있습니다.

    이러한 Unexpected key(s)의 앞부분을 제거하는 방법은 아래와 같습니다.

    backbone = timm.create_model(args.model_name, num_classes=0, pretrained=True)
    checkpoint_path = './pretrained_model/model.ckpt'
    checkpoint = torch.load(checkpoint_path)["state_dict"]
    
    # 1st)
    for key in list(checkpoint.keys()):
        if 'model.backbone.' in key:
            checkpoint[key.replace('model.backbone.','')] = checkpoint[key]
            del checkpoint[key]
    
    # 2nd)
    '''
    for key in list(checkpoint):
         checkpoint[key.replace("model.backbone.", "")] = checkpoint.pop(key)
    '''
    
    backbone.load_state_dict(checkpoint, strict=False)

     

     

     

    728x90

    'Machine Learning > PyTorch' 카테고리의 다른 글

    [PyTorch] Deep learning with PyTorch - Intro  (0) 2022.04.27

    댓글

Designed by Tistory.