[PyTorch] model.eval()
์๋ฏธ
๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ์ฝ๋๋ฅผ ์ดํด๋ณด๋ค ๋ณด๋ฉด Evaluation ๋ถ๋ถ์์ ๊ผญ ์ด๋ฐ ์ฝ๋๊ฐ ๋ฑ์ฅํ๋ค.
def evaluation(model, criterion, ...):
model.eval()
criterion.eval()
...
๋ฌด์จ ์๋ฏธ์ธ์ง ๊ถ๊ธํด์ ์ฐพ์๋ณด๋, nn.Module
์์ train time๊ณผ eval time์์ ์ํํ๋ ๋ค๋ฅธ ์์
์ ์ํํ ์ ์๋๋ก switching ํ๋ ํจ์๋ผ๊ณ ํ๋ค. stackoverflow
train time๊ณผ eval time์์ ๋ค๋ฅด๊ฒ ๋์ํด์ผ ํ๋ ๋ํ์ ์ธ ์๋ค์
Dropout
LayerBatchNorm
Layer
๋ฑ๋ฑ์ด ์๋ค๊ณ ํ๋ค.
.eval()
ํจ์๋ evaluation ๊ณผ์ ์์ ์ฌ์ฉํ์ง ์์์ผ ํ๋ layer๋ค์ ์์์ off ์ํค๋๋ก ํ๋ ํจ์์ธ ์
์ด๋ค.
evaluation/validation ๊ณผ์ ์์ ๋ณดํต model.eval()
๊ณผ torch.no_grad()
๋ฅผ ํจ๊ป ์ฌ์ฉํ๋ค๊ณ ํ๋ค.
# evaluate model:
model.eval()
with torch.no_grad():
...
out_data = model(data)
...
eval/val ์์
์ด ๋๋ ํ์๋ ์์ง๋ง๊ณ train mode๋ก ๋ชจ๋ธ์ ๋ณ๊ฒฝํด์ค์ผ ํ๋ค. ์ด๊ฒ์ .train()
ํจ์๋ฅผ ์คํ์ํค๋ฉด ๋๋ค.
# after eval/val, and in training step
model.train()
PyTorch ๊ณต์ ๋ฌธ์์์ .eval()
์ ๋ํ ์์ธํ ๋ด์ฉ์ ํ์ธํ ์ ์๋ค. nn.Module.eval()