mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 01:27:06 -05:00
update samples from Release-53 as a part of SDK release
This commit is contained in:
@@ -21,9 +21,10 @@ image_size = 299
|
||||
num_channel = 3
|
||||
|
||||
|
||||
def get_class_label_dict():
|
||||
def get_class_label_dict(labels_dir):
|
||||
label = []
|
||||
proto_as_ascii_lines = tf.gfile.GFile("labels.txt").readlines()
|
||||
labels_path = os.path.join(labels_dir, 'labels.txt')
|
||||
proto_as_ascii_lines = tf.gfile.GFile(labels_path).readlines()
|
||||
for l in proto_as_ascii_lines:
|
||||
label.append(l.rstrip())
|
||||
return label
|
||||
@@ -34,14 +35,10 @@ def init():
|
||||
|
||||
parser = argparse.ArgumentParser(description="Start a tensorflow model serving")
|
||||
parser.add_argument('--model_name', dest="model_name", required=True)
|
||||
parser.add_argument('--labels_name', dest="labels_name", required=True)
|
||||
parser.add_argument('--labels_dir', dest="labels_dir", required=True)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
workspace = Run.get_context(allow_offline=False).experiment.workspace
|
||||
label_ds = Dataset.get_by_name(workspace=workspace, name=args.labels_name)
|
||||
label_ds.download(target_path='.', overwrite=True)
|
||||
|
||||
label_dict = get_class_label_dict()
|
||||
label_dict = get_class_label_dict(args.labels_dir)
|
||||
classes_num = len(label_dict)
|
||||
|
||||
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
|
||||
|
||||
Reference in New Issue
Block a user