{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Pyro: Deep Universal Probabilistic Programming\n", "Gonzalo Rios (grios@dim.uchile.cl)\n", "\n", "Pyro is a universal probabilistic programming language (PPL) written in Python and supported by PyTorch on the backend. Pyro enables flexible and expressive deep probabilistic modeling, unifying the best of modern deep learning and Bayesian modeling. It was designed with these key principles:\n", "\n", "* Universal: Pyro can represent any computable probability distribution.\n", "* Scalable: Pyro scales to large data sets with little overhead.\n", "* Minimal: Pyro is implemented with a small core of powerful, composable abstractions.\n", "* Flexible: Pyro aims for automation when you want it, control when you need it. \n", "\n", "http://pyro.ai/\n", "\n", "\n", "# Pyro vs Numpy/Scipy" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:46:32.948553Z", "start_time": "2018-11-08T19:46:32.731777Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import scipy.stats as stats" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:46:58.287494Z", "start_time": "2018-11-08T19:46:58.282930Z" } }, "outputs": [], "source": [ "loc = 0. # mean zero\n", "scale = 1. # unit variance\n", "size = int(1e8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:47:05.604084Z", "start_time": "2018-11-08T19:47:01.857239Z" } }, "outputs": [], "source": [ "x_numpy = loc+scale*np.random.randn(size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:47:21.909319Z", "start_time": "2018-11-08T19:47:18.092629Z" } }, "outputs": [], "source": [ "x_scipy = stats.norm.rvs(loc=loc, scale=scale, size=size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:47:43.344360Z", "start_time": "2018-11-08T19:47:38.687470Z" } }, "outputs": [], "source": [ "logp_scipy = stats.norm.logpdf(x_scipy, loc=loc, scale=scale)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:48:07.883663Z", "start_time": "2018-11-08T19:48:07.312813Z" } }, "outputs": [], "source": [ "import torch\n", "import pyro\n", "import pyro.distributions as dist" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:48:18.690884Z", "start_time": "2018-11-08T19:48:18.686222Z" } }, "outputs": [], "source": [ "normal = dist.Normal(loc, scale) # create a normal distribution object\n", "x = normal.sample() # draw a sample from N(0,1)\n", "print(\"sample\", x)\n", "print(\"log prob\", normal.log_prob(x)) # score the sample from N(0,1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:48:32.089244Z", "start_time": "2018-11-08T19:48:31.119004Z" } }, "outputs": [], "source": [ "x_pyro = normal.sample((size,))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:48:36.539068Z", "start_time": "2018-11-08T19:48:36.023745Z" } }, "outputs": [], "source": [ "logp_pyro = normal.log_prob(x_pyro)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:48:44.437024Z", "start_time": "2018-11-08T19:48:44.417859Z" } }, "outputs": [], "source": [ "cuda = True\n", "device = torch.device(\"cuda\") if torch.cuda.is_available() and cuda else torch.device(\"cpu\")\n", "device" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:48:52.462975Z", "start_time": "2018-11-08T19:48:50.709840Z" } }, "outputs": [], "source": [ "normal = dist.Normal(torch.Tensor([loc]).to(device), torch.Tensor([scale]).to(device))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:49:04.961980Z", "start_time": "2018-11-08T19:49:04.959490Z" } }, "outputs": [], "source": [ "x_pyro = normal.sample((size,))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:49:10.349893Z", "start_time": "2018-11-08T19:49:10.346734Z" } }, "outputs": [], "source": [ "logp_pyro = normal.log_prob(x_pyro)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Models" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:50:01.579007Z", "start_time": "2018-11-08T19:50:01.568573Z" } }, "outputs": [], "source": [ "x = pyro.sample(\"my_sample\", dist.Normal(loc, scale))\n", "print(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:51:02.847539Z", "start_time": "2018-11-08T19:51:02.844421Z" } }, "outputs": [], "source": [ "def weather():\n", " cloudy = pyro.sample('cloudy', dist.Bernoulli(0.3))\n", " cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'\n", " mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]\n", " scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]\n", " temp = pyro.sample('temp', dist.Normal(mean_temp, scale_temp))\n", " return cloudy, temp" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:51:12.369413Z", "start_time": "2018-11-08T19:51:12.351541Z" } }, "outputs": [], "source": [ "for _ in range(10):\n", " print(weather())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:51:40.764395Z", "start_time": "2018-11-08T19:51:40.759399Z" } }, "outputs": [], "source": [ "def ice_cream_sales():\n", " cloudy, temp = weather()\n", " expected_sales = 200. if cloudy == 'sunny' and temp > 80.0 else 50.\n", " ice_cream = pyro.sample('ice_cream', dist.Normal(expected_sales, 10.0))\n", " return ice_cream.item()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:51:45.596062Z", "start_time": "2018-11-08T19:51:45.579669Z" } }, "outputs": [], "source": [ "for _ in range(10):\n", " print(ice_cream_sales())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:51:55.907508Z", "start_time": "2018-11-08T19:51:55.587540Z" } }, "outputs": [], "source": [ "import seaborn as sb\n", "sb.distplot(np.array([ice_cream_sales() for i in range(1000)]), bins=100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:52:08.086838Z", "start_time": "2018-11-08T19:52:07.951754Z" } }, "outputs": [], "source": [ "sb.countplot(np.array([weather()[0] for i in range(1000)]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:52:15.356275Z", "start_time": "2018-11-08T19:52:15.086278Z" } }, "outputs": [], "source": [ "sb.distplot(np.array([weather()[1].item() for i in range(1000)]), bins=100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:52:34.780299Z", "start_time": "2018-11-08T19:52:34.777188Z" } }, "outputs": [], "source": [ "def geometric(p, t=None):\n", " if t is None:\n", " t = 0\n", " x = pyro.sample(\"x_{}\".format(t), dist.Bernoulli(p))\n", " if x.item() == 0:\n", " return x\n", " else:\n", " return x + geometric(p, t + 1)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:52:35.514629Z", "start_time": "2018-11-08T19:52:35.341945Z" } }, "outputs": [], "source": [ "sb.distplot(np.array([geometric(0.1).item() for i in range(1000)]), kde=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:52:36.077149Z", "start_time": "2018-11-08T19:52:35.896379Z" } }, "outputs": [], "source": [ "sb.distplot(np.array([geometric(0.5).item() for i in range(1000)]), kde=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:52:36.673375Z", "start_time": "2018-11-08T19:52:36.318521Z" } }, "outputs": [], "source": [ "sb.distplot(np.array([geometric(0.9).item() for i in range(1000)]), kde=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:53:47.038390Z", "start_time": "2018-11-08T19:53:47.033196Z" } }, "outputs": [], "source": [ "def normal_product(loc, scale):\n", " z1 = pyro.sample(\"z1\", dist.Normal(loc, scale))\n", " z2 = pyro.sample(\"z2\", dist.Normal(loc, scale))\n", " y = z1 * z2\n", " return y\n", "\n", "def make_normal_normal():\n", " mu_latent = pyro.sample(\"mu_latent\", dist.Normal(0, 1))\n", " fn = lambda scale: normal_product(mu_latent, scale)\n", " return fn" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:53:54.327973Z", "start_time": "2018-11-08T19:53:54.165463Z" } }, "outputs": [], "source": [ "sb.distplot(np.array([make_normal_normal()(0.2).item() for i in range(100)]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:54:26.515587Z", "start_time": "2018-11-08T19:54:26.505587Z" } }, "outputs": [], "source": [ "true_coeffs = torch.tensor([-1., 2., 4.])\n", "data = torch.randn(2000, 3)\n", "labels = dist.Bernoulli(logits=(true_coeffs * data).sum(-1)).sample()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:55:10.190872Z", "start_time": "2018-11-08T19:55:10.187707Z" } }, "outputs": [], "source": [ "def model(inputs, obs):\n", " coefs = pyro.sample('beta', dist.Normal(torch.zeros(3), torch.ones(3)))\n", " y = pyro.sample('y', dist.Bernoulli(logits=(coefs * inputs).sum(-1)), obs=obs)\n", " return y" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:55:32.457142Z", "start_time": "2018-11-08T19:55:25.799176Z" } }, "outputs": [], "source": [ "from pyro.infer import Importance, EmpiricalMarginal\n", "\n", "import_run = Importance(model, num_samples=10000).run(data, labels)\n", "import_run" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:55:48.864103Z", "start_time": "2018-11-08T19:55:48.799944Z" } }, "outputs": [], "source": [ "posterior = pyro.infer.EmpiricalMarginal(import_run, 'beta')\n", "posterior" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:55:51.422269Z", "start_time": "2018-11-08T19:55:51.237832Z" } }, "outputs": [], "source": [ "sample = posterior.sample((10,)).numpy()\n", "sb.distplot(sample[:,0])\n", "sb.distplot(sample[:,1])\n", "sb.distplot(sample[:,2])\n", "sample.mean(axis=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:56:17.603773Z", "start_time": "2018-11-08T19:56:10.439380Z" } }, "outputs": [], "source": [ "from pyro.infer import mcmc\n", "\n", "hmc_kernel = mcmc.HMC(model, step_size=0.0855, num_steps=4)\n", "mcmc_run = mcmc.MCMC(hmc_kernel, num_samples=1000, warmup_steps=500).run(data, labels)\n", "mcmc_run" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:56:26.716840Z", "start_time": "2018-11-08T19:56:26.391266Z" } }, "outputs": [], "source": [ "posterior = pyro.infer.EmpiricalMarginal(mcmc_run, 'beta')\n", "sample = posterior.sample((1000,)).numpy()\n", "sb.distplot(sample[:,0])\n", "sb.distplot(sample[:,1])\n", "sb.distplot(sample[:,2])\n", "posterior.mean" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Trace" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:56:59.838302Z", "start_time": "2018-11-08T19:56:59.831333Z" } }, "outputs": [], "source": [ "N = 200 # size of toy data\n", "\n", "def build_linear_dataset(N, p=1, w=3, b=1, noise_std=0.01):\n", " X = np.random.rand(N, p)\n", " w = w * np.ones(p)\n", " y = np.matmul(X, w) + np.repeat(b, N) + np.random.normal(0, noise_std, size=N)\n", " y = y.reshape(N, 1)\n", " X, y = torch.tensor(X).type(torch.Tensor), torch.tensor(y).type(torch.Tensor)\n", " data = torch.cat((X, y), 1)\n", " assert data.shape == (N, p + 1)\n", " return data\n", "\n", "data = build_linear_dataset(N, p=1, w=2.5, b=-1.3, noise_std=0.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:57:14.102842Z", "start_time": "2018-11-08T19:57:14.097618Z" } }, "outputs": [], "source": [ "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:57:14.423270Z", "start_time": "2018-11-08T19:57:14.315308Z" } }, "outputs": [], "source": [ "plt.plot(data[:,0].numpy(), data[:,1].numpy(), '*')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:57:30.569765Z", "start_time": "2018-11-08T19:57:30.563410Z" } }, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "class RegressionModel(nn.Module):\n", " def __init__(self, p):\n", " # p = number of features\n", " super(RegressionModel, self).__init__()\n", " self.linear = nn.Linear(p, 1)\n", "\n", " def forward(self, x):\n", " return self.linear(x)\n", "\n", "regression_model = RegressionModel(1)\n", "regression_model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:57:33.500040Z", "start_time": "2018-11-08T19:57:33.496260Z" } }, "outputs": [], "source": [ "for p in list(regression_model.named_parameters()):\n", " print(p[0], p[1].shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:58:25.270159Z", "start_time": "2018-11-08T19:58:25.265470Z" } }, "outputs": [], "source": [ "def model(data):\n", " \n", " # Create unit dist.Normal priors over the parameters\n", " zero = torch.zeros(1, 1)\n", " ten = 10 * torch.ones(1, 1)\n", " w_prior = dist.Normal(zero, ten)\n", " b_prior = dist.Normal(zero, ten)\n", " priors = {'linear.weight': w_prior, 'linear.bias': b_prior}\n", " noise = 0.1 * torch.ones(data.size(0))\n", " \n", " # lift module parameters to random variables sampled from the priors\n", " lifted_module = pyro.random_module(\"module\", regression_model, priors)\n", " # sample a regressor (which also samples w and b)\n", " lifted_reg_model = lifted_module()\n", " with pyro.iarange(\"map\", N):\n", " x_data = data[:, :-1]\n", " y_data = data[:, -1]\n", "\n", " # run the regressor forward conditioned on data\n", " prediction_mean = lifted_reg_model(x_data).squeeze(-1)\n", " # condition on the observed data\n", " pyro.sample(\"obs\",\n", " dist.Normal(prediction_mean, 0.1 * torch.ones(data.size(0))),\n", " obs=y_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:58:27.282358Z", "start_time": "2018-11-08T19:58:27.271801Z" }, "scrolled": true }, "outputs": [], "source": [ "from pyro import poutine\n", "\n", "model_trace = poutine.trace(model).get_trace(data)\n", "model_trace" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:58:33.640729Z", "start_time": "2018-11-08T19:58:33.629822Z" } }, "outputs": [], "source": [ "model_trace.stochastic_nodes" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:58:35.931742Z", "start_time": "2018-11-08T19:58:35.922703Z" } }, "outputs": [], "source": [ "model_trace.nodes" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:58:37.354493Z", "start_time": "2018-11-08T19:58:37.346901Z" }, "scrolled": true }, "outputs": [], "source": [ "model_trace.nodes['_INPUT']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:58:40.993614Z", "start_time": "2018-11-08T19:58:40.983874Z" } }, "outputs": [], "source": [ "model_trace.nodes['_RETURN']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:58:42.081861Z", "start_time": "2018-11-08T19:58:42.074940Z" } }, "outputs": [], "source": [ "model_trace.nodes['obs']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:58:51.730618Z", "start_time": "2018-11-08T19:58:51.720583Z" } }, "outputs": [], "source": [ "model_trace.nodes['module$$$linear.weight']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:58:56.515782Z", "start_time": "2018-11-08T19:58:56.510983Z" } }, "outputs": [], "source": [ "model_trace.compute_log_prob()\n", "model_trace.log_prob_sum()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:58:59.794714Z", "start_time": "2018-11-08T19:58:59.781667Z" } }, "outputs": [], "source": [ "model_trace.nodes['module$$$linear.weight']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:59:05.710596Z", "start_time": "2018-11-08T19:59:05.705144Z" } }, "outputs": [], "source": [ "model_trace.nodes['obs']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Stochastic Variational Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T19:59:56.362549Z", "start_time": "2018-11-08T19:59:56.341295Z" } }, "outputs": [], "source": [ "softplus = torch.nn.Softplus()\n", "\n", "def guide(data):\n", " # define our variational parameters\n", " w_loc = torch.randn(1, 1)\n", " # note that we initialize our scales to be pretty narrow\n", " w_log_sig = torch.tensor(-3.0 * torch.ones(1, 1) + 0.05 * torch.randn(1, 1))\n", " b_loc = torch.randn(1)\n", " b_log_sig = torch.tensor(-3.0 * torch.ones(1) + 0.05 * torch.randn(1))\n", " # register learnable params in the param store\n", " mw_param = pyro.param(\"guide_mean_weight\", w_loc)\n", " sw_param = softplus(pyro.param(\"guide_log_scale_weight\", w_log_sig))\n", " \n", " mb_param = pyro.param(\"guide_mean_bias\", b_loc)\n", " sb_param = softplus(pyro.param(\"guide_log_scale_bias\", b_log_sig))\n", " # guide distributions for w and b\n", " w_dist = dist.Normal(mw_param, sw_param).independent(1)\n", " b_dist = dist.Normal(mb_param, sb_param).independent(1)\n", " \n", " dists = {'linear.weight': w_dist, 'linear.bias': b_dist}\n", " # overload the parameters in the module with random samples\n", " # from the guide distributions\n", " lifted_module = pyro.random_module(\"module\", regression_model, dists)\n", " # sample a regressor (which also samples w and b)\n", " return lifted_module()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:00:35.834259Z", "start_time": "2018-11-08T20:00:35.831865Z" } }, "outputs": [], "source": [ "from pyro.optim import Adam\n", "from pyro.infer import SVI, Trace_ELBO" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:00:36.046902Z", "start_time": "2018-11-08T20:00:36.042028Z" } }, "outputs": [], "source": [ "optim = Adam({\"lr\": 0.05})\n", "svi = SVI(model, guide, optim, loss=Trace_ELBO())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:00:45.150530Z", "start_time": "2018-11-08T20:00:45.138315Z" } }, "outputs": [], "source": [ "losses = []\n", "\n", "def main(num_iterations = 1000):\n", " pyro.clear_param_store()\n", " for j in range(num_iterations):\n", " # calculate the loss and take a gradient step\n", " loss = svi.step(data)\n", " losses.append(loss)\n", " if j % 100 == 0:\n", " print(\"[iteration %04d] loss: %.4f\" % (j, loss / float(N)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:00:50.883403Z", "start_time": "2018-11-08T20:00:47.828107Z" } }, "outputs": [], "source": [ "main()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:00:51.873773Z", "start_time": "2018-11-08T20:00:51.759320Z" } }, "outputs": [], "source": [ "plt.plot(losses)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:00:56.004359Z", "start_time": "2018-11-08T20:00:55.904859Z" } }, "outputs": [], "source": [ "plt.plot(losses[100:])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:06.487909Z", "start_time": "2018-11-08T20:01:06.484016Z" } }, "outputs": [], "source": [ "for name in pyro.get_param_store().get_all_param_names():\n", " print(\"[%s]: %.3f\" % (name, pyro.param(name).data.numpy()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:10.195993Z", "start_time": "2018-11-08T20:01:10.146239Z" }, "scrolled": false }, "outputs": [], "source": [ "pred = 0 \n", "for i in range(100):\n", " # guide does not require the data\n", " sampled_reg_model = guide(None)\n", " # run the regression model and add prediction to total\n", " pred += sampled_reg_model(data[:,:-1])\n", "# take the average of the predictions\n", "pred /= 100\n", "obs = data[:,-1:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:11.508218Z", "start_time": "2018-11-08T20:01:11.286627Z" } }, "outputs": [], "source": [ "plt.figure(figsize=(20,10))\n", "plt.plot(obs.detach().numpy()[:200])\n", "plt.plot(pred.detach().numpy()[:200])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:15.882749Z", "start_time": "2018-11-08T20:01:15.878582Z" } }, "outputs": [], "source": [ "loss = nn.MSELoss()\n", "loss(pred, obs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Gaussian Processes" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:20.648251Z", "start_time": "2018-11-08T20:01:20.629746Z" } }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "import pyro\n", "import pyro.contrib.gp as gp\n", "import pyro.distributions as dist\n", "from pyro.infer import SVI, Trace_ELBO\n", "from pyro.optim import Adam\n", "\n", "def plot(plot_observed_data=False, plot_predictions=False, n_prior_samples=0,\n", " model=None, kernel=None, n_test=500):\n", "\n", " plt.figure(figsize=(12, 6))\n", " if plot_observed_data:\n", " plt.plot(X.numpy(), y.numpy(), 'kx')\n", " if plot_predictions:\n", " Xtest = torch.linspace(-0.5, 5.5, n_test) # test inputs\n", " # compute predictive mean and variance\n", " with torch.no_grad():\n", " if type(model) == gp.models.VariationalSparseGP:\n", " mean, cov = model(Xtest, full_cov=True)\n", " else:\n", " mean, cov = model(Xtest, full_cov=True, noiseless=False)\n", " sd = cov.diag().sqrt() # standard deviation at each input point x\n", " plt.plot(Xtest.numpy(), mean.numpy(), 'r', lw=2) # plot the mean\n", " plt.fill_between(Xtest.numpy(), # plot the two-sigma uncertainty about the mean\n", " (mean - 2.0 * sd).numpy(),\n", " (mean + 2.0 * sd).numpy(),\n", " color='C0', alpha=0.3)\n", " if n_prior_samples > 0: # plot samples from the GP prior\n", " Xtest = torch.linspace(-0.5, 5.5, n_test) # test inputs\n", " noise = (model.noise if type(model) != gp.models.VariationalSparseGP\n", " else model.likelihood.variance)\n", " cov = kernel.forward(Xtest) + noise.expand(n_test).diag()\n", " samples = dist.MultivariateNormal(torch.zeros(n_test), covariance_matrix=cov)\\\n", " .sample(sample_shape=(n_prior_samples,))\n", " plt.plot(Xtest.numpy(), samples.numpy().T, lw=2, alpha=0.4)\n", "\n", " plt.xlim(-0.5, 5.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:21.123839Z", "start_time": "2018-11-08T20:01:21.004943Z" } }, "outputs": [], "source": [ "N = 20\n", "X = dist.Uniform(0.0, 5.0).sample(sample_shape=(N,))\n", "y = 0.5 * torch.sin(3*X) + dist.Normal(0.0, 0.2).sample(sample_shape=(N,))\n", "\n", "plot(plot_observed_data=True) # let's plot the observed data\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:21.378335Z", "start_time": "2018-11-08T20:01:21.370818Z" } }, "outputs": [], "source": [ "kernel = gp.kernels.RBF(input_dim=1, variance=torch.tensor(5.),\n", " lengthscale=torch.tensor(10.))\n", "gpr = gp.models.GPRegression(X, y, kernel, noise=torch.tensor(1.))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:21.958011Z", "start_time": "2018-11-08T20:01:21.783018Z" } }, "outputs": [], "source": [ "plot(model=gpr, kernel=kernel, n_prior_samples=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:22.317355Z", "start_time": "2018-11-08T20:01:22.085788Z" } }, "outputs": [], "source": [ "kernel2 = gp.kernels.RBF(input_dim=1, variance=torch.tensor(10.),\n", " lengthscale=torch.tensor(.1))\n", "gpr2 = gp.models.GPRegression(X, y, kernel2, noise=torch.tensor(0.1))\n", "plot(model=gpr2, kernel=kernel2, n_prior_samples=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:27.495066Z", "start_time": "2018-11-08T20:01:22.332208Z" } }, "outputs": [], "source": [ "optim = Adam({\"lr\": 0.005})\n", "svi = SVI(gpr.model, gpr.guide, optim, loss=Trace_ELBO())\n", "losses = []\n", "num_steps = 2500\n", "for i in range(num_steps):\n", " losses.append(svi.step())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:27.621640Z", "start_time": "2018-11-08T20:01:27.496986Z" } }, "outputs": [], "source": [ "plt.plot(losses)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:27.770366Z", "start_time": "2018-11-08T20:01:27.623396Z" } }, "outputs": [], "source": [ "plot(model=gpr, plot_observed_data=True, plot_predictions=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:27.776880Z", "start_time": "2018-11-08T20:01:27.771871Z" } }, "outputs": [], "source": [ "gpr.kernel.get_param(\"variance\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:27.871576Z", "start_time": "2018-11-08T20:01:27.778828Z" } }, "outputs": [], "source": [ "gpr.kernel.get_param(\"lengthscale\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:27.964760Z", "start_time": "2018-11-08T20:01:27.873190Z" } }, "outputs": [], "source": [ "gpr.get_param(\"noise\").item()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:34.901832Z", "start_time": "2018-11-08T20:01:27.968651Z" } }, "outputs": [], "source": [ "gpr.kernel.set_prior(\"lengthscale\", dist.LogNormal(0.0, 1.0))\n", "gpr.kernel.set_prior(\"variance\", dist.LogNormal(0.0, 1.0))\n", "# we reset the param store so that the previous inference doesn't interfere with this one\n", "pyro.clear_param_store()\n", "optim = Adam({\"lr\": 0.005})\n", "svi = SVI(gpr.model, gpr.guide, optim, loss=Trace_ELBO())\n", "losses = []\n", "num_steps = 2500\n", "for i in range(num_steps):\n", " losses.append(svi.step())\n", "plt.plot(losses)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:35.089511Z", "start_time": "2018-11-08T20:01:34.903824Z" } }, "outputs": [], "source": [ "plot(model=gpr, plot_observed_data=True, plot_predictions=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:35.094757Z", "start_time": "2018-11-08T20:01:35.091420Z" } }, "outputs": [], "source": [ "for param_name in pyro.get_param_store().get_all_param_names():\n", " print('{} = {}'.format(param_name, pyro.param(param_name).item()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:35.471124Z", "start_time": "2018-11-08T20:01:35.097106Z" } }, "outputs": [], "source": [ "N = 1000\n", "X = dist.Uniform(0.0, 5.0).sample(sample_shape=(N,))\n", "y = 0.5 * torch.sin(3*X) + dist.Normal(0.0, 0.2).sample(sample_shape=(N,))\n", "plot(plot_observed_data=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:46.852724Z", "start_time": "2018-11-08T20:01:35.472598Z" } }, "outputs": [], "source": [ "Xu = torch.arange(20, dtype=torch.float) / 4.0\n", "\n", "# initialize the kernel and model\n", "kernel = gp.kernels.RBF(input_dim=1)\n", "# we increase the jitter for better numerical stability\n", "sgpr = gp.models.SparseGPRegression(X, y, kernel, Xu=Xu, jitter=1.0e-5)\n", "\n", "# the way we setup inference is similar to above\n", "pyro.clear_param_store()\n", "optim = Adam({\"lr\": 0.005})\n", "svi = SVI(sgpr.model, sgpr.guide, optim, loss=Trace_ELBO())\n", "losses = []\n", "num_steps = 2500\n", "for i in range(num_steps):\n", " losses.append(svi.step())\n", "plt.plot(losses)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:47.031208Z", "start_time": "2018-11-08T20:01:46.860966Z" } }, "outputs": [], "source": [ "# let's look at the inducing points we've learned\n", "print(\"inducing points:\\n{}\".format(pyro.param(\"SGPR$$$Xu\").data.numpy()))\n", "# and plot the predictions from the sparse GP\n", "plot(model=sgpr, plot_observed_data=True, plot_predictions=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:54.080732Z", "start_time": "2018-11-08T20:01:47.033370Z" } }, "outputs": [], "source": [ "# initialize the inducing inputs\n", "Xu = torch.arange(10, dtype=torch.float) / 2.0\n", "\n", "# initialize the kernel, likelihood, and model\n", "kernel = gp.kernels.RBF(input_dim=1)\n", "likelihood = gp.likelihoods.Gaussian()\n", "# turn on \"whiten\" flag for more stable optimization\n", "vsgp = gp.models.VariationalSparseGP(X, y, kernel, Xu=Xu, likelihood=likelihood, whiten=True)\n", "\n", "pyro.clear_param_store()\n", "# instead of defining our own training loop, we will\n", "# use the built-in support provided by the GP module\n", "num_steps = 1500\n", "losses = vsgp.optimize(optimizer=Adam({\"lr\": 0.01}), num_steps=num_steps)\n", "plt.plot(losses)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-11-08T20:01:54.236350Z", "start_time": "2018-11-08T20:01:54.088078Z" } }, "outputs": [], "source": [ "plot(model=vsgp, plot_observed_data=True, plot_predictions=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "384px" }, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }