블로그 글을 많은 부분 참고했다.
모델 Freeze 하는 방법에 대해 정리한 글이다.
ULMFiT, Adapter, P-tuning 등의 architechture에서 처럼 일부 모델을 freeze 시키고 다른 모델의 일부에 대해서만 paramter update를 하고싶은 경우가 있다. 대표적으로 transfer learning이나 generative adversarial network의 경우가 있다. 이럴 경우에 파이토치에서 사용하는 no_grad 와 requires_grad 에 대한 차이점을 기록한다.
첫번째 경우, 위 그림과 같이 모델 A, B가 있는데, 모델 A는 freeze 시키고 모델 B만 학습시키는 방법입니다. 쉽게 생각하면 B에서 A로 가는 gradient를 막아버리면 되겠죠.
- 예를들면 transfer learning : fine-tuning?
# 첫번째 경우_1
z = A(x)
y = B(z.detach())
위와 같이 detach 함수를 사용하면 z를 static tensor로 만들어 update 할때, z 이전 파라미터에는 gradient가 흐르지 않습니다.
# 첫번째 경우_2
with torch.no_grad():
z = A(x)
y = B(z)
위와 같이 no_grad 함수를 쓰는 방법도 있습니다. no_grad 사용하면 내부에 있는 operation들의 gradient는 계산이 되지 않게 됩니다. gradient가 계산되지 않기 때문에 당연히 모델 A의 파라미터 업데이트가 될 수 없겠죠.
# 첫번째 경우_3
for p in A.parameter():
p.requires_grad = False
z = A(x)
y = B(z)
마지막 방법은 requires_grad를 쓰는 방법입니다. A의 파라미터를 하나씩 불러와서 gradient를 꺼주는 것입니다. 이렇게 하면 A의 parameter들을 상수 취급해주어 업데이트도 되지 않습니다. 이후에 다시 A의 파라미터를 업데이트 해주고 싶다면 requires_grad=True 로 gradient를 켜주면 됩니다.
두번째 경우는, 위 그림처럼 A만 update 시키고 B를 freeze 하는 것입니다. 위에 경우랑 똑같은거 아닌가 싶을 수 있지만 주의할 사항이 있습니다. 위 상황처럼 B로 가는 gradient를 끊어버리면 안된다는 것입니다. 우선 detach 함수는 이 경우 쓰기 힘들고 2, 3번째 경우를 비교해 보도록 하겠습니다.
- 예를들면 adapter
# 두번째 경우_1
z = A(x)
with torch.no_grad():
y = B(z)
# 두번째 경우_2
for p in B.parameter():
p.requires_grad = False
z = A(x)
y = B(z)
두 방법 모두 똑같아 보이지만, 이 중 하나는 작동하고 하나는 작동하지 않습니다. 첫번째 경우는 오류가 발생하는데요, no_grad 와 requires_grad 차이에 있습니다.
- no_grad는 아예 gradient 자체를 계산하지 않겠다는 것입니다. 따라서 B 모델의 gradient 자체를 계산하지 않는다면 자연스럽게 A는 당연히 gradient가 계산될 수 없겠죠.
- requires_grad는 B의 상수 취급을 해서 B 모델의 파라미터가 업데이트 되지는 않지만, 여전히 gradient는 흐를 수 있는 상태입니다. 따라서 두번째 경우는 A 모델이 업데이트 될 수 있습니다.
이 외에도 여러가지 방법이 있을 수 있지만 이번 글에서 no_grad와 requires_grad의 차이를 설명드리고 싶어서 두 방법 위주로 설명했습니다. 다른 방법으로는 zero_grad() 또는 optimizer를 다르게 쓰는 방법 등이 있을 것 같습니다.
'Language > Python' 카테고리의 다른 글
파이썬 자료형 - 정수형, 실수형 (2) | 2023.01.17 |
---|---|
[python] 한글 파일 텍스트 파일로 변환하기 hwp to txt (1) | 2021.12.08 |
댓글