深度学习模型移植-替换gather算子

郑重承诺,本文章提供代码保证能够运行可用,若不能用可留言,看到了一定帮忙解决!

很多边缘计算芯片不支持复杂的算子,比如gather算子。

这就需要使用简单的矩阵计算以及一些简单的torch操作来进行替换。

下面分享一些我在工作中用到的,能够替换gather的函数:

def torch_max_replace_gather(att_weights_prob, dim, ind_k):
    b,c,d,h,w = att_weights_prob.shape
    b,c,d1,h,w = ind_k.shape
    att_weights_prob = att_weights_prob.transpose(4,dim).reshape(b*c*h*w,d).unsqueeze(1).expand(-1,d1,-1)
    att_topk1 = ind_k.transpose(4,dim).reshape(b*c*h*w,d1).unsqueeze(2).expand(-1,-1,d)
    index_array = torch.arange(d).unsqueeze(0).unsqueeze(0).repeat(b*c*h*w,d1,1).to(att_weights_prob.device)
    att_topk1 = index_array - att_topk1
    att_topk1 =att_topk1*att_topk1
    att_topk1[att_topk1>0]=1
    att_topk1 = 1 - att_topk1
    min_val= torch.min(att_weights_prob, dim=dim)[0] - 0.1
    att_weights_prob = att_weights_prob - min_val.unsqueeze(2)
    att_weights_prob = att_weights_prob*att_topk1
    max_val = torch.max(att_weights_prob, dim=dim)[0]
    att_topk1 = (max_val + min_val).reshape(b,c,w,h,d1).transpose(4,2).transpose(3,4).transpose(3,4)
    return att_topk1

文章来源地址https://uudwc.com/A/6XjOk

原文地址:https://blog.csdn.net/u011231598/article/details/131441946

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请联系站长进行投诉反馈,一经查实,立即删除!

上一篇 2023年06月28日 21:37
软考高级系统架构设计师(四) 计算机网络3物联网&云计算
下一篇 2023年06月28日 21:37