| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import absolute_import |
| | from __future__ import division |
| | from __future__ import print_function |
| | from __future__ import unicode_literals |
| |
|
| | import time |
| |
|
| | import torch |
| |
|
| | from autoattack.fab_projections import projection_linf, projection_l2,\ |
| | projection_l1 |
| |
|
| | DEFAULT_EPS_DICT_BY_NORM = {'Linf': .3, 'L2': 1., 'L1': 5.0} |
| |
|
| |
|
| | class FABAttack(): |
| | """ |
| | Fast Adaptive Boundary Attack (Linf, L2, L1) |
| | https://arxiv.org/abs/1907.02044 |
| | |
| | :param norm: Lp-norm to minimize ('Linf', 'L2', 'L1' supported) |
| | :param n_restarts: number of random restarts |
| | :param n_iter: number of iterations |
| | :param eps: epsilon for the random restarts |
| | :param alpha_max: alpha_max |
| | :param eta: overshooting |
| | :param beta: backward step |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | norm='Linf', |
| | n_restarts=1, |
| | n_iter=100, |
| | eps=None, |
| | alpha_max=0.1, |
| | eta=1.05, |
| | beta=0.9, |
| | loss_fn=None, |
| | verbose=False, |
| | seed=0, |
| | targeted=False, |
| | device=None, |
| | n_target_classes=9): |
| | """ FAB-attack implementation in pytorch """ |
| |
|
| | self.norm = norm |
| | self.n_restarts = n_restarts |
| | self.n_iter = n_iter |
| | self.eps = eps if eps is not None else DEFAULT_EPS_DICT_BY_NORM[norm] |
| | self.alpha_max = alpha_max |
| | self.eta = eta |
| | self.beta = beta |
| | self.targeted = targeted |
| | self.verbose = verbose |
| | self.seed = seed |
| | self.target_class = None |
| | self.device = device |
| | self.n_target_classes = n_target_classes |
| |
|
| | def check_shape(self, x): |
| | return x if len(x.shape) > 0 else x.unsqueeze(0) |
| |
|
| | def _predict_fn(self, x): |
| | raise NotImplementedError("Virtual function.") |
| |
|
| | def _get_predicted_label(self, x): |
| | raise NotImplementedError("Virtual function.") |
| |
|
| | def get_diff_logits_grads_batch(self, imgs, la): |
| | raise NotImplementedError("Virtual function.") |
| |
|
| | def get_diff_logits_grads_batch_targeted(self, imgs, la, la_target): |
| | raise NotImplementedError("Virtual function.") |
| |
|
| | def attack_single_run(self, x, y=None, use_rand_start=False, is_targeted=False): |
| | """ |
| | :param x: clean images |
| | :param y: clean labels, if None we use the predicted labels |
| | :param is_targeted True if we ise targeted version. Targeted class is assigned by `self.target_class` |
| | """ |
| |
|
| | if self.device is None: |
| | self.device = x.device |
| | self.orig_dim = list(x.shape[1:]) |
| | self.ndims = len(self.orig_dim) |
| |
|
| | x = x.detach().clone().float().to(self.device) |
| | |
| |
|
| | y_pred = self._get_predicted_label(x) |
| | if y is None: |
| | y = y_pred.detach().clone().long().to(self.device) |
| | else: |
| | y = y.detach().clone().long().to(self.device) |
| | pred = y_pred == y |
| | corr_classified = pred.float().sum() |
| | if self.verbose: |
| | print('Clean accuracy: {:.2%}'.format(pred.float().mean())) |
| | if pred.sum() == 0: |
| | return x |
| | pred = self.check_shape(pred.nonzero().squeeze()) |
| |
|
| | if is_targeted: |
| | output = self._predict_fn(x) |
| | la_target = output.sort(dim=-1)[1][:, -self.target_class] |
| | la_target2 = la_target[pred].detach().clone() |
| |
|
| | startt = time.time() |
| | |
| | im2 = x[pred].detach().clone() |
| | la2 = y[pred].detach().clone() |
| | if len(im2.shape) == self.ndims: |
| | im2 = im2.unsqueeze(0) |
| | bs = im2.shape[0] |
| | u1 = torch.arange(bs) |
| | adv = im2.clone() |
| | adv_c = x.clone() |
| | res2 = 1e10 * torch.ones([bs]).to(self.device) |
| | x1 = im2.clone() |
| | x0 = im2.clone().reshape([bs, -1]) |
| |
|
| | if use_rand_start: |
| | if self.norm == 'Linf': |
| | t = 2 * torch.rand(x1.shape).to(self.device) - 1 |
| | x1 = im2 + (torch.min(res2, |
| | self.eps * torch.ones(res2.shape) |
| | .to(self.device) |
| | ).reshape([-1, *[1]*self.ndims]) |
| | ) * t / (t.reshape([t.shape[0], -1]).abs() |
| | .max(dim=1, keepdim=True)[0] |
| | .reshape([-1, *[1]*self.ndims])) * .5 |
| | elif self.norm == 'L2': |
| | t = torch.randn(x1.shape).to(self.device) |
| | x1 = im2 + (torch.min(res2, |
| | self.eps * torch.ones(res2.shape) |
| | .to(self.device) |
| | ).reshape([-1, *[1]*self.ndims]) |
| | ) * t / ((t ** 2) |
| | .view(t.shape[0], -1) |
| | .sum(dim=-1) |
| | .sqrt() |
| | .view(t.shape[0], *[1]*self.ndims)) * .5 |
| | elif self.norm == 'L1': |
| | t = torch.randn(x1.shape).to(self.device) |
| | x1 = im2 + (torch.min(res2, |
| | self.eps * torch.ones(res2.shape) |
| | .to(self.device) |
| | ).reshape([-1, *[1]*self.ndims]) |
| | ) * t / (t.abs().view(t.shape[0], -1) |
| | .sum(dim=-1) |
| | .view(t.shape[0], *[1]*self.ndims)) / 2 |
| |
|
| | x1 = x1.clamp(0.0, 1.0) |
| |
|
| | counter_iter = 0 |
| | while counter_iter < self.n_iter: |
| | with torch.no_grad(): |
| | if is_targeted: |
| | df, dg = self.get_diff_logits_grads_batch_targeted(x1, la2, la_target2) |
| | else: |
| | df, dg = self.get_diff_logits_grads_batch(x1, la2) |
| | if self.norm == 'Linf': |
| | dist1 = df.abs() / (1e-12 + |
| | dg.abs() |
| | .reshape(dg.shape[0], dg.shape[1], -1) |
| | .sum(dim=-1)) |
| | elif self.norm == 'L2': |
| | dist1 = df.abs() / (1e-12 + (dg ** 2) |
| | .reshape(dg.shape[0], dg.shape[1], -1) |
| | .sum(dim=-1).sqrt()) |
| | elif self.norm == 'L1': |
| | dist1 = df.abs() / (1e-12 + dg.abs().reshape( |
| | [df.shape[0], df.shape[1], -1]).max(dim=2)[0]) |
| | else: |
| | raise ValueError('norm not supported') |
| | ind = dist1.min(dim=1)[1] |
| | dg2 = dg[u1, ind] |
| | b = (- df[u1, ind] + (dg2 * x1).reshape(x1.shape[0], -1) |
| | .sum(dim=-1)) |
| | w = dg2.reshape([bs, -1]) |
| |
|
| | if self.norm == 'Linf': |
| | d3 = projection_linf( |
| | torch.cat((x1.reshape([bs, -1]), x0), 0), |
| | torch.cat((w, w), 0), |
| | torch.cat((b, b), 0)) |
| | elif self.norm == 'L2': |
| | d3 = projection_l2( |
| | torch.cat((x1.reshape([bs, -1]), x0), 0), |
| | torch.cat((w, w), 0), |
| | torch.cat((b, b), 0)) |
| | elif self.norm == 'L1': |
| | d3 = projection_l1( |
| | torch.cat((x1.reshape([bs, -1]), x0), 0), |
| | torch.cat((w, w), 0), |
| | torch.cat((b, b), 0)) |
| | d1 = torch.reshape(d3[:bs], x1.shape) |
| | d2 = torch.reshape(d3[-bs:], x1.shape) |
| | if self.norm == 'Linf': |
| | a0 = d3.abs().max(dim=1, keepdim=True)[0]\ |
| | .view(-1, *[1]*self.ndims) |
| | elif self.norm == 'L2': |
| | a0 = (d3 ** 2).sum(dim=1, keepdim=True).sqrt()\ |
| | .view(-1, *[1]*self.ndims) |
| | elif self.norm == 'L1': |
| | a0 = d3.abs().sum(dim=1, keepdim=True)\ |
| | .view(-1, *[1]*self.ndims) |
| | a0 = torch.max(a0, 1e-8 * torch.ones( |
| | a0.shape).to(self.device)) |
| | a1 = a0[:bs] |
| | a2 = a0[-bs:] |
| | alpha = torch.min(torch.max(a1 / (a1 + a2), |
| | torch.zeros(a1.shape) |
| | .to(self.device)), |
| | self.alpha_max * torch.ones(a1.shape) |
| | .to(self.device)) |
| | x1 = ((x1 + self.eta * d1) * (1 - alpha) + |
| | (im2 + d2 * self.eta) * alpha).clamp(0.0, 1.0) |
| |
|
| | is_adv = self._get_predicted_label(x1) != la2 |
| |
|
| | if is_adv.sum() > 0: |
| | ind_adv = is_adv.nonzero().squeeze() |
| | ind_adv = self.check_shape(ind_adv) |
| | if self.norm == 'Linf': |
| | t = (x1[ind_adv] - im2[ind_adv]).reshape( |
| | [ind_adv.shape[0], -1]).abs().max(dim=1)[0] |
| | elif self.norm == 'L2': |
| | t = ((x1[ind_adv] - im2[ind_adv]) ** 2)\ |
| | .reshape(ind_adv.shape[0], -1).sum(dim=-1).sqrt() |
| | elif self.norm == 'L1': |
| | t = (x1[ind_adv] - im2[ind_adv])\ |
| | .abs().reshape(ind_adv.shape[0], -1).sum(dim=-1) |
| | adv[ind_adv] = x1[ind_adv] * (t < res2[ind_adv]).\ |
| | float().reshape([-1, *[1]*self.ndims]) + adv[ind_adv]\ |
| | * (t >= res2[ind_adv]).float().reshape( |
| | [-1, *[1]*self.ndims]) |
| | res2[ind_adv] = t * (t < res2[ind_adv]).float()\ |
| | + res2[ind_adv] * (t >= res2[ind_adv]).float() |
| | x1[ind_adv] = im2[ind_adv] + ( |
| | x1[ind_adv] - im2[ind_adv]) * self.beta |
| |
|
| | counter_iter += 1 |
| |
|
| | ind_succ = res2 < 1e10 |
| | if self.verbose: |
| | print('success rate: {:.0f}/{:.0f}' |
| | .format(ind_succ.float().sum(), corr_classified) + |
| | ' (on correctly classified points) in {:.1f} s' |
| | .format(time.time() - startt)) |
| |
|
| | ind_succ = self.check_shape(ind_succ.nonzero().squeeze()) |
| | adv_c[pred[ind_succ]] = adv[ind_succ].clone() |
| |
|
| | return adv_c |
| |
|
| | def perturb(self, x, y): |
| | if self.device is None: |
| | self.device = x.device |
| | adv = x.clone() |
| | with torch.no_grad(): |
| | acc = self._predict_fn(x).max(1)[1] == y |
| |
|
| | startt = time.time() |
| |
|
| | torch.random.manual_seed(self.seed) |
| | torch.cuda.random.manual_seed(self.seed) |
| |
|
| | if not self.targeted: |
| | for counter in range(self.n_restarts): |
| | ind_to_fool = acc.nonzero().squeeze() |
| | if len(ind_to_fool.shape) == 0: ind_to_fool = ind_to_fool.unsqueeze(0) |
| | if ind_to_fool.numel() != 0: |
| | x_to_fool, y_to_fool = x[ind_to_fool].clone(), y[ind_to_fool].clone() |
| | adv_curr = self.attack_single_run(x_to_fool, y_to_fool, use_rand_start=(counter > 0), is_targeted=False) |
| |
|
| | acc_curr = self._predict_fn(adv_curr).max(1)[1] == y_to_fool |
| | if self.norm == 'Linf': |
| | res = (x_to_fool - adv_curr).abs().reshape(x_to_fool.shape[0], -1).max(1)[0] |
| | elif self.norm == 'L2': |
| | res = ((x_to_fool - adv_curr) ** 2).reshape(x_to_fool.shape[0], -1).sum(dim=-1).sqrt() |
| | elif self.norm == 'L1': |
| | res = (x_to_fool - adv_curr).abs().reshape(x_to_fool.shape[0], -1).sum(-1) |
| | acc_curr = torch.max(acc_curr, res > self.eps) |
| |
|
| | ind_curr = (acc_curr == 0).nonzero().squeeze() |
| | acc[ind_to_fool[ind_curr]] = 0 |
| | adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() |
| |
|
| | if self.verbose: |
| | print('restart {} - robust accuracy: {:.2%} at eps = {:.5f} - cum. time: {:.1f} s'.format( |
| | counter, acc.float().mean(), self.eps, time.time() - startt)) |
| |
|
| | else: |
| | for target_class in range(2, self.n_target_classes + 2): |
| | self.target_class = target_class |
| | for counter in range(self.n_restarts): |
| | ind_to_fool = acc.nonzero().squeeze() |
| | if len(ind_to_fool.shape) == 0: ind_to_fool = ind_to_fool.unsqueeze(0) |
| | if ind_to_fool.numel() != 0: |
| | x_to_fool, y_to_fool = x[ind_to_fool].clone(), y[ind_to_fool].clone() |
| | adv_curr = self.attack_single_run(x_to_fool, y_to_fool, use_rand_start=(counter > 0), is_targeted=True) |
| |
|
| | acc_curr = self._predict_fn(adv_curr).max(1)[1] == y_to_fool |
| | if self.norm == 'Linf': |
| | res = (x_to_fool - adv_curr).abs().reshape(x_to_fool.shape[0], -1).max(1)[0] |
| | elif self.norm == 'L2': |
| | res = ((x_to_fool - adv_curr) ** 2).reshape(x_to_fool.shape[0], -1).sum(dim=-1).sqrt() |
| | elif self.norm == 'L1': |
| | res = (x_to_fool - adv_curr).abs().reshape(x_to_fool.shape[0], -1).sum(-1) |
| | acc_curr = torch.max(acc_curr, res > self.eps) |
| |
|
| | ind_curr = (acc_curr == 0).nonzero().squeeze() |
| | acc[ind_to_fool[ind_curr]] = 0 |
| | adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() |
| |
|
| | if self.verbose: |
| | print('restart {} - target_class {} - robust accuracy: {:.2%} at eps = {:.5f} - cum. time: {:.1f} s'.format( |
| | counter, self.target_class, acc.float().mean(), self.eps, time.time() - startt)) |
| |
|
| | return adv |
| |
|