Я загружаю изображения серой шкалы партиями с помощью ImageDataGenerator
. Мне нужно скопировать содержимое каждого изображения шкалы серого в 3 канала. Я пробовал следующий код, но кажется, что он не работает:
def grayscale_to_rgb(images, channel_axis=-1):
images= K.expand_dims(images, axis=channel_axis)
tiling = [1] * 4 # 4 dimensions: B, H, W, C
tiling[channel_axis] *= 3
images= K.tile(images, tiling)
return images
train_images_orign= grayscale_to_rgb(train_images_orign)
valid_images_orign= grayscale_to_rgb(valid_images_orign)
test_images_orign= grayscale_to_rgb(test_images_orign)
x_train, y_train = next(train_images_orign)
x_valid, y_valid = next(valid_images_orign)
x_test, y_test = next(test_images_orign)
В каком направлении я должен попытаться это сделать?
Обновление. Оказывается, что функция load_img
в load_img
реализована таким образом, что если цветовой режим загружаемого изображения и заданный аргумент color_mode
(который по умолчанию равен 'RGB'
), то оно будет преобразовано в заданный color_mode
. В результате в этом случае изображение в градациях серого будет автоматически преобразовано в RGB.
Вы также можете использовать аргумент preprocessing_function
для ImageDataGenerator
(при условии, что вы используете color_mode='grayscale'
, в противном случае применяется примечание выше):
import numpy as np
def gray_to_rgb(img):
return np.repeat(img, 3, 2)
generator = ImageDataGenerator(..., preprocessing_function=gray_to_rgb)
train_gen = generator.flow_from_directory(color_mode='grayscale', ...)
Просто обратите внимание, что эта функция применяется после любого увеличения изображения:
preprocessing_function: функция, которая будет подразумеваться на каждом входе. Функция будет работать после изменения размера изображения и его увеличения. Функция должна принимать один аргумент: одно изображение (тензор Numpy с рангом 3) и должно выводить тензор Numpy с одинаковой формой.
Я думаю, что у меня есть лучшее решение, которое заключается в том, чтобы написать оберточный слой
class MyPreprocess( Layer ) :
def call( self, inputs ) :
# expand your input from gray scale to rgb
# if your inputs.shape = (None,None,1)
fake_rgb = K.concatenate( [inputs for i in range(3)], axis=-1 )
fake_rgb = K.cast( fake_rgb, 'float32' )
# else use K.stack( [inputs for i in range(3)], axis=-1 )
# preprocess for uint8 image
x = preprocess_input( fake_rgb )
return x
def compute_output_shape( self, input_shape ) :
return input_shape[:3] + (3,)
gray_in = Input(shape=(None,None,1), name='gray_uint8')
tensor_in = MyPreprocess(name='preproc')( gray_in )
pred_out = PretrainedModel( tensor_in )
new_model = Model( inputs=gray_in, outputs=pred_out )
Таким образом, new_model
может быть непосредственно использован для прогнозирования изображения unit8
.
У меня нет прямого способа конвертировать его с помощью ImageDataGenerator, но у меня была такая же проблема несколько дней назад, и косвенно вы можете использовать opencv2 для преобразования его в rgb, тогда я прочитал его с использованием imageio to numpy.
import cv2
cv2.imread('path/to/img/a.png') # look at glob for reading from folder
cv2.cvtColor(gray,cv2.COLOR_GRAY2RGB)
cv2.imwrite('path/to/output/a.png')
import imageio
import skimage.transform as transform
image_size = 64
dimensions=4
array = []
for image_path in glob.glob("path/to/output/*.png"):
try:
im = imageio.imread(image_path)
array = [*array, transform.resize(im, (image_size, image_size, dimensions))]
array_names = [*array_names, image_path.split("output/")[1].split(".png")[0]]
except ValueError:
""
np.array(array).shape
Expand_dims можно использовать следующим образом с помощью методов-методах tenorflow 1.10.1, которые тесно связаны с ImageDataGenerator, чтобы сделать то же самое, но вам придется каким-то образом преобразовать его из gray2rgb, что может улучшить этот ответ.
preprocess_input = tf.keras.applications.resnet50.preprocess_input
preds = model.predict(preprocess_input(np.expand_dims(array[0], axis=0))) # where model is some keras model
Для декодирования вывода потребуется следующее:
decode_predictions = tf.keras.applications.resnet50.decode_predictions
decode_predictions(preds, top=3)
Если это не решит вашу проблему или по крайней мере предоставит шаблон для решений, оставьте комментарий, и я обновлю его соответственно :)