AnswerBun.com

Проверка нейросети на своих данных tensorflow

Stack Overflow на русском Asked by Ylanaish on September 4, 2020

Есть обученная на данных mnist нейросеть. Когда загружаю свое изображение цифры, получаю следующую ошибку:

TypeError                                 Traceback (most recent call last)

<ipython-input-169-3810d7660ceb> in <module>()
      1 i = 0
      2 plt.figure()
----> 3 plot_images(i, predictions, test_labels, x)
      4 plt.show()

6 frames

/usr/local/lib/python3.6/dist-packages/matplotlib/image.py in set_data(self, A)
    697                 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
    698             raise TypeError("Invalid shape {} for image data"
--> 699                             .format(self._A.shape))
    700 
    701         if self._A.ndim == 3:

TypeError: Invalid shape (28, 28, 1) for image data

На тестовых данных ошибки нет. Код пишу в google colab. Вот полный код:

import tensorflow as tf
from tensorflow import keras

import matplotlib.pyplot as plt
import numpy as np
from google.colab import files

(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

model = keras.Sequential([
                          keras.layers.Flatten(input_shape = (28, 28)),
                          keras.layers.Dropout(0.2),
                          keras.layers.Dense(128, activation = 'relu'),
                          keras.layers.Dropout(0.2),
                          keras.layers.Dense(10, activation = 'softmax')
])
model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])

model.fit(train_images, train_labels, epochs = 10)

files.upload()
images = keras.preprocessing.image.load_img("three.png", target_size=(28, 28))    
x = keras.preprocessing.image.img_to_array(images)
x = tf.image.rgb_to_grayscale(x)
x = np.expand_dims(x, axis=0)
x = x/255.0
img = []
img.append(x)

predictions = model.predict(img)

def plot_images(i, predictions_array, true_label, img):
  predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])

  plt.imshow(img, cmap = plt.cm.binary)
  predictions_label = np.argmax(predictions_array)

  plt.xlabel('{}'.format(class_names[predictions_label]))

i = 0
plt.figure()
plot_images(i, predictions, test_labels, x)
plt.show()

Только начал изучать tensorflow, прошу помочь. Спасибо.

введите сюда описание изображения

Add your own answers!

Related Questions

Real time Laravel

0  Asked on January 2, 2021 by htmlprogrammer

       

Ошибка работы бота

1  Asked on January 1, 2021 by alexsn2020

       

Internal Server Error 500 Postman

1  Asked on January 1, 2021 by blacit

   

Ask a Question

Get help from others!

© 2022 AnswerBun.com. All rights reserved. Sites we Love: PCI Database, MenuIva, UKBizDB, Menu Kuliner, Sharing RPP