Deep_Learning/Lab5/utils.py

176 lines
5.1 KiB
Python

import math
import torch
from torch.utils import data
import torch.nn as nn
from matplotlib import pyplot as plt
import numpy as np
import time
def mse_fn(y, pred):
return np.mean((np.array(y) - np.array(pred)) ** 2)
def mae_fn(y, pred):
return np.mean(np.abs(np.array(y) - np.array(pred)))
def mape_fn(y, pred):
mask = y != 0
y = y[mask]
pred = pred[mask]
mape = np.abs((y - pred) / y)
mape = np.mean(mape) * 100
return mape
def eval(y, pred):
y = y.cpu().numpy()
pred = pred.cpu().numpy()
mse = mse_fn(y, pred)
rmse = math.sqrt(mse)
mae = mae_fn(y, pred)
mape = mape_fn(y, pred)
return [rmse, mae, mape]
# 测试函数(用于分类)
def test(net, data_iter, loss_fn, denormalize_fn, device='cpu'):
rmse, mae, mape = 0, 0, 0
batch_count = 0
total_loss = 0.0
net.eval()
for seqs, targets in data_iter:
seqs = seqs.to(device).float()
targets = targets.to(device).float()
y_hat = net(seqs)
loss = loss_fn(y_hat, targets)
targets = denormalize_fn(targets)
y_hat = denormalize_fn(y_hat)
a, b, c = eval(targets.detach(), y_hat.detach())
rmse += a
mae += b
mape += c
total_loss += loss.detach().cpu().numpy().tolist()
batch_count += 1
return [rmse / batch_count, mae / batch_count, mape / batch_count], total_loss / batch_count
def train(net, train_iter, val_iter, test_iter, loss_fn, denormalize_fn, optimizer, num_epoch,
early_stop=10, device='cpu', num_print_epoch_round=0):
train_loss_lst = []
val_loss_lst = []
train_score_lst = []
val_score_lst = []
epoch_time = []
best_epoch = 0
best_val_rmse = 9999
early_stop_flag = 0
for epoch in range(num_epoch):
net.train()
epoch_loss = 0
batch_count = 0
batch_time = []
rmse, mae, mape = 0, 0, 0
for seqs, targets in train_iter:
batch_s = time.time()
seqs = seqs.to(device).float()
targets = targets.to(device).float()
optimizer.zero_grad()
y_hat = net(seqs)
loss = loss_fn(y_hat, targets)
loss.backward()
optimizer.step()
targets = denormalize_fn(targets)
y_hat = denormalize_fn(y_hat)
a, b, c = eval(targets.detach(), y_hat.detach())
rmse += a
mae += b
mape += c
epoch_loss += loss.detach().cpu().numpy().tolist()
batch_count += 1
batch_time.append(time.time() - batch_s)
train_loss = epoch_loss / batch_count
train_loss_lst.append(train_loss)
train_score_lst.append([rmse/batch_count, mae/batch_count, mape/batch_count])
# 验证集
val_score, val_loss = test(net, val_iter, loss_fn, denormalize_fn, device)
val_score_lst.append(val_score)
val_loss_lst.append(val_loss)
epoch_time.append(np.array(batch_time).sum())
# 打印本轮训练结果
if num_print_epoch_round > 0 and (epoch+1) % num_print_epoch_round == 0:
print(
f"Epoch [{epoch + 1}/{num_epoch}],",
f"Train Loss: {train_loss:.4f},",
f"Train RMSE: {train_score_lst[-1][0]:.4f},",
f"Val Loss: {val_loss:.4f},",
f"Val RMSE: {val_score[0]:.6f},",
f"Time Use: {epoch_time[-1]:.3f}s"
)
# 早停
if val_score[0] < best_val_rmse:
best_val_rmse = val_score[0]
best_epoch = epoch
early_stop_flag = 0
else:
early_stop_flag += 1
if early_stop_flag == early_stop:
print(f'The model has not been improved for {early_stop} rounds. Stop early!')
break
# 输出最终训练结果
print(
f'Final result:',
f'Get best validation rmse {np.array(val_score_lst)[:, 0].min():.4f} at epoch {best_epoch},',
f'Total time {np.array(epoch_time).sum():.2f}s'
)
# 计算测试集效果
test_score, test_loss = test(net, test_iter, loss_fn, denormalize_fn, device)
print(
'Test result:',
f'Test RMSE: {test_score[0]},',
f'Test MAE: {test_score[1]},',
f'Test MAPE: {test_score[2]}'
)
return train_loss_lst, val_loss_lst, train_score_lst, val_score_lst, epoch
def visualize(num_epochs, train_data, test_data, x_label='epoch', y_label='loss'):
x = np.arange(0, num_epochs + 1).astype(dtype=np.int32)
plt.figure(figsize=(5, 3.5))
plt.plot(x, train_data, label=f"train_{y_label}", linewidth=1.5)
plt.plot(x, test_data, label=f"val_{y_label}", linewidth=1.5)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.legend()
plt.show()
def plot_metric(score_log):
score_log = np.array(score_log)
plt.figure(figsize=(13, 3.5))
plt.subplot(1, 3, 1)
plt.plot(score_log[:, 0], c='#d28ad4')
plt.ylabel('RMSE')
plt.subplot(1, 3, 2)
plt.plot(score_log[:, 1], c='#e765eb')
plt.ylabel('MAE')
plt.subplot(1, 3, 3)
plt.plot(score_log[:, 2], c='#6b016d')
plt.ylabel('MAPE')
plt.show()