Позвольте мне вызвать функцию, которую я ищу " magic_combine
", которая может объединить непрерывные измерения тензора, которые я им даю. Для более конкретного, я хочу, чтобы он делал следующее:
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.magic_combine(2, 5) # combine dimension 2, 3, 4
print(b.size()) # should be (1, 2, 60, 6)
Я знаю, что torch.view()
может делать аналогичную вещь. Но мне просто интересно, есть ли более элегантный способ достижения цели?
Я не уверен, что вы имеете в виду "более элегантный способ", но Tensor.view()
имеет преимущество не перераспределять данные для представления (исходный тензор и представление имеют одни и те же данные), что делает эту операцию довольно легкий вес.
Как уже упоминалось в @UmangGupta, однако довольно просто перенести эту функцию для достижения того, что вы хотите, например:
import torch
def magic_combine(x, dim_begin, dim_end):
combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
return x.view(combined_shape)
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])