AnswerBun.com

Проверка нейросети на своих данных tensorflow

Stack Overflow на русском Asked by Ylanaish on September 4, 2020

Есть обученная на данных mnist нейросеть. Когда загружаю свое изображение цифры, получаю следующую ошибку:

TypeError                                 Traceback (most recent call last)

<ipython-input-169-3810d7660ceb> in <module>()
      1 i = 0
      2 plt.figure()
----> 3 plot_images(i, predictions, test_labels, x)
      4 plt.show()

6 frames

/usr/local/lib/python3.6/dist-packages/matplotlib/image.py in set_data(self, A)
    697                 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
    698             raise TypeError("Invalid shape {} for image data"
--> 699                             .format(self._A.shape))
    700 
    701         if self._A.ndim == 3:

TypeError: Invalid shape (28, 28, 1) for image data

На тестовых данных ошибки нет. Код пишу в google colab. Вот полный код:

import tensorflow as tf
from tensorflow import keras

import matplotlib.pyplot as plt
import numpy as np
from google.colab import files

(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

model = keras.Sequential([
                          keras.layers.Flatten(input_shape = (28, 28)),
                          keras.layers.Dropout(0.2),
                          keras.layers.Dense(128, activation = 'relu'),
                          keras.layers.Dropout(0.2),
                          keras.layers.Dense(10, activation = 'softmax')
])
model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])

model.fit(train_images, train_labels, epochs = 10)

files.upload()
images = keras.preprocessing.image.load_img("three.png", target_size=(28, 28))    
x = keras.preprocessing.image.img_to_array(images)
x = tf.image.rgb_to_grayscale(x)
x = np.expand_dims(x, axis=0)
x = x/255.0
img = []
img.append(x)

predictions = model.predict(img)

def plot_images(i, predictions_array, true_label, img):
  predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])

  plt.imshow(img, cmap = plt.cm.binary)
  predictions_label = np.argmax(predictions_array)

  plt.xlabel('{}'.format(class_names[predictions_label]))

i = 0
plt.figure()
plot_images(i, predictions, test_labels, x)
plt.show()

Только начал изучать tensorflow, прошу помочь. Спасибо.

введите сюда описание изображения

Add your own answers!

Related Questions

Django, QuerySet

2  Asked on February 3, 2021

   

Ask a Question

Get help from others!

© 2023 AnswerBun.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP