# Copyright (C) 2017-2022 Cleanlab Inc.# This file is part of cleanlab.# # cleanlab is free software: you can redistribute it and/or modify# it under the terms of the GNU Affero General Public License as published# by the Free Software Foundation, either version 3 of the License, or# (at your option) any later version.# # cleanlab is distributed in the hope that it will be useful,# but WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the# GNU Affero General Public License for more details.# # You should have received a copy of the GNU Affero General Public License# along with cleanlab. If not, see <https://www.gnu.org/licenses/>."""A PyTorch CNN for training CIFAR-10 using Co-Teaching.Code adapted from: https://github.com/bhanML/Co-teaching/blob/master/model.pyThis code requires you have PyTorch installedSee: https://pytorch.org/get-started/locally/"""# Python 2 and 3 compatibilityfrom__future__import(print_function,absolute_import,division,unicode_literals,with_statement)importtorch.nnasnnimporttorch.nn.functionalasF
[docs]classCNN(nn.Module):"""A CNN architecture shown to be a good baseline for a CIFAR-10 benchmark. Parameters ---------- input_channel : int n_outputs : int dropout_rate : float top_bn : bool Methods ------- forward forward pass in PyTorch"""def__init__(self,input_channel=3,n_outputs=10,dropout_rate=0.25,top_bn=False):self.dropout_rate=dropout_rateself.top_bn=top_bnsuper(CNN,self).__init__()self.c1=nn.Conv2d(input_channel,128,kernel_size=3,stride=1,padding=1)self.c2=nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1)self.c3=nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1)self.c4=nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1)self.c5=nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1)self.c6=nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1)self.c7=nn.Conv2d(256,512,kernel_size=3,stride=1,padding=0)self.c8=nn.Conv2d(512,256,kernel_size=3,stride=1,padding=0)self.c9=nn.Conv2d(256,128,kernel_size=3,stride=1,padding=0)self.l_c1=nn.Linear(128,n_outputs)self.bn1=nn.BatchNorm2d(128)self.bn2=nn.BatchNorm2d(128)self.bn3=nn.BatchNorm2d(128)self.bn4=nn.BatchNorm2d(256)self.bn5=nn.BatchNorm2d(256)self.bn6=nn.BatchNorm2d(256)self.bn7=nn.BatchNorm2d(512)self.bn8=nn.BatchNorm2d(256)self.bn9=nn.BatchNorm2d(128)