mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 17:45:10 -05:00
Version 0.1.80
This commit is contained in:
@@ -59,6 +59,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs, data_dir):
|
||||
dataloaders, dataset_sizes, class_names = load_data(data_dir)
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
since = time.time()
|
||||
|
||||
best_model_wts = copy.deepcopy(model.state_dict())
|
||||
@@ -146,12 +147,15 @@ def fine_tune_model(num_epochs, data_dir, learning_rate, momentum):
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# Observe that all parameters are being optimized
|
||||
optimizer_ft = optim.SGD(model_ft.parameters(), lr=learning_rate, momentum=momentum)
|
||||
optimizer_ft = optim.SGD(model_ft.parameters(),
|
||||
lr=learning_rate, momentum=momentum)
|
||||
|
||||
# Decay LR by a factor of 0.1 every 7 epochs
|
||||
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
|
||||
exp_lr_scheduler = lr_scheduler.StepLR(
|
||||
optimizer_ft, step_size=7, gamma=0.1)
|
||||
|
||||
model = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs, data_dir)
|
||||
model = train_model(model_ft, criterion, optimizer_ft,
|
||||
exp_lr_scheduler, num_epochs, data_dir)
|
||||
|
||||
return model
|
||||
|
||||
@@ -159,15 +163,19 @@ def fine_tune_model(num_epochs, data_dir, learning_rate, momentum):
|
||||
def main():
|
||||
# get command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_dir', type=str, help='directory of training data')
|
||||
parser.add_argument('--num_epochs', type=int, default=25, help='number of epochs to train')
|
||||
parser.add_argument('--data_dir', type=str,
|
||||
help='directory of training data')
|
||||
parser.add_argument('--num_epochs', type=int, default=25,
|
||||
help='number of epochs to train')
|
||||
parser.add_argument('--output_dir', type=str, help='output directory')
|
||||
parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate')
|
||||
parser.add_argument('--learning_rate', type=float,
|
||||
default=0.001, help='learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||
args = parser.parse_args()
|
||||
|
||||
print("data directory is: " + args.data_dir)
|
||||
model = fine_tune_model(args.num_epochs, args.data_dir, args.learning_rate, args.momentum)
|
||||
model = fine_tune_model(args.num_epochs, args.data_dir,
|
||||
args.learning_rate, args.momentum)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
torch.save(model, os.path.join(args.output_dir, 'model.pt'))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user