[PyTorch] model.eval() ์๋ฏธ
๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ์ฝ๋๋ฅผ ์ดํด๋ณด๋ค ๋ณด๋ฉด Evaluation ๋ถ๋ถ์์ ๊ผญ ์ด๋ฐ ์ฝ๋๊ฐ ๋ฑ์ฅํ๋ค.
def evaluation(model, criterion, ...):
model.eval()
criterion.eval()
...
๋ฌด์จ ์๋ฏธ์ธ์ง ๊ถ๊ธํด์ ์ฐพ์๋ณด๋, nn.Module์์ train time๊ณผ eval time์์ ์ํํ๋ ๋ค๋ฅธ ์์
์ ์ํํ ์ ์๋๋ก switching ํ๋ ํจ์๋ผ๊ณ ํ๋ค. stackoverflow
train time๊ณผ eval time์์ ๋ค๋ฅด๊ฒ ๋์ํด์ผ ํ๋ ๋ํ์ ์ธ ์๋ค์
DropoutLayerBatchNormLayer
๋ฑ๋ฑ์ด ์๋ค๊ณ ํ๋ค.
.eval() ํจ์๋ evaluation ๊ณผ์ ์์ ์ฌ์ฉํ์ง ์์์ผ ํ๋ layer๋ค์ ์์์ off ์ํค๋๋ก ํ๋ ํจ์์ธ ์
์ด๋ค.
evaluation/validation ๊ณผ์ ์์ ๋ณดํต model.eval()๊ณผ torch.no_grad()๋ฅผ ํจ๊ป ์ฌ์ฉํ๋ค๊ณ ํ๋ค.
# evaluate model:
model.eval()
with torch.no_grad():
...
out_data = model(data)
...
eval/val ์์
์ด ๋๋ ํ์๋ ์์ง๋ง๊ณ train mode๋ก ๋ชจ๋ธ์ ๋ณ๊ฒฝํด์ค์ผ ํ๋ค. ์ด๊ฒ์ .train() ํจ์๋ฅผ ์คํ์ํค๋ฉด ๋๋ค.
# after eval/val, and in training step
model.train()
PyTorch ๊ณต์ ๋ฌธ์์์ .eval()์ ๋ํ ์์ธํ ๋ด์ฉ์ ํ์ธํ ์ ์๋ค. nn.Module.eval()