File size: 15,149 Bytes
f724cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
//
// -----------------------------------------------------------------------------
// The proprietary software and information contained in this file is
// confidential and may only be used by an authorized person under a valid
// licensing agreement from Arm Limited or its affiliates.
//
// Copyright (C) 2025. Arm Limited or its affiliates. All rights reserved.
//
// This entire notice must be reproduced on all copies of this file and
// copies of this file may only be made by an authorized person under a valid
// licensing agreement from Arm Limited or its affiliates.
// -----------------------------------------------------------------------------
//
#version 460
#extension GL_EXT_shader_8bit_storage : require
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_float32 : require
#extension GL_GOOGLE_include_directive : enable

// defines
#define SCALE_1_0X 0
#define SCALE_1_3X 1
#define SCALE_1_5X 2
#define SCALE_2_0X 3

// settings
#define HISTORY_CATMULL
#define SCALE_MODE SCALE_2_0X

// includes
#include "typedefs.h"
#include "common.h"
#include "kernel_lut.h"

// inputs
layout (set=0, binding=0) uniform mediump   sampler2D _ColourTex;               // 540p  | R11G11B10 32bpp
layout (set=0, binding=1) uniform mediump   sampler2D _MotionVectorTex;         // 540p  | RG16_FLOAT 32bpp
layout (set=0, binding=2) uniform mediump   sampler2D _HistoryTex;              // 1080p | R11G11B10 32bpp
layout (set=0, binding=3) uniform lowp      sampler2D _K0Tensor;                // 540p  | R8G8B8A8_SNORM 32bpp | Tensor->Texture Alias (Linear)
layout (set=0, binding=4) uniform lowp      sampler2D _K1Tensor;                // 540p  | R8G8B8A8_SNORM 32bpp | Tensor->Texture Alias (Linear)
layout (set=0, binding=5) uniform lowp      sampler2D _K2Tensor;                // 540p  | R8G8B8A8_SNORM 32bpp | Tensor->Texture Alias (Linear)
layout (set=0, binding=6) uniform lowp      sampler2D _K3Tensor;                // 540p  | R8G8B8A8_SNORM 32bpp | Tensor->Texture Alias (Linear)
layout (set=0, binding=7) uniform lowp      sampler2D _TemporalTensor;          // 540p  | R8G8B8A8_SNORM 32bpp | Tensor->Texture Alias (Linear)
layout (set=0, binding=8) uniform lowp      sampler2D _NearestDepthCoordTex;    // 540p  | R8_UNORM 8bpp

// outputs
layout (set=1, binding=0, r11f_g11f_b10f) uniform writeonly mediump image2D _UpsampledColourOut; // 1080p | R11G11B10 32bpp

// push-constants
layout(push_constant, std430) uniform PushConstants {
    // ─────────────── 8-byte aligned ───────────────
    layout(offset =  0) int32_t2 _OutputDims;        //  8 B
    layout(offset =  8) int32_t2 _InputDims;         //  8 B
    layout(offset = 16) float2   _InvOutputDims;     //  8 B
    layout(offset = 24) float2   _InvInputDims;      //  8 B
    layout(offset = 32) float2   _Scale;             //  8 B
    layout(offset = 40) float2   _InvScale;          //  8 B

    // ─────────────── 4-byte aligned ───────────────
    layout(offset = 48) int16_t2 _IndexModulo;       //  4 B
    layout(offset = 52) half2    _QuantParams;       //  4 B
    layout(offset = 56) int16_t2 _LutOffset;         //  4 B
    layout(offset = 60) half2    _ExposurePair;      //  4 B
    layout(offset = 64) half2    _HistoryPad;        //  4 B
    layout(offset = 68) half2    _MotionThreshPad;   //  4 B (.x = motion, .y = unused)
    layout(offset = 72) int32_t  _Padding0;          //  4 B (explicit pad for alignment)
                                                     // Total: **76 bytes**
};

// Convenience mapping for accessing push constants
#define _Exposure        _ExposurePair.x
#define _InvExposure     _ExposurePair.y
#define _NotHistoryReset _HistoryPad.x
#define _MotionThresh    _MotionThreshPad.x

// Quantization Parameters
// inside: `./parameters.json`
// these values are embdedded inside the TOSA file and learnt during QAT

#ifndef _K0QuantParams
    // outputs - activation_post_process_45["SNORM"]
    #define _K0QuantParams _QuantParams.xy
#endif
#ifndef _K1QuantParams
    // outputs - activation_post_process_50["SNORM"]
    #define _K1QuantParams _QuantParams.xy
#endif
#ifndef _K2QuantParams
    // outputs - activation_post_process_55["SNORM"]
    #define _K2QuantParams _QuantParams.xy
#endif
#ifndef _K3QuantParams
    // outputs - activation_post_process_60["SNORM"]
    #define _K3QuantParams _QuantParams.xy
#endif
#ifndef _TemporalQuantParams
    // outputs - activation_post_process_65["SNORM"]
    #define _TemporalQuantParams _QuantParams.xy
#endif


// methods

half2 LoadMotion(int32_t2 pixel)
{
    return half2(texelFetch(_MotionVectorTex, pixel, 0).rg);
}


half3 LoadHistory(float2 uv)
{
    return half3(textureLod(_HistoryTex, uv, 0).rgb);
}

half3 LoadHistoryCatmull(float2 uv)
{
    //------------------------------------------------------------------------------------
    // 1) Compute Catmull–Rom weights
    //------------------------------------------------------------------------------------
    float2 scaledUV = uv * _OutputDims;
    float2 baseFloor = floor(scaledUV - 0.5) + 0.5;

    half2 f  = half2(scaledUV - baseFloor);
    half2 f2 = f * f;
    half2 f3 = f2 * f;

    // Catmull–Rom basis
    half2 w0 = f2 - 0.5HF * (f3 + f);
    half2 w1 = 1.5HF * f3 - 2.5HF * f2 + 1.0HF;
    half2 w3 = 0.5HF * (f3 - f2);
    half2 w2 = (1.0HF - w0) - w1 - w3; // = 1 - (w0 + w1 + w3)

    // Combine w1 and w2 for center axis
    half2 w12 = w1 + w2;
    half wx0  = w0.x, wy0  = w0.y;
    half wx1  = w12.x, wy1 = w12.y;
    half wx2  = w3.x, wy2  = w3.y;

    // Final weights for the cross sample layout
    half wUp     = wx1 * wy0;   // center in X, up in Y
    half wDown   = wx1 * wy2;   // center in X, down in Y
    half wLeft   = wx0 * wy1;   // left   in X, center in Y
    half wRight  = wx2 * wy1;   // right  in X, center in Y
    half wCenter = wx1 * wy1;   // center in X, center in Y

    // Fractional offsets for the center
    half dx = w2.x / wx1;
    half dy = w2.y / wy1;

    //------------------------------------------------------------------------------------
    // 2) Gather the 5 taps
    //------------------------------------------------------------------------------------
    half4 left   = half4(LoadHistory((baseFloor + float2(-1.0, dy))  * _InvOutputDims ), 1.HF);
    half4 up     = half4(LoadHistory((baseFloor + float2(dx,  -1.0)) * _InvOutputDims ), 1.HF);
    half4 center = half4(LoadHistory((baseFloor + float2(dx,  dy))   * _InvOutputDims ), 1.HF);
    half4 right  = half4(LoadHistory((baseFloor + float2(2.0, dy))   * _InvOutputDims ), 1.HF);
    half4 down   = half4(LoadHistory((baseFloor + float2(dx,  2.0))  * _InvOutputDims ), 1.HF);

    //------------------------------------------------------------------------------------
    // 3) Accumulate and track min/max
    //------------------------------------------------------------------------------------
    half4 accum = up    * wUp     +
                  left  * wLeft   +
                  center* wCenter +
                  right * wRight  +
                  down  * wDown;
    half3 cmin3 = min(up.rgb, 
                  min(left.rgb,
                  min(center.rgb,
                  min(right.rgb, down.rgb))));
    half3 cmax3 = max(up.rgb, 
                  max(left.rgb,
                  max(center.rgb,
                  max(right.rgb, down.rgb))));

    //------------------------------------------------------------------------------------
    // 4) Final color
    //------------------------------------------------------------------------------------
    half3 color = accum.rgb * rcp(accum.w);

    // dering in the case where we have negative values, we don't do this all the time
    // as it can impose unnecessary blurring on the output
    return any(lessThan(color, half3(0.HF)))
         ? clamp(color, cmin3, cmax3)
         : color;
}


int32_t2 LoadNearestDepthOffset(int32_t2 pixel)
{
    half encNorm = half(texelFetch(_NearestDepthCoordTex, pixel, 0).r);
    int32_t code = int32_t(encNorm * 255.0 + 0.5);          

    // 3. map back to {-1,0,1}Β²
    return DecodeNearestDepthCoord(code);
}


half3 LoadWarpedHistory(float2 uv, int32_t2 input_pixel, out half onscreen)
{
    // Dilate motion vectors with previously calculated nearest depth coordinate
    int32_t2 nearest_offset = LoadNearestDepthOffset(input_pixel);
    half2 motion = LoadMotion(input_pixel + nearest_offset);
    
    // Suppress very small motion - no need to resample
    half2  motion_pix = motion * half2(_OutputDims);
    motion *= half(dot(motion_pix, motion_pix) > _MotionThresh);

    // UV coordinates in previous frame to resample history
    float2 reproj_uv = uv - float2(motion);

    // Mask to flag whether the motion vector is resampling from valid location onscreen
    onscreen = half(
        all(greaterThanEqual(reproj_uv, float2(0.0))) &&
        all(lessThan(reproj_uv, float2(1.0)))
    );

#ifdef HISTORY_CATMULL
    half3 warped_history = LoadHistoryCatmull(reproj_uv);
#else
    half3 warped_history = LoadHistory(reproj_uv);
#endif

    return SafeColour(warped_history * _Exposure);
}

#if SCALE_MODE == SCALE_2_0X
/*
    Optimised special case pattern for applying 4x4 kernel to 
    sparse jitter-aware 2x2 upsampled image
*/


half4 LoadKPNWeight(float2 uv, int16_t lut_idx)
{
    // Load 4 kernel slices (each with 4 taps)
    half4 k0 = Dequantize(half4(textureLod(_K0Tensor, uv, 0)), _K0QuantParams);
    half4 k1 = Dequantize(half4(textureLod(_K1Tensor, uv, 0)), _K1QuantParams);
    half4 k2 = Dequantize(half4(textureLod(_K2Tensor, uv, 0)), _K2QuantParams);
    half4 k3 = Dequantize(half4(textureLod(_K3Tensor, uv, 0)), _K3QuantParams);

    // Precomputed swizzle patterns for KernelTile
    half4 p0 = half4(k0.x, k2.x, k0.z, k2.z);
    half4 p1 = half4(k1.x, k3.x, k1.z, k3.z);
    half4 p2 = half4(k0.y, k2.y, k0.w, k2.w);
    half4 p3 = half4(k1.y, k3.y, k1.w, k3.w);

    // Return the correct pattern for this tile
    return (lut_idx == 0) ? p0 :
           (lut_idx == 1) ? p1 :
           (lut_idx == 2) ? p2 :
                            p3;
}


half3 LoadAndFilterColour(int32_t2 output_pixel, float2 uv, out half4 col_to_accum)
{    
    //-------------------------------------------------------------------
    // 1. Compute indexes, load correct pattern from LUT for given thread
    //-------------------------------------------------------------------
    float2 out_tex = float2(output_pixel) + 0.5f;
    
    // Compute the LUT index for this pixel
    int16_t2 tiled_idx = (int16_t2(output_pixel) + _LutOffset) % int16_t2(_IndexModulo);
    int16_t lut_idx = tiled_idx.y * int16_t(_IndexModulo) + tiled_idx.x;
    KernelTile lut = kernelLUT[lut_idx];

    //------------------------------------------------------------------
    // 2. Apply KPN 
    //------------------------------------------------------------------
    // Dequantize the kernel weights
    half4 kpn_weights = clamp(LoadKPNWeight(uv, lut_idx), half4(EPS), half4(1.HF));

    // Calculate tap locations
    int16_t4 tap_x = clamp(int16_t4(floor((float4(out_tex.x) + float4(lut.dx)) * _InvScale.x)), int16_t4(0), int16_t4(_InputDims.x - 1));
    int16_t4 tap_y = clamp(int16_t4(floor((float4(out_tex.y) + float4(lut.dy)) * _InvScale.y)), int16_t4(0), int16_t4(_InputDims.y - 1));

    // Gather taps
    f16mat4x4 interm;
    interm[0] = half4(SafeColour(half3(texelFetch(_ColourTex, int16_t2(tap_x[0], tap_y[0]), 0).rgb) * half3(_Exposure)), 1.HF);
    interm[1] = half4(SafeColour(half3(texelFetch(_ColourTex, int16_t2(tap_x[1], tap_y[1]), 0).rgb) * half3(_Exposure)), 1.HF);
    interm[2] = half4(SafeColour(half3(texelFetch(_ColourTex, int16_t2(tap_x[2], tap_y[2]), 0).rgb) * half3(_Exposure)), 1.HF);
    interm[3] = half4(SafeColour(half3(texelFetch(_ColourTex, int16_t2(tap_x[3], tap_y[3]), 0).rgb) * half3(_Exposure)), 1.HF);

    // Special case: grab the accumulation pixel, when it corresponds to current thread
    half match = half(lut.dx[CENTER_TAP] == 0 && lut.dy[CENTER_TAP] == 0);
    col_to_accum = interm[CENTER_TAP] * match;

    // Apply filter
    half4 out_colour = interm * kpn_weights;

    return half3(out_colour.rgb * rcp(out_colour.w));
}
#else
    #error "Unsupported SCALE_MODE"
#endif // SCALE_MODE == SCALE_2_0X


void LoadTemporalParameters(float2 uv, out half theta, out half alpha)
{
    half2 tp = Dequantize(half2(textureLod(_TemporalTensor, uv, 0).xy), _TemporalQuantParams);
    theta = tp.x * _NotHistoryReset; // {0 <= x <= 1}
    alpha = tp.y * 0.35HF + 0.05HF; // { 0.05 <= x <= 0.4}
}


void WriteUpsampledColour(int32_t2 pixel, half3 colour)
{
    half3 to_write = SafeColour(colour);
    // Write with alpha = 1.0
    imageStore(_UpsampledColourOut, pixel, half4(to_write, 1.0));
}


// entry-point
layout(local_size_x = 16, local_size_y = 16) in;
void main()
{
    int32_t2 output_pixel = int32_t2(gl_GlobalInvocationID.xy);
    if (any(greaterThanEqual(output_pixel, _OutputDims))) return;

    float2 uv = (float2(output_pixel) + 0.5) * _InvOutputDims;
    int32_t2 input_pixel = int32_t2(uv * _InputDims);

    //-------------------------------------------------------------------------
    // 1) Warp history
    //-------------------------------------------------------------------------
    half  onscreen;
    half3 history = LoadWarpedHistory(uv, input_pixel, onscreen);

    //-------------------------------------------------------------------------
    // 2) KPN filter β†’ col
    //-------------------------------------------------------------------------
    half4 col_to_accum;
    half3 colour = LoadAndFilterColour(output_pixel, uv, col_to_accum);

    // -------------------------------------------------------------------------
    // 3) Load temporal parameters
    //-------------------------------------------------------------------------
    half theta, alpha;
    LoadTemporalParameters(uv, theta, alpha);

    //-------------------------------------------------------------------------
    // 3) Rectify history, force reset when offscreen
    //-------------------------------------------------------------------------
    half3 rectified = lerp(colour, history, theta * onscreen);

    //-------------------------------------------------------------------------
    // 3) Accumulate new sample
    //-------------------------------------------------------------------------
    half3 accumulated = lerp(Tonemap(rectified), Tonemap(col_to_accum.rgb), alpha * col_to_accum.a);

    //-------------------------------------------------------------------------
    // 4) Inverse tonemap + exposure and write output
    //-------------------------------------------------------------------------
    half3 out_linear = InverseTonemap(accumulated) * _InvExposure;
    WriteUpsampledColour(output_pixel, out_linear);
}