| #extension GL_EXT_buffer_reference : require |
| #extension GL_EXT_buffer_reference2 : require |
|
|
| #define ACQUIRE gl_StorageSemanticsBuffer, gl_SemanticsAcquire |
| #define RELEASE gl_StorageSemanticsBuffer, gl_SemanticsRelease |
|
|
| |
| #define FLAG_NOT_READY 0u |
| #define FLAG_AGGREGATE_READY 1u |
| #define FLAG_PREFIX_READY 2u |
|
|
| layout(buffer_reference, buffer_reference_align = T_ALIGN) nonprivate buffer StateData { |
| DTYPE aggregate; |
| DTYPE prefix; |
| uint flag; |
| }; |
|
|
| shared DTYPE sh_scratch[WG_SIZE]; |
| shared DTYPE sh_prefix; |
| shared uint sh_part_ix; |
| shared uint sh_flag; |
|
|
| void prefix_sum(DataBuffer dst, uint dst_stride, DataBuffer src, uint src_stride) |
| { |
| DTYPE local[N_ROWS]; |
| |
| if (gl_GlobalInvocationID.x == 0) |
| sh_part_ix = gl_WorkGroupID.x; |
| |
|
|
| barrier(); |
| uint part_ix = sh_part_ix; |
|
|
| uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS; |
|
|
| |
| local[0] = src.v[ix*src_stride]; |
| for (uint i = 1; i < N_ROWS; i++) |
| local[i] = local[i - 1] + src.v[(ix + i)*src_stride]; |
|
|
| DTYPE agg = local[N_ROWS - 1]; |
| sh_scratch[gl_LocalInvocationID.x] = agg; |
| for (uint i = 0; i < LG_WG_SIZE; i++) { |
| barrier(); |
| if (gl_LocalInvocationID.x >= (1u << i)) |
| agg += sh_scratch[gl_LocalInvocationID.x - (1u << i)]; |
| barrier(); |
|
|
| sh_scratch[gl_LocalInvocationID.x] = agg; |
| } |
|
|
| |
| if (gl_LocalInvocationID.x == WG_SIZE - 1) { |
| state[part_ix].aggregate = agg; |
| if (part_ix == 0) |
| state[0].prefix = agg; |
| } |
|
|
| |
| if (gl_LocalInvocationID.x == WG_SIZE - 1) { |
| uint flag = part_ix == 0 ? FLAG_PREFIX_READY : FLAG_AGGREGATE_READY; |
| atomicStore(state[part_ix].flag, flag, gl_ScopeDevice, RELEASE); |
| } |
|
|
| DTYPE exclusive = DTYPE(0); |
| if (part_ix != 0) { |
| |
| uint look_back_ix = part_ix - 1; |
|
|
| DTYPE their_agg; |
| uint their_ix = 0; |
| while (true) { |
| |
| if (gl_LocalInvocationID.x == WG_SIZE - 1) |
| sh_flag = atomicLoad(state[look_back_ix].flag, gl_ScopeDevice, ACQUIRE); |
|
|
| |
| |
| |
| barrier(); |
|
|
| uint flag = sh_flag; |
| barrier(); |
|
|
| if (flag == FLAG_PREFIX_READY) { |
| if (gl_LocalInvocationID.x == WG_SIZE - 1) { |
| DTYPE their_prefix = state[look_back_ix].prefix; |
| exclusive = their_prefix + exclusive; |
| } |
| break; |
| } else if (flag == FLAG_AGGREGATE_READY) { |
| if (gl_LocalInvocationID.x == WG_SIZE - 1) { |
| their_agg = state[look_back_ix].aggregate; |
| exclusive = their_agg + exclusive; |
| } |
| look_back_ix--; |
| their_ix = 0; |
| continue; |
| } |
|
|
| if (gl_LocalInvocationID.x == WG_SIZE - 1) { |
| |
| |
| |
| DTYPE m = src.v[(look_back_ix * PARTITION_SIZE + their_ix)*src_stride]; |
| if (their_ix == 0) |
| their_agg = m; |
| else |
| their_agg += m; |
|
|
| their_ix++; |
| if (their_ix == PARTITION_SIZE) { |
| exclusive = their_agg + exclusive; |
| if (look_back_ix == 0) { |
| sh_flag = FLAG_PREFIX_READY; |
| } else { |
| look_back_ix--; |
| their_ix = 0; |
| } |
| } |
| } |
| barrier(); |
| flag = sh_flag; |
| barrier(); |
| if (flag == FLAG_PREFIX_READY) |
| break; |
| } |
|
|
| |
| if (gl_LocalInvocationID.x == WG_SIZE - 1) { |
| DTYPE inclusive_prefix = exclusive + agg; |
| sh_prefix = exclusive; |
| state[part_ix].prefix = inclusive_prefix; |
| } |
|
|
| if (gl_LocalInvocationID.x == WG_SIZE - 1) |
| atomicStore(state[part_ix].flag, FLAG_PREFIX_READY, gl_ScopeDevice, RELEASE); |
| } |
|
|
| barrier(); |
| if (part_ix != 0) |
| exclusive = sh_prefix; |
|
|
| DTYPE row = exclusive; |
| if (gl_LocalInvocationID.x > 0) |
| row += sh_scratch[gl_LocalInvocationID.x - 1]; |
|
|
| |
| for (uint i = 0; i < N_ROWS; i++) |
| dst.v[(ix + i)*dst_stride] = row + local[i]; |
| } |
|
|