Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import math | |
| class Quantizer(nn.Module): | |
| def __init__(self, shape=1): | |
| super(Quantizer, self).__init__() | |
| self.register_buffer('maxq', torch.tensor(0)) | |
| self.register_buffer('scale', torch.zeros(shape)) | |
| self.register_buffer('zero', torch.zeros(shape)) | |
| def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False): | |
| self.maxq = torch.tensor(2**bits - 1) | |
| self.perchannel = perchannel | |
| self.sym = sym | |
| self.mse = mse | |
| self.norm = norm | |
| self.grid = grid | |
| self.maxshrink = maxshrink | |
| if trits: | |
| self.maxq = torch.tensor(-1) | |
| self.scale = torch.zeros_like(self.scale) | |
| def _quantize(self, x, scale, zero, maxq): | |
| if maxq < 0: | |
| return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero | |
| q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) | |
| return scale * (q - zero) | |
| def find_params(self, x, weight=False): | |
| dev = x.device | |
| self.maxq = self.maxq.to(dev) | |
| shape = x.shape | |
| if self.perchannel: | |
| if weight: | |
| x = x.flatten(1) | |
| else: | |
| if len(shape) == 4: | |
| x = x.permute([1, 0, 2, 3]) | |
| x = x.flatten(1) | |
| if len(shape) == 3: | |
| x = x.reshape((-1, shape[-1])).t() | |
| if len(shape) == 2: | |
| x = x.t() | |
| else: | |
| x = x.flatten().unsqueeze(0) | |
| tmp = torch.zeros(x.shape[0], device=dev) | |
| xmin = torch.minimum(x.min(1)[0], tmp) | |
| xmax = torch.maximum(x.max(1)[0], tmp) | |
| if self.sym: | |
| xmax = torch.maximum(torch.abs(xmin), xmax) | |
| tmp = xmin < 0 | |
| if torch.any(tmp): | |
| xmin[tmp] = -xmax[tmp] | |
| tmp = (xmin == 0) & (xmax == 0) | |
| xmin[tmp] = -1 | |
| xmax[tmp] = +1 | |
| if self.maxq < 0: | |
| self.scale = xmax | |
| self.zero = xmin | |
| else: | |
| self.scale = (xmax - xmin) / self.maxq | |
| if self.sym: | |
| self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) | |
| else: | |
| self.zero = torch.round(-xmin / self.scale) | |
| if self.mse: | |
| best = torch.full([x.shape[0]], float('inf'), device=dev) | |
| for i in range(int(self.maxshrink * self.grid)): | |
| p = 1 - i / self.grid | |
| xmin1 = p * xmin | |
| xmax1 = p * xmax | |
| scale1 = (xmax1 - xmin1) / self.maxq | |
| zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero | |
| q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) | |
| q -= x | |
| q.abs_() | |
| q.pow_(self.norm) | |
| err = torch.sum(q, 1) | |
| tmp = err < best | |
| if torch.any(tmp): | |
| best[tmp] = err[tmp] | |
| self.scale[tmp] = scale1[tmp] | |
| self.zero[tmp] = zero1[tmp] | |
| if not self.perchannel: | |
| if weight: | |
| tmp = shape[0] | |
| else: | |
| tmp = shape[1] if len(shape) != 3 else shape[2] | |
| self.scale = self.scale.repeat(tmp) | |
| self.zero = self.zero.repeat(tmp) | |
| if weight: | |
| shape = [-1] + [1] * (len(shape) - 1) | |
| self.scale = self.scale.reshape(shape) | |
| self.zero = self.zero.reshape(shape) | |
| return | |
| if len(shape) == 4: | |
| self.scale = self.scale.reshape((1, -1, 1, 1)) | |
| self.zero = self.zero.reshape((1, -1, 1, 1)) | |
| if len(shape) == 3: | |
| self.scale = self.scale.reshape((1, 1, -1)) | |
| self.zero = self.zero.reshape((1, 1, -1)) | |
| if len(shape) == 2: | |
| self.scale = self.scale.unsqueeze(0) | |
| self.zero = self.zero.unsqueeze(0) | |
| def quantize(self, x): | |
| if self.ready(): | |
| return self._quantize(x, self.scale, self.zero, self.maxq) | |
| return x | |
| def enabled(self): | |
| return self.maxq > 0 | |
| def ready(self): | |
| return torch.all(self.scale != 0) | |