-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompute.comp
121 lines (96 loc) · 2.76 KB
/
compute.comp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#version 450
#define MODE_CALCULATE 0
#define MODE_PRESENT 1
struct Neuron {
float charge;
};
layout(push_constant) uniform pushConstants {
float dTime;
uint mode;
};
layout(std140, binding=0) uniform ubo {
ivec2 resolution;
ivec2 networkDimensions;
vec4 zoom;
};
layout(rgba32f, binding=1) uniform image2D imgOut;
layout(std430, binding=2) buffer neuronBuffer {
Neuron neurons[];
};
layout(std430, binding=3) buffer weightBuffer {
float weights[];
};
layout(local_size_x=32, local_size_y=32, local_size_z=1) in;
const int neighbors[8][2] = {
{1,0},
{1,-1},
{0,-1},
{-1,-1},
{-1,0},
{-1,1},
{0,1},
{1,1}
};
const float learnRate = 0.0;
uint weightIndex(uint tier, uint neuron, uint neighbor) {
return neuron * 8 + neighbor;
}
void calculateMode() {
if(
gl_GlobalInvocationID.x < networkDimensions.x &&
gl_GlobalInvocationID.y < networkDimensions.y
) {
uint neuronIndex = gl_GlobalInvocationID.y * networkDimensions.x + gl_GlobalInvocationID.x;
float initialCharge = neurons[neuronIndex].charge;
float retval = 0.0;
int numNeighbors = 0;
for(int i = 0; i < 8; i++) {
int x = neighbors[i][0] + int(gl_GlobalInvocationID.x);
int y = neighbors[i][1] + int(gl_GlobalInvocationID.y);
if(x < 0 || x >= networkDimensions.x || y < 0 || y >= networkDimensions.y) continue;
float oldWeight = weights[weightIndex(0, neuronIndex, i)];
float neighborCharge = neurons[x + networkDimensions.x * y].charge;
retval += oldWeight * neighborCharge;
float weightChange = neighborCharge * initialCharge;
weights[weightIndex(0, neuronIndex, i)] = oldWeight + weightChange * dTime * learnRate;
numNeighbors++;
}
barrier();
neurons[neuronIndex].charge = initialCharge + (retval) * dTime;
if(neurons[neuronIndex].charge < 0.0) neurons[neuronIndex].charge = 0.0;
if(neurons[neuronIndex].charge > 1.0) neurons[neuronIndex].charge = 1.0;
}
}
vec4 falseColor(float x) {
vec4 a = vec4(1,0,0,1);
vec4 b = vec4(0,1,0,1);
if(x <= 0) return a;
if(x >= 1) return b;
return a*x + b*(1-x);
}
void presentMode() {
if(
gl_GlobalInvocationID.x < resolution.x &&
gl_GlobalInvocationID.y < resolution.y
) {
int x = int((float(gl_GlobalInvocationID.x) + zoom.x)/zoom.z);
int y = int((float(gl_GlobalInvocationID.y) + zoom.y)/zoom.z);
vec4 color;
if(x >= networkDimensions.x || x < 0 || y >= networkDimensions.y || y < 0) {
color = vec4(0,0,0,1);
} else {
color = falseColor(neurons[x+y*networkDimensions.x].charge);
}
imageStore(imgOut, ivec2(gl_GlobalInvocationID.xy), color);
}
}
void main() {
switch(mode) {
case MODE_CALCULATE:
calculateMode();
break;
case MODE_PRESENT:
presentMode();
break;
}
}