本文整理汇总了Python中torch.unsqueeze函数的典型用法代码示例。如果您正苦于以下问题:Python unsqueeze函数的具体用法?Python unsqueeze怎么用?Python unsqueeze使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了unsqueeze函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: __call__
def __call__(self, grid):
batch_size, _, grid_dimX, grid_dimY, grid_dimZ = grid.size()
k = 1.0
x_coords = 2.0 * k * torch.arange(grid_dimX, dtype=torch.float32).unsqueeze(1).unsqueeze(1
).expand(grid_dimX, grid_dimY, grid_dimZ) / (grid_dimX - 1.0) - 1.0
y_coords = 2.0 * k * torch.arange(grid_dimY, dtype=torch.float32).unsqueeze(1).unsqueeze(0
).expand(grid_dimX, grid_dimY, grid_dimZ) / (grid_dimY - 1.0) - 1.0
z_coords = 2.0 * k * torch.arange(grid_dimZ, dtype=torch.float32).unsqueeze(0).unsqueeze(0
).expand(grid_dimX, grid_dimY, grid_dimZ) / (grid_dimZ - 1.0) - 1.0
coords = torch.stack((x_coords, y_coords, z_coords), dim=0)
if self.with_r:
rs = ((x_coords ** 2) + (y_coords ** 2) + (z_coords ** 2)) ** 0.5
rs = k * rs / torch.max(rs)
rs = torch.unsqueeze(rs, dim=0)
coords = torch.cat((coords, rs), dim=0)
coords = torch.unsqueeze(coords, dim=0).repeat(batch_size, 1, 1, 1, 1)
grid = torch.cat((coords.to(grid.device), grid), dim=1)
return grid
开发者ID:caskeep,项目名称:3D-SIS,代码行数:25,代码来源:coord_conv3d.py
示例2: bootstrapped_cross_entropy2d
def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True):
batch_size = input.size()[0]
def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True):
n, c, h, w = input.size()
log_p = F.log_softmax(input, dim=1)
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0]
log_p = log_p.view(-1, c)
mask = target >= 0
target = target[mask]
loss = F.nll_loss(log_p, target, weight=weight, ignore_index=250,
reduce=False, size_average=False)
topk_loss, _ = loss.topk(K)
reduced_topk_loss = topk_loss.sum() / K
return reduced_topk_loss
loss = 0.0
# Bootstrap from each image not entire batch
for i in range(batch_size):
loss += _bootstrap_xentropy_single(input=torch.unsqueeze(input[i], 0),
target=torch.unsqueeze(target[i], 0),
K=K,
weight=weight,
size_average=size_average)
return loss / float(batch_size)
开发者ID:clavichord93,项目名称:pytorch-semseg,代码行数:29,代码来源:loss.py
示例3: forward
def forward(self, image_feat_variable,
input_question_variable, input_answers=None, **kwargs):
question_embeddings = []
for q_model in self.question_embedding_models:
q_embedding = q_model(input_question_variable)
question_embeddings.append(q_embedding)
question_embedding = torch.cat(question_embeddings, dim=1)
if isinstance(image_feat_variable, list):
image_embeddings = []
for idx, image_feat in enumerate(image_feat_variable):
ques_embedding_each = torch.unsqueeze(
question_embedding[idx, :], 0)
image_feat_each = torch.unsqueeze(image_feat, dim=0)
attention_each = self.image_attention_model(
image_feat_each, ques_embedding_each)
image_embedding_each = torch.sum(
attention_each * image_feat, dim=1)
image_embeddings.append(image_embedding_each)
image_embedding = torch.cat(image_embeddings, dim=0)
else:
attention = self.image_attention_model(
image_feat_variable, question_embedding)
image_embedding = torch.sum(attention * image_feat_variable, dim=1)
joint_embedding = self.nonLinear_question(
question_embedding) * self.nonLinear_image(image_embedding)
logit_res = self.classifier(joint_embedding)
return logit_res
开发者ID:xiaojie18,项目名称:pythia,代码行数:31,代码来源:top_down_bottom_up_model.py
示例4: ycrcb_to_rgb_torch
def ycrcb_to_rgb_torch(input_tensor, delta = 0.5):
y, cr, cb = input_tensor[:,0,:,:], input_tensor[:,1,:,:], input_tensor[:,2,:,:]
r = torch.unsqueeze(y + 1.403 * (cr - delta), 1)
g = torch.unsqueeze(y - 0.714 * (cr - delta) - 0.344 * (cb - delta), 1)
b = torch.unsqueeze(y + 1.773 * (cb - delta), 1)
return torch.cat([r, g, b], 1)
开发者ID:stereomatchingkiss,项目名称:blogCodes2,代码行数:7,代码来源:color_converter.py
示例5: predict
def predict(self, wm, s, a, ls):
with torch.no_grad():
self.embedding, _ = create_emb_layer(wm)
s_embedded = self.embedding(s)
a_embedded = self.embedding(a)
# Average the aspect embedding
a_new_embedded = torch.zeros(len(s),1,100)
for i in range(len(a_embedded)):
if len(torch.nonzero(a_embedded[i])):
a_new_embedded[i] = torch.unsqueeze(torch.sum(a_embedded[i], 0)/len(torch.nonzero(a_embedded[i])),0)
a_embedded = a_new_embedded
embedded = torch.zeros(len(s),40,200)
# Concatenate each word in sentence with aspect vector
zero_tag = torch.zeros(100).cuda()
for i in range(len(s_embedded)):
for j in range(40):
if j<(ls[i]-1):
embedded[i][j] = torch.unsqueeze(torch.cat((s_embedded[i][j].cuda(),torch.squeeze(a_embedded[i].cuda(),0)),0),0)
else:
embedded[i][j] = torch.unsqueeze(torch.cat((s_embedded[i][j].cuda(),zero_tag),0),0)
out, (h, c) = self.lstm(embedded.cuda())
hidden = self.dropout(torch.cat((h[-2,:,:], h[-1,:,:]), dim=1))
hidden2pred = self.fc(hidden)
pred = self.softmax(hidden2pred)
return pred
开发者ID:bearcave9,项目名称:Weekend-Projects,代码行数:30,代码来源:absa_models.py
示例6: _morph_face
def _morph_face(self, face, expresion):
face = torch.unsqueeze(self._transform(Image.fromarray(face)), 0)
expresion = torch.unsqueeze(torch.from_numpy(expresion/5.0), 0)
test_batch = {'real_img': face, 'real_cond': expresion, 'desired_cond': expresion, 'sample_id': torch.FloatTensor(), 'real_img_path': []}
self._model.set_input(test_batch)
imgs, _ = self._model.forward(keep_data_for_visuals=False, return_estimates=True)
return imgs['concat']
开发者ID:iGuaZi,项目名称:GANimation,代码行数:7,代码来源:test.py
示例7: outer
def outer(vec1, vec2=None):
'''Batch support for vectors outer products.
This function is broadcast-able,
so you can provide batched vec1 or batched vec2 or both.
Args:
vec1: A vector of size (Batch, Size1).
vec2: A vector of size (Batch, Size2)
if vec2 is None, vec2 = vec1.
Returns:
The outer product of vec1 and vec2 (Batch, Size1, Size2).
'''
if vec2 is None:
vec2 = vec1
if len(vec1.size()) == 1 and len(vec2.size()) == 1:
return torch.ger(vec1, vec2)
else: # batch outer product
if len(vec1.size()) == 1:
vec1 = torch.unsqueeze(vec1, 0)
if len(vec2.size()) == 1:
vec2 = torch.unsqueeze(vec2, 0)
vec1 = torch.unsqueeze(vec1, -1)
vec2 = torch.unsqueeze(vec2, -2)
if vec1.size(0) == vec2.size(0):
return torch.bmm(vec1, vec2)
else:
return vec1.matmul(vec2)
开发者ID:ModarTensai,项目名称:network_moments,代码行数:29,代码来源:ops.py
示例8: forward
def forward(self, x):
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
# 299 x 299 x 3
x = self.Conv2d_1a_3x3(x)
# 149 x 149 x 32
x = self.Conv2d_2a_3x3(x)
# 147 x 147 x 32
x = self.Conv2d_2b_3x3(x)
# 147 x 147 x 64
x = F.max_pool2d(x, kernel_size=3, stride=2)
# 73 x 73 x 64
x = self.Conv2d_3b_1x1(x)
# 73 x 73 x 80
x = self.Conv2d_4a_3x3(x)
# 71 x 71 x 192
x = F.max_pool2d(x, kernel_size=3, stride=2)
# 35 x 35 x 192
x = self.Mixed_5b(x)
# 35 x 35 x 256
x = self.Mixed_5c(x)
# 35 x 35 x 288
x = self.Mixed_5d(x)
# 35 x 35 x 288
x = self.Mixed_6a(x)
# 17 x 17 x 768
x = self.Mixed_6b(x)
# 17 x 17 x 768
x = self.Mixed_6c(x)
# 17 x 17 x 768
x = self.Mixed_6d(x)
# 17 x 17 x 768
x = self.Mixed_6e(x)
# 17 x 17 x 768
if self.training and self.aux_logits:
aux = self.AuxLogits(x)
# 17 x 17 x 768
x = self.Mixed_7a(x)
# 8 x 8 x 1280
x = self.Mixed_7b(x)
# 8 x 8 x 2048
x = self.Mixed_7c(x)
# 8 x 8 x 2048
x = F.avg_pool2d(x, kernel_size=8)
# 1 x 1 x 2048
x = F.dropout(x, training=self.training)
# 1 x 1 x 2048
x = x.view(x.size(0), -1)
# 2048
x = self.fc(x)
# 1000 (num_classes)
if self.training and self.aux_logits:
return x, aux
return x
开发者ID:choasUp,项目名称:Vision,代码行数:57,代码来源:inception.py
示例9: forward
def forward(self, output, target):
P = F.softmax(output)
f_out = F.log_softmax(output)
Pt = P.gather(1, torch.unsqueeze(target, 1))
focus_p = torch.pow(1 - Pt, self.y)
alpha = 0.25
nll_feature = -f_out.gather(1, torch.unsqueeze(target, 1))
weight_nll = alpha * focus_p * nll_feature
loss = weight_nll.mean()
return loss
开发者ID:shubhampachori12110095,项目名称:pytorch-cv,代码行数:10,代码来源:seg_modules.py
示例10: _mask_attentions
def _mask_attentions(attention, image_locs):
batch_size, num_loc, n_att = attention.data.shape
tmp1 = torch.unsqueeze(
torch.arange(0, num_loc).type(torch.LongTensor),
dim=0).expand(batch_size, num_loc)
tmp1 = tmp1.cuda() if use_cuda else tmp1
tmp2 = torch.unsqueeze(image_locs.data, 1).expand(batch_size, num_loc)
mask = torch.ge(tmp1, tmp2)
mask = torch.unsqueeze(mask, 2).expand_as(attention)
attention.data.masked_fill_(mask, 0)
return attention
开发者ID:xiaojie18,项目名称:pythia,代码行数:11,代码来源:image_attention.py
示例11: run
def run(self):
complete_episodes = 0
episode_final = False
output = open('result.log', 'w')
print(self.num_states, self.num_actions)
for episode in range(NUM_EPISODE):
observation = self.env.reset()
state = torch.from_numpy(observation).type(torch.FloatTensor)
state = torch.unsqueeze(state, 0)
for step in range(MAX_STEPS):
if episode_final:
self.env.render(mode='rgb_array')
action = self.agent.get_action(state, episode)
observation_next, _, done, _ = self.env.step(action.item())
state_next = torch.from_numpy(observation_next).type(torch.FloatTensor)
state_next = torch.unsqueeze(state_next, 0)
reward = torch.FloatTensor([0.0])
if done:
state_next = None
if 199 <= step:
reward = torch.FloatTensor([-1.0])
complete_episodes = 0
else:
reward = torch.FloatTensor([1.0])
complete_episodes = complete_episodes + 1
self.agent.memory(state, action, state_next, reward)
self.agent.update_q_function()
state = state_next
if done:
message = 'episode: {0}, step: {1}'.format(episode, step)
print(message)
output.write(message + '\n')
break
if episode_final:
break
if 10 <= complete_episodes:
print('success 10 times in sequence')
# episode_final = True
self.env.close()
output.close()
开发者ID:y-kamiya,项目名称:machine-learning-samples,代码行数:52,代码来源:mountaincar_dqn.py
示例12: forward
def forward(self, img, qst):
x = self.conv(img) ## x = (64 x 24 x 5 x 5)
"""g"""
mb = x.size()[0]
n_channels = x.size()[1]
d = x.size()[2]
# x_flat = (64 x 25 x 24)
x_flat = x.view(mb,n_channels,d*d).permute(0,2,1)
# add coordinates
x_flat = torch.cat([x_flat, self.coord_tensor],2)
# add question everywhere
qst = torch.unsqueeze(qst, 1)
qst = qst.repeat(1,25,1)
qst = torch.unsqueeze(qst, 2)
# cast all pairs against each other
x_i = torch.unsqueeze(x_flat,1) # (64x1x25x26+11)
x_i = x_i.repeat(1,25,1,1) # (64x25x25x26+11)
x_j = torch.unsqueeze(x_flat,2) # (64x25x1x26+11)
x_j = torch.cat([x_j,qst],3)
x_j = x_j.repeat(1,1,25,1) # (64x25x25x26+11)
# concatenate all together
x_full = torch.cat([x_i,x_j],3) # (64x25x25x2*26+11)
# reshape for passing through network
x_ = x_full.view(mb*d*d*d*d,63)
x_ = self.g_fc1(x_)
x_ = F.relu(x_)
x_ = self.g_fc2(x_)
x_ = F.relu(x_)
x_ = self.g_fc3(x_)
x_ = F.relu(x_)
x_ = self.g_fc4(x_)
x_ = F.relu(x_)
# reshape again and sum
x_g = x_.view(mb,d*d*d*d,256)
x_g = x_g.sum(1).squeeze()
"""f"""
x_f = self.f_fc1(x_g)
x_f = F.relu(x_f)
return self.fcout(x_f)
开发者ID:SikaStar,项目名称:relational-networks,代码行数:48,代码来源:model.py
示例13: main
def main():
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())
x = Variable(x)
y = Variable(y)
net = RegreNN(1,1)
optm = torch.optim.SGD(net.parameters(),lr=0.5e-1)
loss_func = torch.nn.MSELoss()
plt.ion()
for i in range(600):
v = net(x)
loss = loss_func(v,y)
optm.zero_grad()
loss.backward()
optm.step()
if i % 100 == 0:
print(loss)
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), v.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1)
plt.ioff()
plt.show()
开发者ID:chenjianhong,项目名称:machineleaning,代码行数:29,代码来源:nn_regression.py
示例14: update_parameters
def update_parameters(self, batch):
state_batch = Variable(torch.cat(batch.state))
next_state_batch = Variable(torch.cat(batch.next_state), volatile=True)
action_batch = Variable(torch.cat(batch.action))
reward_batch = Variable(torch.cat(batch.reward))
mask_batch = Variable(torch.cat(batch.mask))
next_action_batch = self.actor_target(next_state_batch)
next_state_action_values = self.critic_target(next_state_batch, next_action_batch)
reward_batch = torch.unsqueeze(reward_batch, 1)
expected_state_action_batch = reward_batch + (self.gamma * next_state_action_values)
self.critic_optim.zero_grad()
state_action_batch = self.critic((state_batch), (action_batch))
value_loss = MSELoss(state_action_batch, expected_state_action_batch)
value_loss.backward()
self.critic_optim.step()
self.actor_optim.zero_grad()
policy_loss = -self.critic((state_batch),self.actor((state_batch)))
policy_loss = policy_loss.mean()
policy_loss.backward()
self.actor_optim.step()
soft_update(self.actor_target, self.actor, self.tau)
soft_update(self.critic_target, self.critic, self.tau)
开发者ID:GuanyuGao,项目名称:thermal_project,代码行数:31,代码来源:ddpg.py
示例15: loss
def loss(anchors, data, pred, threshold):
iou = pred['iou']
device_id = iou.get_device() if torch.cuda.is_available() else None
rows, cols = pred['feature'].size()[-2:]
iou_matrix, _iou, _, _data = iou_match(pred['yx_min'].data, pred['yx_max'].data, data)
anchors = utils.ensure_device(anchors, device_id)
positive = fit_positive(rows, cols, *(data[key] for key in 'yx_min, yx_max'.split(', ')), anchors)
negative = ~positive & (_iou < threshold)
_center_offset, _size_norm = fill_norm(*(_data[key] for key in 'yx_min, yx_max'.split(', ')), anchors)
positive, negative, _iou, _center_offset, _size_norm, _cls = (torch.autograd.Variable(t) for t in (positive, negative, _iou, _center_offset, _size_norm, _data['cls']))
_positive = torch.unsqueeze(positive, -1)
loss = {}
# iou
loss['foreground'] = F.mse_loss(iou[positive], _iou[positive], size_average=False)
loss['background'] = torch.sum(square(iou[negative]))
# bbox
loss['center'] = F.mse_loss(pred['center_offset'][_positive], _center_offset[_positive], size_average=False)
loss['size'] = F.mse_loss(pred['size_norm'][_positive], _size_norm[_positive], size_average=False)
# cls
if 'logits' in pred:
logits = pred['logits']
if len(_cls.size()) > 3:
loss['cls'] = F.mse_loss(F.softmax(logits, -1)[_positive], _cls[_positive], size_average=False)
else:
loss['cls'] = F.cross_entropy(logits[_positive].view(-1, logits.size(-1)), _cls[positive].view(-1))
# normalize
cnt = float(np.multiply.reduce(positive.size()))
for key in loss:
loss[key] /= cnt
return loss, dict(iou=_iou, data=_data, positive=positive, negative=negative)
开发者ID:codealphago,项目名称:yolo2-pytorch,代码行数:30,代码来源:__init__.py
示例16: plot_means
def plot_means(ax, model, data, xlimits=[-6, 6], ylimits=[-6, 6],
numticks=101, cmap=None, alpha=1., legend=False, n_samps=10, cs_to_use=None):
x = np.linspace(*xlimits, num=numticks)
y = np.linspace(*ylimits, num=numticks)
X, Y = np.meshgrid(x, y)
aaa = torch.from_numpy(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T).type(torch.FloatTensor)
if len(data) < n_samps:
n_samps = len(data)
means = []
for samp_i in range(n_samps):
if samp_i % 1000 == 0:
print samp_i
mean, logvar = model.encode(Variable(torch.unsqueeze(data[samp_i],0)))
# print mean.data[0][0]
means.append(np.array([mean.data[0][0],mean.data[0][1]]))
# print mean
# print mean[0][0].data[0]
means=np.array(means)
# print means.T[0]
# plt.scatter(means.T[0],means.T[1], marker='x', s=3, alpha=alpha)
plt.scatter(means.T[0],means.T[1], s=.1, alpha=alpha)
ax.set_yticks([])
ax.set_xticks([])
plt.gca().set_aspect('equal', adjustable='box')
开发者ID:chriscremer,项目名称:Other_Code,代码行数:28,代码来源:plotting_functions.py
示例17: plot_isocontours_expected
def plot_isocontours_expected(ax, model, data, xlimits=[-6, 6], ylimits=[-6, 6],
numticks=101, cmap=None, alpha=1., legend=False):
x = np.linspace(*xlimits, num=numticks)
y = np.linspace(*ylimits, num=numticks)
X, Y = np.meshgrid(x, y)
# zs = np.exp(func(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T))
aaa = torch.from_numpy(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T).type(torch.FloatTensor)
n_samps = 10
if len(data) < n_samps:
n_samps = len(data)
for samp_i in range(n_samps):
if samp_i % 1000 == 0:
print samp_i
mean, logvar = model.encode(Variable(torch.unsqueeze(data[samp_i],0)))
func = lambda zs: lognormal4(torch.Tensor(zs), torch.squeeze(mean.data), torch.squeeze(logvar.data))
# print aaa.size()
bbb = func(aaa)
# print 'sum:1', torch.sum(bbb)
ddd = torch.exp(bbb)
# print 'sum:', torch.sum(ddd)
# print ddd.size()
# fdsa
if samp_i ==0:
sum_of_all = ddd
else:
sum_of_all = sum_of_all + ddd
avg_of_all = sum_of_all / n_samps
Z = avg_of_all.view(X.shape)
Z=Z.numpy()
# print 'sum:', np.sum(Z)
cs = plt.contour(X, Y, Z, cmap=cmap, alpha=alpha)
if legend:
nm, lbl = cs.legend_elements()
plt.legend(nm, lbl, fontsize=4)
ax.set_yticks([])
ax.set_xticks([])
plt.gca().set_aspect('equal', adjustable='box')
return Z
开发者ID:chriscremer,项目名称:Other_Code,代码行数:57,代码来源:plotting_functions.py
示例18: _forward_rnn
def _forward_rnn(cell, input_, length, hx):
# max_time = input_.size(0)
seq_len = input_.size(0)
output = []
for i in range(seq_len):
h_next, c_next = cell(input_=input_[i], hx=hx)
if i == 0:
output = torch.unsqueeze((h_next, 0))
else:
output = torch.cat([output, torch.unsqueeze(h_next, 0)], 0)
# mask = (i < length).float().unsqueeze(1).expand_as(h_next)
# h_next = h_next*mask + hx[0]*(1 - mask)
# c_next = c_next*mask + hx[1]*(1 - mask)
hx_next = (h_next, c_next)
# output.append(h_next)
hx = hx_next
output = torch.stack(output, 0)
return output, hx
开发者ID:Joyce94,项目名称:sentence_classification,代码行数:18,代码来源:bnlstm.py
示例19: forward
def forward(self, x):
if len(x.size()) == 3: # N x k xdim
# N x dim x k x 1
x_reshape = torch.unsqueeze(x.permute(0, 2, 1), 3)
elif len(x.size()) == 2: # N x dim
# N x dim x 1 x 1
x_reshape = torch.unsqueeze(torch.unsqueeze(x, 2), 3)
iatt_conv1 = self.conv1(x_reshape) # N x hidden_dim x * x 1
iatt_relu = F.relu(iatt_conv1)
iatt_conv2 = self.conv2(iatt_relu) # N x out_dim x * x 1
if len(x.size()) == 3:
iatt_conv3 = torch.squeeze(iatt_conv2, 3).permute(0, 2, 1)
elif len(x.size()) == 2:
iatt_conv3 = torch.squeeze(torch.squeeze(iatt_conv2, 3), 2)
return iatt_conv3
开发者ID:xiaojie18,项目名称:pythia,代码行数:18,代码来源:post_combine_transform.py
示例20: __getitem__
def __getitem__(self, index):
this_record = self.list_sample[index]
# load image and label
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
img = imread(image_path, mode='RGB')
img = img[:, :, ::-1] # BGR to RGB!!!
segm = imread(segm_path)
ori_height, ori_width, _ = img.shape
img_resized_list = []
for this_short_size in self.imgSize:
# calculate target height and width
scale = min(this_short_size / float(min(ori_height, ori_width)),
self.imgMaxSize / float(max(ori_height, ori_width)))
target_height, target_width = int(ori_height * scale), int(ori_width * scale)
# to avoid rounding in network
target_height = round2nearest_multiple(target_height, self.padding_constant)
target_width = round2nearest_multiple(target_width, self.padding_constant)
# resize
img_resized = cv2.resize(img.copy(), (target_width, target_height))
# image to float
img_resized = img_resized.astype(np.float32)
img_resized = img_resized.transpose((2, 0, 1))
img_resized = self.img_transform(torch.from_numpy(img_resized))
img_resized = torch.unsqueeze(img_resized, 0)
img_resized_list.append(img_resized)
segm = torch.from_numpy(segm.astype(np.int)).long()
batch_segms = torch.unsqueeze(segm, 0)
batch_segms = batch_segms - 1 # label from -1 to 149
output = dict()
output['img_ori'] = img.copy()
output['img_data'] = [x.contiguous() for x in img_resized_list]
output['seg_label'] = batch_segms.contiguous()
output['info'] = this_record['fpath_img']
return output
开发者ID:zyxunh,项目名称:semantic-segmentation-pytorch,代码行数:44,代码来源:dataset.py
注:本文中的torch.unsqueeze函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论