home / ml /2020 /03 /10 /no-gd-training.html
Narsil's picture
Narsil HF Staff
Static space.
4c7b631
<!DOCTYPE html>
<html lang="en"><head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1"><link rel="shortcut icon" type="image/x-icon" href="/narsil.github.io/favicon.ico"><!-- Begin Jekyll SEO tag v2.6.1 -->
<title>Can we train neural networks without gradient descent ? | Narsil</title>
<meta name="generator" content="Jekyll v3.8.5" />
<meta property="og:title" content="Can we train neural networks without gradient descent ?" />
<meta property="og:locale" content="en_US" />
<meta name="description" content="If the lottery ticket hypothesis is real, does that mean we can train a neural network without gradient descent?" />
<meta property="og:description" content="If the lottery ticket hypothesis is real, does that mean we can train a neural network without gradient descent?" />
<link rel="canonical" href="http://localhost:4000/narsil.github.io/ml/2020/03/10/no-gd-training.html" />
<meta property="og:url" content="http://localhost:4000/narsil.github.io/ml/2020/03/10/no-gd-training.html" />
<meta property="og:site_name" content="Narsil" />
<meta property="og:type" content="article" />
<meta property="article:published_time" content="2020-03-10T00:00:00+01:00" />
<script type="application/ld+json">
{"description":"If the lottery ticket hypothesis is real, does that mean we can train a neural network without gradient descent?","mainEntityOfPage":{"@type":"WebPage","@id":"http://localhost:4000/narsil.github.io/ml/2020/03/10/no-gd-training.html"},"@type":"BlogPosting","url":"http://localhost:4000/narsil.github.io/ml/2020/03/10/no-gd-training.html","headline":"Can we train neural networks without gradient descent ?","dateModified":"2020-03-10T00:00:00+01:00","datePublished":"2020-03-10T00:00:00+01:00","@context":"https://schema.org"}</script>
<!-- End Jekyll SEO tag -->
<link href="https://unpkg.com/@primer/css/dist/primer.css" rel="stylesheet" />
<link rel="stylesheet" href="/narsil.github.io/assets/main.css">
<link rel="stylesheet" href="//use.fontawesome.com/releases/v5.0.7/css/all.css"><link type="application/atom+xml" rel="alternate" href="http://localhost:4000/narsil.github.io/feed.xml" title="Narsil" />
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/katex.min.css" integrity="sha384-zB1R0rpPzHqg7Kpt0Aljp8JPLqbXI3bhnPWROx27a9N0Ll6ZP/+DiW/UqRcLbRjq" crossorigin="anonymous">
<script type="text/javascript" async src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-MML-AM_CHTML"> </script>
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/katex.min.js" integrity="sha384-y23I5Q6l+B6vatafAwxRu/0oK/79VlbSz7Q9aiSZUvyWYIYsd+qj+o24G5ZU2zJz" crossorigin="anonymous"></script>
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/contrib/auto-render.min.js" integrity="sha384-kWPLUVMOks5AQFrykwIup5lo0m3iMkkHrD0uJ4H5cjeGihAutqP0yW0J6dpFiVkI" crossorigin="anonymous"></script>
<script>
document.addEventListener("DOMContentLoaded", function() {
renderMathInElement( document.body, {
delimiters: [
{left: "$$", right: "$$", display: true},
{left: "[%", right: "%]", display: true},
{left: "$", right: "$", display: false}
]}
);
});
</script>
<script>
function wrap_img(fn) {
if (document.attachEvent ? document.readyState === "complete" : document.readyState !== "loading") {
var elements = document.querySelectorAll(".post img");
Array.prototype.forEach.call(elements, function(el, i) {
if (el.getAttribute("title")) {
const caption = document.createElement('figcaption');
var node = document.createTextNode(el.getAttribute("title"));
caption.appendChild(node);
const wrapper = document.createElement('figure');
wrapper.className = 'image';
el.parentNode.insertBefore(wrapper, el);
el.parentNode.removeChild(el);
wrapper.appendChild(el);
wrapper.appendChild(caption);
}
});
} else { document.addEventListener('DOMContentLoaded', fn); }
}
window.onload = wrap_img;
</script>
<script>
document.addEventListener("DOMContentLoaded", function(){
// add link icon to anchor tags
var elem = document.querySelectorAll(".anchor-link")
elem.forEach(e => (e.innerHTML = '<i class="fas fa-link fa-xs"></i>'));
// remove paragraph tags in rendered toc (happens from notebooks)
var toctags = document.querySelectorAll(".toc-entry")
toctags.forEach(e => (e.firstElementChild.innerText = e.firstElementChild.innerText.replace('¶', '')))
});
</script>
</head><body><header class="site-header" role="banner">
<div class="wrapper"><a class="site-title" rel="author" href="/narsil.github.io/">Narsil</a><nav class="site-nav">
<input type="checkbox" id="nav-trigger" class="nav-trigger" />
<label for="nav-trigger">
<span class="menu-icon">
<svg viewBox="0 0 18 15" width="18px" height="15px">
<path d="M18,1.484c0,0.82-0.665,1.484-1.484,1.484H1.484C0.665,2.969,0,2.304,0,1.484l0,0C0,0.665,0.665,0,1.484,0 h15.032C17.335,0,18,0.665,18,1.484L18,1.484z M18,7.516C18,8.335,17.335,9,16.516,9H1.484C0.665,9,0,8.335,0,7.516l0,0 c0-0.82,0.665-1.484,1.484-1.484h15.032C17.335,6.031,18,6.696,18,7.516L18,7.516z M18,13.516C18,14.335,17.335,15,16.516,15H1.484 C0.665,15,0,14.335,0,13.516l0,0c0-0.82,0.665-1.483,1.484-1.483h15.032C17.335,12.031,18,12.695,18,13.516L18,13.516z"/>
</svg>
</span>
</label>
<div class="trigger"><a class="page-link" href="/narsil.github.io/about/">About Me</a><a class="page-link" href="/narsil.github.io/search/">Search</a><a class="page-link" href="/narsil.github.io/categories/">Tags</a></div>
</nav></div>
</header>
<main class="page-content" aria-label="Content">
<div class="wrapper">
<article class="post h-entry" itemscope itemtype="http://schema.org/BlogPosting">
<header class="post-header">
<h1 class="post-title p-name" itemprop="name headline">Can we train neural networks without gradient descent ?</h1><p class="page-description">If the lottery ticket hypothesis is real, does that mean we can train a neural network without gradient descent?</p><p class="post-meta post-meta-title"><time class="dt-published" datetime="2020-03-10T00:00:00+01:00" itemprop="datePublished">
Mar 10, 2020
</time>
<span class="read-time" title="Estimated read time">
4 min read
</span></p>
<p class="category-tags"><i class="fas fa-tags category-tags-icon"></i></i>
<a class="category-tags-link" href="/narsil.github.io/categories/#ml">ml</a>
</p>
<div class="pb-5 d-flex flex-wrap flex-justify-end">
<div class="px-2">
<a href="https://github.com/Narsil/narsil.github.io/tree/master/_notebooks/2020-03-10-no-gd-training.ipynb" role="button">
<img class="notebook-badge-image" src="https://img.shields.io/static/v1?label=&message=View%20On%20GitHub&color=586069&logo=github&labelColor=2f363d">
</a>
</div><div class="px-2">
<a href="https://colab.research.google.com/github/Narsil/narsil.github.io/blob/master/_notebooks/2020-03-10-no-gd-training.ipynb">
<img class="notebook-badge-image" src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
</div>
</div>
</header>
<div class="post-content e-content" itemprop="articleBody">
<ul class="section-nav">
<li class="toc-entry toc-h2"><a href="#What's-the-problem-?">What&#39;s the problem ? </a></li>
<li class="toc-entry toc-h2"><a href="#Experiments">Experiments </a></li>
</ul><!--
#################################################
### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ###
#################################################
# file to edit: _notebooks/2020-03-10-no-gd-training.ipynb
-->
<div class="container" id="notebook-container">
<div class="cell border-box-sizing code_cell rendered">
</div>
<div class="cell border-box-sizing text_cell rendered">
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h2 id="What's-the-problem-?">
<a class="anchor" href="#What's-the-problem-?" aria-hidden="true"><span class="octicon octicon-link"></span></a>What's the problem ?<a class="anchor-link" href="#What's-the-problem-?"> </a>
</h2>
<p>ML models usually are not really capable of predicting how well the data you<br>
feed them is close to what was in the dataset. It really matters in production
models as they might make really stupid mistakes just because they are off<br>
the training set.</p>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered">
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>Let's train a simple mnist model (straight out from pytorch tutorial <a href="https://github.com/pytorch/examples/tree/master/mnist">https://github.com/pytorch/examples/tree/master/mnist</a>)</p>
</div>
</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<details class="description">
<summary class="btn btn-sm" data-open="Hide Code" data-close="Show Code"></summary>
<p></p>
<div class="input">
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1">#collapse</span>
<span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">print_function</span>
<span class="kn">import</span> <span class="nn">argparse</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
<span class="kn">import</span> <span class="nn">torch.optim</span> <span class="k">as</span> <span class="nn">optim</span>
<span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">datasets</span><span class="p">,</span> <span class="n">transforms</span>
<span class="kn">from</span> <span class="nn">torch.optim.lr_scheduler</span> <span class="kn">import</span> <span class="n">StepLR</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="k">class</span> <span class="nc">Net</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">Net</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dropout1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout2d</span><span class="p">(</span><span class="mf">0.25</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dropout2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout2d</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">9216</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span>
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">):</span>
<span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
<span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loader</span><span class="p">):</span>
<span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="k">if</span> <span class="n">batch_idx</span> <span class="o">%</span> <span class="n">args</span><span class="o">.</span><span class="n">log_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">'Train Epoch: </span><span class="si">{}</span><span class="s1"> [</span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1"> (</span><span class="si">{:.0f}</span><span class="s1">%)]</span><span class="se">\t</span><span class="s1">Loss: </span><span class="si">{:.6f}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
<span class="n">epoch</span><span class="p">,</span> <span class="n">batch_idx</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span>
<span class="mf">100.</span> <span class="o">*</span> <span class="n">batch_idx</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="p">),</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()))</span>
<span class="k">def</span> <span class="nf">test</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">):</span>
<span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="n">test_loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="k">for</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="ow">in</span> <span class="n">test_loader</span><span class="p">:</span>
<span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">test_loss</span> <span class="o">+=</span> <span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">'sum'</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="c1"># sum up batch loss</span>
<span class="n">pred</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># get the index of the max log-probability</span>
<span class="n">correct</span> <span class="o">+=</span> <span class="n">pred</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">view_as</span><span class="p">(</span><span class="n">pred</span><span class="p">))</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">test_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">'</span><span class="se">\n</span><span class="s1">Test set: Average loss: </span><span class="si">{:.4f}</span><span class="s1">, Accuracy: </span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1"> (</span><span class="si">{:.0f}</span><span class="s1">%)</span><span class="se">\n</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
<span class="n">test_loss</span><span class="p">,</span> <span class="n">correct</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span>
<span class="mf">100.</span> <span class="o">*</span> <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)))</span>
<span class="k">def</span> <span class="nf">mnist</span><span class="p">():</span>
<span class="n">filename</span> <span class="o">=</span><span class="s2">"mnist_cnn.pt"</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">filename</span><span class="p">):</span>
<span class="k">return</span>
<span class="c1"># Training settings</span>
<span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="s1">'PyTorch MNIST Example'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--batch-size'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'N'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'input batch size for training (default: 64)'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--test-batch-size'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'N'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'input batch size for testing (default: 1000)'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--epochs'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'N'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'number of epochs to train (default: 14)'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--lr'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'LR'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'learning rate (default: 1.0)'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--gamma'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">0.7</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'M'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'Learning rate step gamma (default: 0.7)'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--no-cuda'</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">'store_true'</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'disables CUDA training'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--seed'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'S'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'random seed (default: 1)'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--log-interval'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'N'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'how many batches to wait before logging training status'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--save-model'</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">'store_true'</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'For Saving the current Model'</span><span class="p">)</span>
<span class="n">args</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse_args</span><span class="p">()</span>
<span class="n">use_cuda</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">args</span><span class="o">.</span><span class="n">no_cuda</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
<span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">"cuda"</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="s2">"cpu"</span><span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'num_workers'</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">'pin_memory'</span><span class="p">:</span> <span class="kc">True</span><span class="p">}</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="p">{}</span>
<span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="s1">'../data'</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
<span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span>
<span class="p">])),</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="s1">'../data'</span><span class="p">),</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
<span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span>
<span class="p">])),</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">test_batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adadelta</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">lr</span><span class="p">)</span>
<span class="n">scheduler</span> <span class="o">=</span> <span class="n">StepLR</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">step_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">gamma</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">gamma</span><span class="p">)</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">args</span><span class="o">.</span><span class="n">epochs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
<span class="n">train</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span>
<span class="n">test</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">)</span>
<span class="n">scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="k">if</span> <span class="n">args</span><span class="o">.</span><span class="n">save_model</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">filename</span><span class="p">)</span>
<span class="c1"># mnist()</span>
</pre></div>
</div>
</div>
</div>
</details>
</div>
<div class="cell border-box-sizing text_cell rendered">
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>Other out of distribution detector have been proposed. Here is a sample of methods:</p>
<ul>
<li>Genetic algorithms</li>
<li>DFO</li>
<li>Simulated annealing</li>
</ul>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered">
<div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h2 id="Experiments">
<a class="anchor" href="#Experiments" aria-hidden="true"><span class="octicon octicon-link"></span></a>Experiments<a class="anchor-link" href="#Experiments"> </a>
</h2>
</div>
</div>
</div>
<div class="cell border-box-sizing code_cell rendered">
<div class="input">
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython3"><pre><span></span><span class="k">def</span> <span class="nf">train_ticket</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">):</span>
<span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
<span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loader</span><span class="p">):</span>
<span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="k">if</span> <span class="n">batch_idx</span> <span class="o">%</span> <span class="n">args</span><span class="o">.</span><span class="n">log_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">'Train Epoch: </span><span class="si">{}</span><span class="s1"> [</span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1"> (</span><span class="si">{:.0f}</span><span class="s1">%)]</span><span class="se">\t</span><span class="s1">Loss: </span><span class="si">{:.6f}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
<span class="n">epoch</span><span class="p">,</span> <span class="n">batch_idx</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span>
<span class="mf">100.</span> <span class="o">*</span> <span class="n">batch_idx</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="p">),</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()))</span>
<span class="k">def</span> <span class="nf">test_ticket</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">):</span>
<span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="n">test_loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="k">for</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="ow">in</span> <span class="n">test_loader</span><span class="p">:</span>
<span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">test_loss</span> <span class="o">+=</span> <span class="n">F</span><span class="o">.</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">'sum'</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="c1"># sum up batch loss</span>
<span class="n">pred</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># get the index of the max log-probability</span>
<span class="n">correct</span> <span class="o">+=</span> <span class="n">pred</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">view_as</span><span class="p">(</span><span class="n">pred</span><span class="p">))</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">test_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">'</span><span class="se">\n</span><span class="s1">Test set: Average loss: </span><span class="si">{:.4f}</span><span class="s1">, Accuracy: </span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1"> (</span><span class="si">{:.0f}</span><span class="s1">%)</span><span class="se">\n</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
<span class="n">test_loss</span><span class="p">,</span> <span class="n">correct</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">),</span>
<span class="mf">100.</span> <span class="o">*</span> <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">dataset</span><span class="p">)))</span>
<span class="k">def</span> <span class="nf">ticket_finder</span><span class="p">():</span>
<span class="n">filename</span> <span class="o">=</span><span class="s2">"ticket_finder.pt"</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">filename</span><span class="p">):</span>
<span class="k">return</span>
<span class="c1"># Training settings</span>
<span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="s1">'PyTorch MNIST Example'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--batch-size'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'N'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'input batch size for training (default: 64)'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--test-batch-size'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'N'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'input batch size for testing (default: 1000)'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--epochs'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'N'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'number of epochs to train (default: 14)'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--lr'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'LR'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'learning rate (default: 1.0)'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--gamma'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">0.7</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'M'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'Learning rate step gamma (default: 0.7)'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--no-cuda'</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">'store_true'</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'disables CUDA training'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--seed'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'S'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'random seed (default: 1)'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--log-interval'</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">metavar</span><span class="o">=</span><span class="s1">'N'</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'how many batches to wait before logging training status'</span><span class="p">)</span>
<span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s1">'--save-model'</span><span class="p">,</span> <span class="n">action</span><span class="o">=</span><span class="s1">'store_true'</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">help</span><span class="o">=</span><span class="s1">'For Saving the current Model'</span><span class="p">)</span>
<span class="n">args</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse_args</span><span class="p">()</span>
<span class="n">use_cuda</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">args</span><span class="o">.</span><span class="n">no_cuda</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
<span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">"cuda"</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="s2">"cpu"</span><span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'num_workers'</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">'pin_memory'</span><span class="p">:</span> <span class="kc">True</span><span class="p">}</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="p">{}</span>
<span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="s1">'../data'</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
<span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span>
<span class="p">])),</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="s1">'../data'</span><span class="p">),</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
<span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.1307</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.3081</span><span class="p">,))</span>
<span class="p">])),</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">test_batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">TicketFinder</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">args</span><span class="o">.</span><span class="n">epochs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
<span class="n">train_ticket</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span>
<span class="n">test_ticket</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">)</span>
<span class="k">if</span> <span class="n">args</span><span class="o">.</span><span class="n">save_model</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">filename</span><span class="p">)</span>
</pre></div>
</div>
</div>
</div>
</div>
</div>
</div><!-- from https://github.com/utterance/utterances -->
<script src="https://utteranc.es/client.js"
repo="Narsil/narsil.github.io"
issue-term="title"
label="blogpost-comment"
theme="github-light"
crossorigin="anonymous"
async>
</script><a class="u-url" href="/narsil.github.io/ml/2020/03/10/no-gd-training.html" hidden></a>
</article>
</div>
</main><footer class="site-footer h-card">
<data class="u-url" href="/narsil.github.io/"></data>
<div class="wrapper">
<h2 class="footer-heading">Narsil</h2>
<div class="footer-col-wrapper">
<div class="footer-col footer-col-1">
<ul class="contact-list">
<li class="p-name">Narsil</li></ul>
</div>
<div class="footer-col footer-col-2"><ul class="social-media-list">
<li><a href="https://github.com/Narsil"><svg class="social svg-icon"><use xlink:href="/narsil.github.io/assets/minima-social-icons.svg#github"></use></svg> <span class="username">Narsil</span></a></li><li><a href="https://www.twitter.com/narsilou"><svg class="social svg-icon"><use xlink:href="/narsil.github.io/assets/minima-social-icons.svg#twitter"></use></svg> <span class="username">narsilou</span></a></li></ul>
</div>
<div class="footer-col footer-col-3">
<p>Small experiements insights from ML and software development.</p>
</div>
</div>
</div>
</footer>
</body>
</html>