Version 0.1.80

This commit is contained in:
rastala
2018-11-20 11:00:48 -05:00
parent d10b1fa796
commit 5726fe3ddb
61 changed files with 23040 additions and 24681 deletions

View File

@@ -59,6 +59,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs, data_dir):
dataloaders, dataset_sizes, class_names = load_data(data_dir)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
@@ -146,12 +147,15 @@ def fine_tune_model(num_epochs, data_dir, learning_rate, momentum):
criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=learning_rate, momentum=momentum)
optimizer_ft = optim.SGD(model_ft.parameters(),
lr=learning_rate, momentum=momentum)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
exp_lr_scheduler = lr_scheduler.StepLR(
optimizer_ft, step_size=7, gamma=0.1)
model = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs, data_dir)
model = train_model(model_ft, criterion, optimizer_ft,
exp_lr_scheduler, num_epochs, data_dir)
return model
@@ -159,15 +163,19 @@ def fine_tune_model(num_epochs, data_dir, learning_rate, momentum):
def main():
# get command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, help='directory of training data')
parser.add_argument('--num_epochs', type=int, default=25, help='number of epochs to train')
parser.add_argument('--data_dir', type=str,
help='directory of training data')
parser.add_argument('--num_epochs', type=int, default=25,
help='number of epochs to train')
parser.add_argument('--output_dir', type=str, help='output directory')
parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate')
parser.add_argument('--learning_rate', type=float,
default=0.001, help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
args = parser.parse_args()
print("data directory is: " + args.data_dir)
model = fine_tune_model(args.num_epochs, args.data_dir, args.learning_rate, args.momentum)
model = fine_tune_model(args.num_epochs, args.data_dir,
args.learning_rate, args.momentum)
os.makedirs(args.output_dir, exist_ok=True)
torch.save(model, os.path.join(args.output_dir, 'model.pt'))