30 lines
854 B
Python
30 lines
854 B
Python
import torch
|
||
from torch import nn
|
||
from utils import *
|
||
|
||
|
||
class My_Dropout(nn.Module):
|
||
def __init__(self, p, **kwargs):
|
||
super().__init__()
|
||
self.p = p
|
||
self.mask = None
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
if self.training:
|
||
self.mask = (torch.rand(x.shape) > self.p).to(dtype=torch.float32, device=x.device)
|
||
return x * self.mask / (1 - self.p)
|
||
else:
|
||
return x
|
||
|
||
|
||
if __name__ == "__main__":
|
||
my_dropout = My_Dropout(p=0.5)
|
||
nn_dropout = nn.Dropout(p=0.5)
|
||
x = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0],
|
||
[6.0, 7.0, 8.0, 9.0, 10.0]])
|
||
print(f"输入:\n{x}")
|
||
output_my_dropout = my_dropout(x)
|
||
output_nn_dropout = nn_dropout(x)
|
||
print(f"My_Dropout输出:\n{output_my_dropout}")
|
||
print(f"nn.Dropout输出:\n{output_nn_dropout}")
|