郑重承诺,本文章提供代码保证能够运行可用,若不能用可留言,看到了一定帮忙解决!
很多边缘计算芯片不支持复杂的算子,比如gather算子。
这就需要使用简单的矩阵计算以及一些简单的torch操作来进行替换。
下面分享一些我在工作中用到的,能够替换gather的函数:
文章来源:https://uudwc.com/A/6XjOk
def torch_max_replace_gather(att_weights_prob, dim, ind_k):
b,c,d,h,w = att_weights_prob.shape
b,c,d1,h,w = ind_k.shape
att_weights_prob = att_weights_prob.transpose(4,dim).reshape(b*c*h*w,d).unsqueeze(1).expand(-1,d1,-1)
att_topk1 = ind_k.transpose(4,dim).reshape(b*c*h*w,d1).unsqueeze(2).expand(-1,-1,d)
index_array = torch.arange(d).unsqueeze(0).unsqueeze(0).repeat(b*c*h*w,d1,1).to(att_weights_prob.device)
att_topk1 = index_array - att_topk1
att_topk1 =att_topk1*att_topk1
att_topk1[att_topk1>0]=1
att_topk1 = 1 - att_topk1
min_val= torch.min(att_weights_prob, dim=dim)[0] - 0.1
att_weights_prob = att_weights_prob - min_val.unsqueeze(2)
att_weights_prob = att_weights_prob*att_topk1
max_val = torch.max(att_weights_prob, dim=dim)[0]
att_topk1 = (max_val + min_val).reshape(b,c,w,h,d1).transpose(4,2).transpose(3,4).transpose(3,4)
return att_topk1
文章来源地址https://uudwc.com/A/6XjOk