overlap_filter
overlap_filter
函数的通俗含义是根据给定的掩码(filter_mask
),逐层地过滤掉主掩码(mask
)中的某些值。具体来说,该函数会从最后一个通道开始,逐层检查并根据对应层的过滤掩码,将前面的通道中对应位置的值设置为零。
结果没有完全看懂
import numpy as np
def overlap_filter(mask, filter_mask):
# 获取 mask 的通道数、高度和宽度
C, _, _ = mask.shape
# 从最后一个通道开始,逐个通道向前遍历
for c in range(C - 1, -1, -1):
# 创建一个过滤器,根据 filter_mask 的当前通道是否不为 0 来确定
filter = np.repeat((filter_mask[c] != 0)[None, :], c, axis=0)
# 将 mask 中前 c 个通道对应位置的值设置为 0
mask[:c][filter] = 0
# 返回修改后的 mask
return mask
# 创建整齐的示例数据
C, H, W = 4, 2, 2 # 通道数、高度和宽度
mask = np.zeros((C, H, W), dtype=np.uint8)
filter_mask = np.zeros((C, H, W), dtype=np.uint8)
# 为 mask 填充整齐的数据
for c in range(C):
mask[c] = c + 1 # 每个通道填充相同的值,方便观察
# 为 filter_mask 填充整齐的数据
filter_mask[0, 0, 0] = 1 # 仅在第一个通道的左上角设置为 1
filter_mask[1, :, 1] = 1 # 在第二个通道的右边一列设置为 1
filter_mask[2, 1, :] = 1 # 在第三个通道的底边一行设置为 1
print("Original mask:")
print(mask)
print("\nFilter mask:")
print(filter_mask)
# 调用 overlap_filter 函数
filtered_mask = overlap_filter(mask.copy(), filter_mask)
print("\nFiltered mask:")
print(filtered_mask)
结果:
Original mask:
[[[1 1]
[1 1]]
[[2 2]
[2 2]]
[[3 3]
[3 3]]
[[4 4]
[4 4]]]
Filter mask:
[[[1 0]
[0 0]]
[[0 1]
[0 1]]
[[0 0]
[1 1]]
[[0 0]
[0 0]]]
Filtered mask:
[[[1 0]
[0 0]]
[[2 2]
[0 0]]
[[3 3]
[3 3]]
[[4 4]
[4 4]]]