A general purpose image processing model built without any math or machine learning libraries.
The ConvolutionalNeuralNetwork class manages initializing, saving, and retrieving network parameters, calculating feedforward output, determining loss on known inputs, single and batch training using gradient descent, and storing network state for analysis or chaining.
/* Create Convolutional Neural Network Instance */
ConvolutionalNeuralNetwork myCNN(
Dimensions(1, Shape(28, 28)), // input image dimensions: 1 channel of size 28x28
{
// convolution layer with 8 kernels of size 3x3 using stride=1, padding=1
new ConvolutionLayerParameters(8, Shape(3, 3), 1, 1),
// relu activation layer
new ActivationLayerParameters(RELU),
// convolutional layer with 16 kernels of size 3x3 using stride=1, padding=1
new ConvolutionLayerParameters(16, Shape(3, 3), 1, 1),
// max pool layer with window size 2x2 using stride=2, padding=0
new PoolLayerParameters(MAX, Shape(2, 2), 2, 0)
},
{
HiddenLayerParameters(64, RELU), // hidden layer with 64 nodes using relu activation
HiddenLayerParameters(10, LINEAR) // output layer with 10 nodes using linear activation
},
SOFTMAX, // normalization function
CATEGORICAL_CROSS_ENTROPY // loss function
);/* Initialize Parameters */
myCNN.initializeRandomFeatureLayerParameters();
// initial weight range -0.1 to 0.1, initial bias range -0.1 to 0.1
myCNN.initializeRandomHiddenLayerParameters(-0.1, 0.1, -0.1, 0.1);/* Load Parameters From A File */
myCNN.load("path/to/learnedParameters.json");
/* Save Parameters To A File */
myCNN.save("path/to/learnedParameters.json");/* Use TensorDataPoint For Training */
std::vector<TensorDataPoint> trainingDataPoints;
trainingDataPoints.emplace_back(
Tensor(/* ...28x28x1 data... */), // input
Matrix(/* ...10x1 data... */) // expected output
);
// ... add other training data .../* Train Using A Single TensorDataPoint */
float learningRate = 0.1;
myCNN.train(trainingDataPoints[0], learningRate);/* Train Using A Batch */
float learningRate = 0.05;
myCNN.batchTrain(trainingDataPoints, learningRate);/* Make A Prediction */
Tensor input = Tensor(/* ...28x28x1 data... */ );
Matrix output = myCNN.calculateFeedForwardOutput(input);/* Calculate The Loss For A Known Point */
Matrix expectedOutput = Matrix(/* ...10x1 data... */);
float loss = myCNN.calculateLoss(input, expectedOutput);ConvolutionalNeuralNetwork cnn = ConvolutionalNeuralNetwork(
Dimensions(1, Shape(28, 28)),
{
new ConvolutionLayerParameters(8, Shape(3, 3), 1, 1),
new PoolLayerParameters(AVG, Shape(2, 2), 2, 0),
new ActivationLayerParameters(TANH),
new ConvolutionLayerParameters(16, Shape(3, 3), 1, 1),
new PoolLayerParameters(AVG, Shape(2, 2), 2, 0)
},
{
HiddenLayerParameters(64, RELU),
HiddenLayerParameters(10, LINEAR)
},
SOFTMAX,
CATEGORICAL_CROSS_ENTROPY
);
cnn.initializeRandomFeatureLayerParameters();
cnn.initializeRandomHiddenLayerParameters(-0.1, 0.1, -0.1, 0.1);The model was trained in randomly partitioned batches of 1200 images on a dataset of 60k images.
The loss and accuracy were calculated against a separate testing dataset of 10k images.
| Epoch | 0 | 1 | 2 | 3 | 4 |
|---|---|---|---|---|---|
| Loss | 2.32982 | 0.401953 | 0.258774 | 0.21641 | 0.181177 |
| Accuracy | 0.0961 | 0.874 | 0.9189 | 0.9315 | 0.9442 |
Visualizing the outputs of different layers gives insight into how certain features of the data are used.
Each channel roughly represents one feature that a convolutional kernel identifies.
To explore this, we can look at snapshots of the output channels throughout the feedforward process:
Each channel roughly reflects a unique low level feature in the image (vertial top line, bottom right corner, etc).
| Digit | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
|---|---|---|---|---|---|---|---|---|---|---|
| Channel 0 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 1 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 2 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 3 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 4 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 5 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 6 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 7 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
The channels retain most of the same feature information in a more compressed image.
| Digit | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
|---|---|---|---|---|---|---|---|---|---|---|
| Channel 0 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 1 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 2 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 3 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 4 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 5 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 6 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 7 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
The activation layers learned to create a strong contrast between feature and non-feature regions.
| Digit | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
|---|---|---|---|---|---|---|---|---|---|---|
| Channel 0 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 1 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 2 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 3 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 4 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 5 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 6 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 7 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
Each channel captures a higher order feature in the image (extended lines, abstract patterns, etc).
| Digit | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
|---|---|---|---|---|---|---|---|---|---|---|
| Channel 0 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 1 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 2 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 3 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 4 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 5 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 6 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 7 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 8 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 9 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 10 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 11 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 12 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 13 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 14 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 15 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
The higher level feature data is compressed further and passed into the neural network.
| Digit | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
|---|---|---|---|---|---|---|---|---|---|---|
| Channel 0 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 1 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 2 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 3 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 4 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 5 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 6 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 7 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 8 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 9 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 10 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 11 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 12 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 13 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 14 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
| Channel 15 | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |

























































































































































































































































































































































































































































































































































































