TransWikia.com

keras-yolov3のtrain.pyのコードについて

スタック・オーバーフロー Asked by chocora on December 9, 2020

前提

【物体検出】keras−yolo3の学習方法 を参考に、独自データにおけるYOLOv3の学習を行っています。

上記サイトで使用されているコード:
https://github.com/sleepless-se/keras-yolo3

すべてでなくても良いので教えていただけると助かります。

質問したいこと

1.
train.pyのコードの中で不明な点があります。
52行目から90行目の以下のコードの部分で、2回model.fit_generater()を行っているのはなぜでしょうか。2つの違いを教えていただきたいです。

また、それぞれのepochsとinitial_epochsの数値は何を指定しているのでしょうか。この状態だと50epochsずつの計100epochs繰り返すのですが、自分でepochs数を変えたい場合、どの数値を変更したらどう変わるのか教えていただきたいです。

 # Train with frozen layers first, to get a stable loss.
    # Adjust num epochs to your dataset. This step is enough to obtain a not bad model.
    if True:
        model.compile(optimizer=Adam(lr=1e-3), loss={
            # use custom yolo_loss Lambda layer.
            'yolo_loss': lambda y_true, y_pred: y_pred})

        batch_size = 32
        if len(sys.argv) > 2:
            batch_size = int(sys.argv[2])

        print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
        model.fit_generator(data_generator_wrapper(lines[:num_train], batch_size, input_shape, anchors, num_classes),
                steps_per_epoch=max(1, num_train//batch_size),
                validation_data=data_generator_wrapper(lines[num_train:], batch_size, input_shape, anchors, num_classes),
                validation_steps=max(1, num_val//batch_size),
                epochs=50,
                initial_epoch=0,
                callbacks=[logging, checkpoint])
        model.save_weights(log_dir + 'trained_weights_stage_1.h5')

    # Unfreeze and continue training, to fine-tune.
    # Train longer if the result is not good.
    if True:
        for i in range(len(model.layers)):
            model.layers[i].trainable = True
        model.compile(optimizer=Adam(lr=1e-4), loss={'yolo_loss': lambda y_true, y_pred: y_pred}) # recompile to apply the change
        print('Unfreeze all of the layers.')

        batch_size = 32 # note that more GPU memory is required after unfreezing the body
        print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
        model.fit_generator(data_generator_wrapper(lines[:num_train], batch_size, input_shape, anchors, num_classes),
            steps_per_epoch=max(1, num_train//batch_size),
            validation_data=data_generator_wrapper(lines[num_train:], batch_size, input_shape, anchors, num_classes),
            validation_steps=max(1, num_val//batch_size),
            epochs=100,
            initial_epoch=50,
            callbacks=[logging, checkpoint, reduce_lr, early_stopping])
        model.save_weights(log_dir + 'trained_weights_final.h5')

2.学習中のログに関して

このログを出力するコードは何行目に書かれているのでしょうか。また、デフォルトだとlossとval_lossが出力されているのですが、accuracyを一緒に表示するのはどうすればよいのでしょうか。

3.データ拡張に関して

170行目から192行目で定義されるdata_generatorの部分で学習データの水増しはされているのでしょうか。
また、されていない場合どこにどのようなコードを追加すればできるのでしょうか。

def data_generator(annotation_lines, batch_size, input_shape, anchors, num_classes):
    '''data generator for fit_generator'''
    n = len(annotation_lines)
    i = 0
    while True:
        image_data = []
        box_data = []
        for b in range(batch_size):
            if i==0:
                np.random.shuffle(annotation_lines)
            image, box = get_random_data(annotation_lines[i], input_shape, random=True)
            image_data.append(image)
            box_data.append(box)
            i = (i+1) % n
        image_data = np.array(image_data)
        box_data = np.array(box_data)
        y_true = preprocess_true_boxes(box_data, input_shape, anchors, num_classes)
        yield [image_data, *y_true], np.zeros(batch_size)

def data_generator_wrapper(annotation_lines, batch_size, input_shape, anchors, num_classes):
    n = len(annotation_lines)
    if n==0 or batch_size<=0: return None
    return data_generator(annotation_lines, batch_size, input_shape, anchors, num_classes)

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP