Compare commits
4 Commits
release_up
...
release_up
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
883e4a4c59 | ||
|
|
e90826b331 | ||
|
|
ac04172f6d | ||
|
|
8c0000beb4 |
@@ -94,6 +94,17 @@ def main():
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
|
||||
# Use Azure Open Datasets for MNIST dataset
|
||||
datasets.MNIST.resources = [
|
||||
("https://azureopendatastorage.azurefd.net/mnist/train-images-idx3-ubyte.gz",
|
||||
"f68b3c2dcbeaaa9fbdd348bbdeb94873"),
|
||||
("https://azureopendatastorage.azurefd.net/mnist/train-labels-idx1-ubyte.gz",
|
||||
"d53e105ee54ea40749a09fcbcd1e9432"),
|
||||
("https://azureopendatastorage.azurefd.net/mnist/t10k-images-idx3-ubyte.gz",
|
||||
"9fb629c4189551a2d022fa330f9573f3"),
|
||||
("https://azureopendatastorage.azurefd.net/mnist/t10k-labels-idx1-ubyte.gz",
|
||||
"ec29112dd5afa0611ce80d1b7f02629c")
|
||||
]
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('data', train=True, download=True,
|
||||
transform=transforms.Compose([transforms.ToTensor(),
|
||||
|
||||
@@ -3,7 +3,11 @@ dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-interpret
|
||||
- interpret-community[visualization]
|
||||
- flask
|
||||
- flask-cors
|
||||
- gevent>=1.3.6
|
||||
- jinja2
|
||||
- ipython
|
||||
- matplotlib
|
||||
- azureml-dataset-runtime
|
||||
- ipywidgets
|
||||
|
||||
@@ -3,6 +3,10 @@ dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-interpret
|
||||
- interpret-community[visualization]
|
||||
- flask
|
||||
- flask-cors
|
||||
- gevent>=1.3.6
|
||||
- jinja2
|
||||
- ipython
|
||||
- matplotlib
|
||||
- ipywidgets
|
||||
|
||||
@@ -3,6 +3,10 @@ dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-interpret
|
||||
- interpret-community[visualization]
|
||||
- flask
|
||||
- flask-cors
|
||||
- gevent>=1.3.6
|
||||
- jinja2
|
||||
- ipython
|
||||
- matplotlib
|
||||
- ipywidgets
|
||||
|
||||
@@ -3,7 +3,11 @@ dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-interpret
|
||||
- interpret-community[visualization]
|
||||
- flask
|
||||
- flask-cors
|
||||
- gevent>=1.3.6
|
||||
- jinja2
|
||||
- ipython
|
||||
- matplotlib
|
||||
- azureml-dataset-runtime
|
||||
- azureml-core
|
||||
|
||||
@@ -51,6 +51,17 @@ if args.cuda:
|
||||
|
||||
|
||||
kwargs = {}
|
||||
# Use Azure Open Datasets for MNIST dataset
|
||||
datasets.MNIST.resources = [
|
||||
("https://azureopendatastorage.azurefd.net/mnist/train-images-idx3-ubyte.gz",
|
||||
"f68b3c2dcbeaaa9fbdd348bbdeb94873"),
|
||||
("https://azureopendatastorage.azurefd.net/mnist/train-labels-idx1-ubyte.gz",
|
||||
"d53e105ee54ea40749a09fcbcd1e9432"),
|
||||
("https://azureopendatastorage.azurefd.net/mnist/t10k-images-idx3-ubyte.gz",
|
||||
"9fb629c4189551a2d022fa330f9573f3"),
|
||||
("https://azureopendatastorage.azurefd.net/mnist/t10k-labels-idx1-ubyte.gz",
|
||||
"ec29112dd5afa0611ce80d1b7f02629c")
|
||||
]
|
||||
train_dataset = \
|
||||
datasets.MNIST('data-%d' % hvd.rank(), train=True, download=True,
|
||||
transform=transforms.Compose([
|
||||
|
||||
@@ -102,6 +102,17 @@ torch.manual_seed(args.seed)
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
||||
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
|
||||
# Use Azure Open Datasets for MNIST dataset
|
||||
datasets.MNIST.resources = [
|
||||
("https://azureopendatastorage.azurefd.net/mnist/train-images-idx3-ubyte.gz",
|
||||
"f68b3c2dcbeaaa9fbdd348bbdeb94873"),
|
||||
("https://azureopendatastorage.azurefd.net/mnist/train-labels-idx1-ubyte.gz",
|
||||
"d53e105ee54ea40749a09fcbcd1e9432"),
|
||||
("https://azureopendatastorage.azurefd.net/mnist/t10k-images-idx3-ubyte.gz",
|
||||
"9fb629c4189551a2d022fa330f9573f3"),
|
||||
("https://azureopendatastorage.azurefd.net/mnist/t10k-labels-idx1-ubyte.gz",
|
||||
"ec29112dd5afa0611ce80d1b7f02629c")
|
||||
]
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('../data', train=True, download=True,
|
||||
transform=transforms.Compose([
|
||||
|
||||
@@ -332,6 +332,18 @@
|
||||
"import random\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"# Use Azure Open Datasets for MNIST dataset\n",
|
||||
"datasets.MNIST.resources = [\n",
|
||||
" (\"https://azureopendatastorage.azurefd.net/mnist/train-images-idx3-ubyte.gz\",\n",
|
||||
" \"f68b3c2dcbeaaa9fbdd348bbdeb94873\"),\n",
|
||||
" (\"https://azureopendatastorage.azurefd.net/mnist/train-labels-idx1-ubyte.gz\",\n",
|
||||
" \"d53e105ee54ea40749a09fcbcd1e9432\"),\n",
|
||||
" (\"https://azureopendatastorage.azurefd.net/mnist/t10k-images-idx3-ubyte.gz\",\n",
|
||||
" \"9fb629c4189551a2d022fa330f9573f3\"),\n",
|
||||
" (\"https://azureopendatastorage.azurefd.net/mnist/t10k-labels-idx1-ubyte.gz\",\n",
|
||||
" \"ec29112dd5afa0611ce80d1b7f02629c\")\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"test_data = datasets.MNIST('../data', train=False, transform=transforms.Compose([\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize((0.1307,), (0.3081,))]))\n",
|
||||
|
||||
Reference in New Issue
Block a user