Haskell - Neural Network Back-Error Propagation

Posted on by Kenny Cason
tags = [ haskell, functional programming, artificial intelligence, λ\= ]

After going through various tutorials I decided to try and build something a bit more complicated. I decided to convert my Java implementation of a Back-Error Propagation Neural Network into Haskell. There appears to be a small bug somewhere in the calculations…

I uploaded most of my Haskell examples to GitHub, found here

Main.hs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import NN
import Utils
import Node
import Layer


testInput nn trainInput = do
    print (getOutput
                (feedForward
                    (setInput
                         nn trainInput)))


train trainInput teacherSignals = do
    -- print "create NN and train 100 steps"
    let nn = (trainStep nnNew trainInput teacherSignals 3000)
                where nnNew = setInput (createNN 2 10 1 2.5) trainInput

    -- print nn
    testInput nn trainInput

main = do
    -- only training one set of data at a time...
    print "testing values [1.0, 1.0] => 1.0"
    train [1.0, 1.0] [1.0]

    print "testing values [0.0, 0.0] => 0.0"
    train [0.0, 0.0] [0.0]

    print "testing values [1.0, 0.0] => 0.0"
    train [1.0, 0.0] [0.0]

    print "testing values [0.0, 1.0] => 0.0"
    train [0.0, 1.0] [0.0]

This yields the following, incorrect, but close output:

"testing values [1.0, 1.0] => 1.0"
"testing values [1.0, 1.0] => 1.0"
[0.9834379896449783]
"testing values [0.0, 0.0] => 0.0"
[0.9241418199787566]
"testing values [1.0, 0.0] => 0.0"
[0.5027090669395176]
"testing values [0.0, 1.0] => 0.0"
[0.5027090669395176]

Node.hs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
module Node
    (Node(..)
    ,numWeights
    ,createNode
    ,compareNode
    ,sigmoidNodeValue
    ,clearNodeValue
    )
where

import Utils

data Node = Node { value::Double, weights::[Double] } deriving Show

-- sigmoidNodeValue()
sigmoidNodeValue :: Node -> Node
sigmoidNodeValue node = node { value = sigmoid (value node) }


-- clearNodeValue()
clearNodeValue :: Node -> Node
clearNodeValue node = Node 0.0 (weights node)


-- createNode()
createNode :: Int -> Double -> Node
createNode numNodes defaultWeight = Node {
                                        value = 0.0
                                        ,weights = replicate numNodes defaultWeight
                                        }


-- numWeights()
numWeights :: Node -> Int
numWeights node = length (weights node)


-- compareNode()
compareNode :: Node -> Node -> Double
compareNode n1 n2 = abs ((value n2) - (value n1))

Layer.hs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
module Layer
    (Layer(..)
    ,createLayer
    ,createEmptyLayer
    ,calculateErrors
    ,calculateOutputErrors
    ,adjustWeights
    ,clearLayerValues
    ,calculateNodeValues
    ,sigmoidLayerValues
    ,isOutputLayer
    ,getErrors
)
where

import Utils
import Node

data Layer = Layer {  
                nodes :: [Node]
                ,errors :: [Double]
                ,teacherSignals :: [Double]
                ,learningRate :: Double
            } deriving Show


createNodeRow :: Int -> Int -> [Node]
createNodeRow numNodes numWeightsPerNode = replicate numNodes (createNode numWeightsPerNode 0.5)


createLayer :: Int -> Int -> Double -> Layer
createLayer numNodes numWeightsPerNode theLearningRate =
        Layer {
              nodes = (createNodeRow numNodes numWeightsPerNode)
              ,errors = (replicate numNodes 0.0)
              ,teacherSignals = (replicate numNodes 0.0)
              ,learningRate = theLearningRate
        }


createEmptyLayer = createLayer 0 0 0

-- calculateErrors()
sumError :: Node -> Layer -> Double
sumError node childLayer = sum (zipWith (*) (errors childLayer) (weights node))

calculateNodeErrors :: Node -> Layer -> Double
calculateNodeErrors node childLayer = (sumError node childLayer) * (value node) * (1.0 - (value node))

calculateErrors :: Layer -> Layer -> Layer
calculateErrors layer childLayer | isOutputLayer layer = calculateOutputErrors layer
                                 | otherwise = layer {
                                            errors = map (\node -> calculateNodeErrors node childLayer) (nodes layer)
                                        }


-- calculateOutputErrors()
calculateOutputNodeError :: Node -> Double -> Double
calculateOutputNodeError node teacherSignal =
                                (teacherSignal - (value node)) * (value node) * (1.0 - (value node))

calculateOutputErrors :: Layer -> Layer
calculateOutputErrors layer = layer {
                                errors = zipWith (\node teacherSignal ->
                                                        calculateOutputNodeError node teacherSignal)
                                                                                     (nodes layer)
                                                                                     (teacherSignals layer)
                            }


-- adjustWeights()
adjustWeightValue :: Double -> Double -> Double -> Double -> Double
adjustWeightValue value weight error learningRate =  weight + (learningRate * error * value)

adjustNodeWeight :: Node -> Layer -> Double -> Node
adjustNodeWeight node childLayer learningRate = node {
                                                 weights = zipWith
                                                      (\weight error ->
                                                              adjustWeightValue (value node) weight error learningRate)
                                                                            (weights node)
                                                                            (errors childLayer)
                                               }

adjustWeights :: Layer -> Layer -> Layer
adjustWeights layer childLayer = layer {
                                    nodes = map (\node -> adjustNodeWeight
                                                                    node
                                                                    childLayer
                                                                    (learningRate layer))
                                                                                   (nodes layer)
                                }


-- clearAllValues()
clearLayerValues :: Layer -> Layer
clearLayerValues layer = layer { nodes = (map clearNodeValue (nodes layer)) }


-- calculateNodeValues()
sumOfWeightsValues :: Layer -> [Double]
sumOfWeightsValues layer = foldl1 (zipWith (+))
                               [multConstList (value node) (weights node)
                               | node <- (nodes layer)]

updateChildNodeValue :: Double -> Node -> Node
updateChildNodeValue weightedValue childNode = childNode {
                                                value = weightedValue
                                             }

calculateNodeValues :: Layer -> Layer -> Layer
calculateNodeValues layer childLayer = childLayer {
                                        nodes = zipWith
                                                    updateChildNodeValue
                                                            (sumOfWeightsValues layer)
                                                            (nodes childLayer)
                                     }

-- sigmoidLayerValues()
sigmoidLayerValues :: Layer -> Layer
sigmoidLayerValues layer = layer { nodes = map (\node -> sigmoidNodeValue node) (nodes layer) }

-- isOutputLayer()
isOutputLayer :: Layer -> Bool
isOutputLayer layer = null (weights (getFirstNode layer))


-- getFirstNode()
getFirstNode :: Layer -> Node
getFirstNode layer = head (nodes layer)

-- getErrors()
getErrors :: Layer -> [Double]
getErrors layer = (errors layer)

NN.hs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
module Layer
    (Layer(..)
    ,createLayer
    ,createEmptyLayer
    ,calculateErrors
    ,calculateOutputErrors
    ,adjustWeights
    ,clearLayerValues
    ,calculateNodeValues
    ,sigmoidLayerValues
    ,isOutputLayer
    ,getErrors
)
where

import Utils
import Node

data Layer = Layer {  
                nodes :: [Node]
                ,errors :: [Double]
                ,teacherSignals :: [Double]
                ,learningRate :: Double
            } deriving Show


createNodeRow :: Int -> Int -> [Node]
createNodeRow numNodes numWeightsPerNode = replicate numNodes (createNode numWeightsPerNode 0.5)


createLayer :: Int -> Int -> Double -> Layer
createLayer numNodes numWeightsPerNode theLearningRate =
        Layer {
              nodes = (createNodeRow numNodes numWeightsPerNode)
              ,errors = (replicate numNodes 0.0)
              ,teacherSignals = (replicate numNodes 0.0)
              ,learningRate = theLearningRate
        }


createEmptyLayer = createLayer 0 0 0

-- calculateErrors()
sumError :: Node -> Layer -> Double
sumError node childLayer = sum (zipWith (*) (errors childLayer) (weights node))

calculateNodeErrors :: Node -> Layer -> Double
calculateNodeErrors node childLayer = (sumError node childLayer) * (value node) * (1.0 - (value node))

calculateErrors :: Layer -> Layer -> Layer
calculateErrors layer childLayer | isOutputLayer layer = calculateOutputErrors layer
                                 | otherwise = layer {
                                            errors = map (\node -> calculateNodeErrors node childLayer) (nodes layer)
                                        }


-- calculateOutputErrors()
calculateOutputNodeError :: Node -> Double -> Double
calculateOutputNodeError node teacherSignal =
                                (teacherSignal - (value node)) * (value node) * (1.0 - (value node))

calculateOutputErrors :: Layer -> Layer
calculateOutputErrors layer = layer {
                                errors = zipWith (\node teacherSignal ->
                                                        calculateOutputNodeError node teacherSignal)
                                                                                     (nodes layer)
                                                                                     (teacherSignals layer)
                            }


-- adjustWeights()
adjustWeightValue :: Double -> Double -> Double -> Double -> Double
adjustWeightValue value weight error learningRate =  weight + (learningRate * error * value)

adjustNodeWeight :: Node -> Layer -> Double -> Node
adjustNodeWeight node childLayer learningRate = node {
                                                 weights = zipWith
                                                      (\weight error ->
                                                              adjustWeightValue (value node) weight error learningRate)
                                                                            (weights node)
                                                                            (errors childLayer)
                                               }

adjustWeights :: Layer -> Layer -> Layer
adjustWeights layer childLayer = layer {
                                    nodes = map (\node -> adjustNodeWeight
                                                                    node
                                                                    childLayer
                                                                    (learningRate layer))
                                                                                   (nodes layer)
                                }


-- clearAllValues()
clearLayerValues :: Layer -> Layer
clearLayerValues layer = layer { nodes = (map clearNodeValue (nodes layer)) }


-- calculateNodeValues()
sumOfWeightsValues :: Layer -> [Double]
sumOfWeightsValues layer = foldl1 (zipWith (+))
                               [multConstList (value node) (weights node)
                               | node <- (nodes layer)]

updateChildNodeValue :: Double -> Node -> Node
updateChildNodeValue weightedValue childNode = childNode {
                                                value = weightedValue
                                             }

calculateNodeValues :: Layer -> Layer -> Layer
calculateNodeValues layer childLayer = childLayer {
                                        nodes = zipWith
                                                    updateChildNodeValue
                                                            (sumOfWeightsValues layer)
                                                            (nodes childLayer)
                                     }

-- sigmoidLayerValues()
sigmoidLayerValues :: Layer -> Layer
sigmoidLayerValues layer = layer { nodes = map (\node -> sigmoidNodeValue node) (nodes layer) }

-- isOutputLayer()
isOutputLayer :: Layer -> Bool
isOutputLayer layer = null (weights (getFirstNode layer))


-- getFirstNode()
getFirstNode :: Layer -> Node
getFirstNode layer = head (nodes layer)

-- getErrors()
getErrors :: Layer -> [Double]
getErrors layer = (errors layer)

Utils.hs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
module Utils
    (sigmoid
    ,listProduct
    ,listSquared
    ,listSum
    ,sumList
    ,multConstList
    ,addConstList
    )
where

-- sigmoid()
e = exp 1
sigmoid :: Double -> Double
sigmoid x = 1 / (1 + e**(-x))


-- listProduct()
listProduct a b = zipWith (*) a b


-- listSum()
listSum a b = zipWith (+) a b


-- listSquared()
listSquared :: [Double] -> [Double]
listSquared l = map (\n -> n * n) l


-- multConstList()
multConstList :: Double -> [Double] -> [Double]
multConstList const list = map (const *) list


-- addConstList()
addConstList :: Double -> [Double] -> [Double]
addConstList const list = map (const +) list


-- sumList()
sumList :: [Double] -> Double
sumList l = foldl (+) 0.0 l
comments powered by Disqus