\n", "_Bad news_ : the regularization is not differentiable near the optimum, regular gradient descent won't do.\n", "\n", "-----" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the sake of clarity, let us distinguish two functions in this objective:\n", "$$ f : \\mathbf{M} \\mapsto \\mathbb{E}_X \\> \\frac{1}{2} \\lVert \\mathbf{W} X - \\mathbf{M} X \\rVert_2^2\n", "\\quad \\text{and} \\quad\n", "g : \\mathbf{M} \\mapsto \\lVert \\mathbf{M} \\rVert_{2,1} $$\n", "\n", "$$ \\mathcal{L} = f + \\lambda \\cdot g $$\n", "\n", "We will switch to arbitrary output reconstruction with the change of variable $Y = \\mathbf{W} X$

\n", "And adopt an empirical point of view with $\\mathbf{X} \\in \\mathbb{R}^{d \\times n}$ the concatenated samples\n", "instead of the expectation in $X$\n", "\n", "$$f(\\mathbf{M}) = \\frac{1}{2n} \\operatorname{Tr}\\left[(\\mathbf{Y} - \\mathbf{M} \\mathbf{X}) (\\mathbf{Y} - \\mathbf{M} \\mathbf{X})^T\\right]$$\n", "\n", "$f$ is differentiable, and can be written with the correlation matrices $R_{x,x} = \\frac{1}{n} \\mathbf{X} \\mathbf{X}^T\\>$ and $\\> R_{y,x} = \\frac{1}{n} \\mathbf{Y} \\mathbf{X}^T$. Derive its gradient and Hessian.\n", "\n", "Write $f$ and its gradient as a function of those matrices. Note how the dimensions of the matrices involved\n", "no longer depend on the number of samples $n$. This means that once the correlations are computed, we can\n", "solve everything in memory, without IO calls to read data from disk or having to pass it through all the previous\n", "layers (even for arbitrarily deep networks).\n", "\n", "-----\n", "\n", "$g$ is not differentiable. The trick we will use to optimize it is to use its associated proximal operator instead:\n", "\n", "$$\\operatorname{prox}_{\\>t \\cdot g} : \\mathbf{M} \\mapsto \\operatorname{argmin}_\\mathbf{U} \\left( t \\cdot g(\\mathbf{U}) + \\frac{1}{2} \\lVert \\mathbf{U} - \\mathbf{M} \\rVert^2_F \\right)$$\n", "\n", "Which in our case is given in closed form by:\n", "\n", "$$\n", "\\forall t \\in \\mathbb{R}^+ \\quad \\left( \\operatorname{prox}_{\\>t \\cdot g}(\\mathbf{M})\\right)_{i,j} =\n", "\\mathbf{M}_{i,j} \\cdot \\left(1 - \\dfrac{t}{\\sqrt{ \\sum_k \\mathbf{M}_{k,j}^2 }}\\right)\n", "\\>\\> \\text{if} \\>\\> \\sqrt{ \\sum_k \\mathbf{M}_{k,j}^2 } > t \\>,\\>\\> \\text{and} \\>\\>\n", "0\n", "\\>\\>\\text{otherwise}\n", "$$" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# R_xx : tensor(d, d), R_yx : tensor(h, d), m : tensor(h, d) -> f(m) : tensor(1)\n", "def f(r_xx, r_yx, m):\n", " # TODO: Implement f\n", " pass\n", "\n", "# R_xx : tensor(d, d), R_yx : tensor(h, d), m : tensor(h, d) -> grad_f(m) : tensor(h, d)\n", "def grad_f(r_xx, r_yx, m):\n", " # TODO: Implement the gradient of f\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# R_xx : tensor(d, d), R_yx : tensor(h, d), m : tensor(h, d) -> g(m) : tensor(1)\n", "def g(m):\n", " # TODO: Implement g\n", " pass\n", "\n", "# t : float, m : tensor(h, d) -> prox_{tg}(m) : tensor(h, d)\n", "def prox_g(t, m):\n", " # TODO : Implement the proximal operator for g\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.A - Slowly but surely : Forward-Backward Descent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The forward-backward algorithm (also known as \"proximal gradient descent\" in this particular case), is a descent method used to minimize the sum of a smooth function and simple function (i.e. prox easily computable). It gives an $\\varepsilon$-approximation in $\\mathcal{O}(\\frac{1}{\\varepsilon})$ iterations, and has the advantage of being a descent method, so mostly all stopping conditions will do the trick." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$ M_{k+1} = \\operatorname{prox}_{(\\lambda / L) \\,\\cdot\\, g} \\left( M_k - \\frac{1}{L} \\nabla f (M_k) \\right)\n", "\\quad \\text{where} \\>\\> \\nabla^2 f \\preccurlyeq L $$" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# R_xx : tensor(d, d), R_yx : tensor(h, d), lam : float, m : tensor(h, d)\n", "# -> m^* : tensor(h, d), loss_hist : list(float)\n", "def __proximal_gradient_descent(r_xx, r_yx, lam, m):\n", " # TODO: Implement proximal gradient descent\n", " pass\n", "\n", "def proximal_gradient_descent(r_xx, r_yx, lam):\n", " m0 = torch.zeros_like(r_yx)\n", " m, loss_hist = __proximal_gradient_descent(r_xx, r_yx, lam, m0)\n", " return m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_, hist = __proximal_gradient_descent(r_xx, r_yx, part1_test_lam, torch.zeros_like(r_yx))\n", "plt.plot(np.log10(np.array(hist) - np.min(hist) + 1e-12))\n", "print(hist[-1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.B - Bouncing faster : Fast Iterative Soft-Thresholding Algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"A Fast Iterative Soft-Thresholding Algorithm\" [Beck & Teboulle 2009] presents a way to obtain an $\\varepsilon$-approximation in only $\\mathcal{O}(\\frac{1}{\\sqrt \\varepsilon})$ iterations, by performing the gradient step with respect to a look-ahead variable ($\\mathbf{A}$) instead of the current estimate of the solution ($\\mathbf{M}$).\n", "\n", "The counterpart is that this is no longer a descent method.

\n", "What would be a reasonable stopping condition for this algorithm ?\n", "Implement it and plot $\\> k \\mapsto \\log \\left(\\mathcal{L}(\\mathbf{M_k}) - \\mathcal{L}(\\mathbf{M}^*)\\right)\\>$. You should see it bounce." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$ A_k = M_k + \\dfrac{k}{k + q + 1} \\cdot (M_k - M_{k-1}) \\quad \\text{where} \\>\\> q \\in \\mathbb{R}^*_+$$\n", "\n", "$$ M_{k+1} = \\operatorname{prox}_{(\\lambda / L) \\,\\cdot\\, g} \\left( A_k - \\frac{1}{L} \\nabla f (A_k) \\right)\n", "\\quad \\text{where} \\>\\> \\nabla^2 f \\preccurlyeq L $$" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# R_xx : tensor(d, d), R_yx : tensor(h, d), lam : float, q : float, m : tensor(h, d)\n", "# -> m^* : tensor(h, d), loss_hist : list(float)\n", "def __fista(r_xx, r_yx, lam, q, m):\n", " # TODO: Implement FISTA\n", " pass\n", "\n", "def fista(r_xx, r_yx, lam, q=10):\n", " m0 = torch.zeros_like(r_yx)\n", " m, loss_hist = __fista(r_xx, r_yx, lam, q, m0)\n", " return m" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The look-ahead distance $u$ is parameterized by a positive real value $q$. The bound on the rate of convergence\n", "holds for all $q$, but different values will generate different behaviors. Higher values diminish\n", "the amplitude of the bouncing effect, but slow down convergence during the first few iterations.\n", "\n", "The ideal $q$ is hard to determine because it depends on the number of iterations that will be performed,\n", "which depends itself on the target precision and the value of the hyperparameter chosen. Smaller values of\n", "the hyperparameter will tend to require more iterations, which suggests we should aim for a relatively low\n", "value of $q$, but raising it a little will have a smoothing effect that will be beneficial for our case because\n", "it will allow the stopping condition to fire earlier when the objective reaches good precision.\n", "\n", "Experiment with different values of $q$ to observe these behaviors." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for q in [ 1, 5, 50, 250 ]:\n", " _, hist = __fista(r_xx, r_yx, part1_test_lam, q, torch.zeros_like(r_yx) + .01)\n", " plt.plot(np.log10(np.array(hist) - np.min(hist) + 1e-12), label=f\"q = {q}\")\n", " print(\"q = \", q, \":\", hist[-1])\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.C - Straight to Hell : Alternate Minimization + FISTA" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The problem with FISTA is the slow movement of the estimate solution, which requires a lot of iterations\n", "to go all the way from the initial point $\\mathbf{M}_0$ to the optimal solution $\\mathbf{M}^*$.\n", "The exact convergence bound for FISTA is\n", "\n", "$$ \\mathcal{L}(\\mathbf{M}_k) - \\mathcal{L}(\\mathbf{M}^*) \\leq \\dfrac{2 L}{k^2} \\lVert \\mathbf{M}_0 - \\mathbf{M}^* \\rVert_F^2 $$\n", "\n", "If we were to give an initial point $\\mathbf{M}_0$ not too far from the optimum, the convergence to the exact optimum would be very fast.\n", "\n", "------" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to compute such an initial point, we will use the variational expression of the $\\ell_1$-norm\n", "\n", "$$ \\forall w \\in \\mathbb{R}, \\quad |w| = \\min_{\\eta > 0} \\frac{w^2}{2 \\eta} + \\frac{\\eta}{2} $$\n", "\n", "Which extends to the $\\ell_{2,1}$ norm we are interested in\n", "\n", "$$ \\lVert \\mathbf{M} \\rVert_{2,1} = \\sum_{j

\n", "Would you rather spend your computation budget on alternate minimization or FISTA ? Choose a stopping condition accordingly" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# R_xx : tensor(d, d), R_yx : tensor(h, d), lam : float, m0 : tensor(h, d)\n", "# -> m : tensor(h, d), loss_hist : list(float)\n", "def __alternate_minimization(r_xx, r_yx, lam, m):\n", " # TODO: Implement alternate minimization\n", " pass\n", "\n", "def alternate_minimization(r_xx, r_yx, lam, _d=10):\n", " m = torch.zeros_like(r_yx)\n", " m, _ = __alternate_minimization(r_xx, r_yx, lam, m)\n", " m, _ = __fista(r_xx, r_yx, lam, _d, m)\n", " return m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m = torch.zeros_like(r_yx)\n", "m, hist1 = __alternate_minimization(r_xx, r_yx, part1_test_lam, m)\n", "m, hist2 = __fista(r_xx, r_yx, part1_test_lam, 10, m)\n", "hist = hist1 + hist2\n", "\n", "plt.plot(np.log10(np.array(hist) - np.min(hist) + 1e-12))\n", "print(f(r_xx, r_yx, m) + part1_test_lam * g(m))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation & Checkpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Choose your prefered method as default for the following section\n", "DEFAULT_METHOD = 'alternate_minimization'\n", "\n", "def solve_linear_reconstruction(r_xx, r_yx, lam, method=DEFAULT_METHOD):\n", " try:\n", " return {\n", " 'proximal_gradient_descent': proximal_gradient_descent,\n", " 'fista': fista,\n", " 'alternate_minimization': alternate_minimization,\n", " }[method](r_xx, r_yx, lam)\n", " except KeyError:\n", " raise ValueError(f\"No such method for linear reconstruction: '{method}'\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "d, r, h, n = 500, 450, 200, 60000\n", "x = rand_matrix(d, r) @ rand_matrix(r, n)\n", "y = rand_matrix(h, d) @ x\n", "\n", "r_xx = (x / d) @ (x.T / n)\n", "r_yx = (y / d) @ (x.T / n)\n", "\n", "# %timeit -r5 solve_linear_reconstruction(r_xx, r_yx, .1, 'proximal_gradient_descent') # maybe don't\n", "%timeit -r5 solve_linear_reconstruction(r_xx, r_yx, .1, 'fista')\n", "%timeit -r5 solve_linear_reconstruction(r_xx, r_yx, .1, 'alternate_minimization')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Part 2 - Network compression via linear activation reconstruction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup and boilerplate code\n", "\n", "This is just the setup for training and testing an MNIST model.\n", "You probably have done this often enough by now to be able to write it on your own, but this is not\n", "the focus of this practical session." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "normalizer = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,)) ])\n", "mnist_get = lambda train: datasets.MNIST('data', train=train, download=True, transform=normalizer)\n", "mnist_train, mnist_test = mnist_get(train=True), mnist_get(train=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = 'cpu' # change to 'cuda:0' if you have a GPU, or whatever your processing unit is named\n", "dataloader_settings = { 'batch_size': 256, 'shuffle': True }" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def dataset_iterator(dataset, model, action):\n", " action_results = []\n", " for data, target in torch.utils.data.DataLoader(dataset, **dataloader_settings):\n", " data, target = data.to(device), target.to(device)\n", " action_results.append(action(model(data), target))\n", " return action_results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def count_correct(output, target):\n", " pred = output.argmax(dim=1, keepdim=True)\n", " correct = pred.eq(target.view_as(pred)).sum().item()\n", " return correct\n", " \n", "def gradient_step(optimizer, output, target):\n", " F.cross_entropy(output, target).backward()\n", " optimizer.step(); optimizer.zero_grad()\n", " return count_correct(output, target)\n", "\n", "def train(model, optimizer):\n", " iter_step = lambda out, target: gradient_step(optimizer, out, target)\n", " correct_counts = dataset_iterator(mnist_train, model, iter_step)\n", " return sum(correct_counts) / len(mnist_train)\n", "\n", "test = lambda model: sum(dataset_iterator(mnist_test, model, count_correct)) / len(mnist_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### MNIST Model : LeNet-5 [LeCun et al., 1998] - (Caffe flavor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The LeNet-5 model was introduced by Yann Le Cun in 1998, and achieves easily 99% accuracy on MNIST.

\n", "We will be using the version provided in the Caffe framework, which has more parameters but is faster to train.\n", "\n", "Additionally, we will need to extract outputs after a specific layer for our reconstructions,\n", "which is why we use this odd definition instead of just a `torch.nn.Sequential` construction." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LeNet5(torch.nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = torch.nn.Conv2d(1, 20, 5)\n", " self.conv2 = torch.nn.Conv2d(20, 50, 5)\n", " self.fc1 = torch.nn.Linear(800, 500)\n", " self.fc2 = torch.nn.Linear(500, 10)\n", " self.layers = [ self.conv1, self.conv2, self.fc1, self.fc2 ]\n", " self.layer_forward = [\n", " lambda x: F.relu(F.max_pool2d(self.conv1(x), 2)),\n", " lambda x: F.relu(F.max_pool2d(self.conv2(x), 2)).view(x.size(0), -1),\n", " lambda x: F.relu(self.fc1(x)),\n", " lambda x: self.fc2(x),\n", " ]\n", " \n", " def forward(self, x, steps=4):\n", " for i in range(steps):\n", " x = self.layer_forward[i](x)\n", " return x\n", "\n", "model = LeNet5()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = \"lenet5_trained\"\n", "\n", "def save_lenet(model):\n", " torch.save(model.state_dict(), PATH)\n", "\n", "def load_lenet():\n", " model = LeNet5()\n", " model.load_state_dict(torch.load(PATH))\n", " return model.eval()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.train(); model.to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", "for i in range(8):\n", " acc = train(model, optimizer)\n", " print(f\"Epoch {i:2d} - train accuracy: {acc:.4f}\")\n", "save_lenet(model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.eval() # No need to compute gradients from now on\n", "print(f\"Test accuracy: {test(model):.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Single Layer Compression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now to the real compression.\n", "\n", "We start by performing one pass over the dataset to get the correlation matrices for the\n", "layer that we are targeting.\n", "\n", "You may have noticed that we ignored the layer bias in our reconstruction formulation.\n", "For now, you can ignore it as well, and just leave it with the same value. We will tackle this later on." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# dataset, model : LeNet5, target_layer_idx : int -> R_xx : tensor(d, d), R_yx : tensor(h, d)\n", "def gather_correlations(dataset, model, target_layer_idx):\n", " # TODO: Compute the correlation matrices for the training set\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "r_xx, r_yx = gather_correlations(mnist_train, model, target_layer_idx=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now use `solve_linear_reconstruction` to get the compressed weight, with $\\lambda = .01$, and print the number of non-zero weights kept." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# TODO: Use your solver and print the proportion of weights kept" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see what accuracy we get for this compression rate. Note that the size of the weight doesn't change yet." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " model.fc2.weight = torch.nn.Parameter(m.to(dtype=model.fc2.weight.dtype))\n", "print(f\"Test accuracy: {test(model):.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now perform the operation that justified the usage of this reconstruction in the first place:\n", "commutation of the feature extractor and non-linearity, to collapse with the previous layer (fc1), which is\n", "dense and will have some rows removed thanks to our parameter reduction on layer fc2.\n", "\n", "Don't forget to also collapse the bias of fc1." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def set_weight(layer, new_weight, bias=None):\n", " param_like = lambda ow, nw: torch.nn.Parameter(nw.to(dtype=ow.dtype, device=ow.device))\n", " with torch.no_grad():\n", " layer.weight = param_like(layer.weight, new_weight)\n", " if bias is not None:\n", " layer.bias = param_like(layer.bias, bias)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# keep : tensor