fork download
  1. import numpy as np
  2.  
  3. # Define your activation function (you mentioned no sigmoid, likely linear)
  4. def activation_function(x):
  5. # No transformation (identity function)
  6. return x
  7.  
  8. # Initialize weights with small random values
  9. def initialize_weights(input_size, output_size):
  10. # Proper small initialization for faster convergence
  11. return np.random.randn(input_size, output_size) * 0.1
  12.  
  13. # Forward pass
  14. def forward_pass(input_vector, weights, biases):
  15. return np.dot(input_vector, weights) + biases
  16.  
  17. # Training function
  18. def train_neural_network(input_data, target_data, learning_rate=0.001, max_epochs=2000, error_threshold=1e-5):
  19. input_size = len(input_data[0])
  20. output_size = len(target_data[0])
  21.  
  22. # Initialize weights and biases
  23. weights = initialize_weights(input_size, output_size)
  24. biases = np.zeros(output_size)
  25.  
  26. # Training loop
  27. for epoch in range(max_epochs):
  28. total_error = 0
  29. for i in range(len(input_data)):
  30. input_vector = input_data[i]
  31. target_vector = target_data[i]
  32.  
  33. # Forward pass
  34. output = forward_pass(input_vector, weights, biases)
  35. error = target_vector - output
  36. total_error += np.sum(error ** 2)
  37.  
  38. # Backpropagation (gradient descent)
  39. weight_gradient = np.outer(input_vector, error) # Gradient w.r.t weights
  40. bias_gradient = error # Gradient w.r.t bias
  41.  
  42. # Gradient clipping to avoid exploding gradients
  43. weight_gradient = np.clip(weight_gradient, -10, 10)
  44. bias_gradient = np.clip(bias_gradient, -10, 10)
  45.  
  46. # Check for NaN in gradients and skip update if found
  47. if np.isnan(np.sum(weight_gradient)) or np.isnan(np.sum(bias_gradient)):
  48. print("NaN detected, skipping update.")
  49. continue
  50.  
  51. # Update weights and biases
  52. weights += learning_rate * weight_gradient
  53. biases += learning_rate * bias_gradient
  54.  
  55. # Early stopping: check if the error is below the threshold
  56. if total_error < error_threshold:
  57. print(f"Converged at epoch {epoch} with error: {total_error}")
  58. break
  59.  
  60. # Optional: Print status for every 100 epochs
  61. if epoch % 100 == 0:
  62. print(f"Epoch {epoch}, Error: {total_error}")
  63.  
  64. return weights, biases
  65.  
  66. # Test the network with the provided truth tables
  67. def test_neural_network(weights, biases, input_data):
  68. predictions = []
  69. for input_vector in input_data:
  70. output = forward_pass(input_vector, weights, biases)
  71. predictions.append(output)
  72. return predictions
  73.  
  74. # Define the truth tables
  75. input_data = [
  76. [0, 0, 0, 0],
  77. [0, 0, 0, 1],
  78. [0, 0, 1, 0],
  79. [0, 0, 1, 1],
  80. [0, 1, 0, 0],
  81. [0, 1, 0, 1],
  82. [0, 1, 1, 0],
  83. [0, 1, 1, 1],
  84. [1, 0, 0, 0],
  85. [1, 0, 0, 1],
  86. [1, 0, 1, 0],
  87. [1, 0, 1, 1],
  88. [1, 1, 0, 0],
  89. [1, 1, 0, 1],
  90. [1, 1, 1, 0],
  91. [1, 1, 1, 1]
  92. ]
  93.  
  94. # Corresponding targets (for the XOR problem)
  95. target_data = [
  96. [0, 0, 0, 0],
  97. [0, 0, 0, 1],
  98. [0, 0, 1, 0],
  99. [0, 0, 1, 1],
  100. [0, 1, 0, 0],
  101. [0, 1, 0, 1],
  102. [0, 1, 1, 0],
  103. [0, 1, 1, 1],
  104. [1, 0, 0, 0],
  105. [1, 0, 0, 1],
  106. [1, 0, 1, 0],
  107. [1, 0, 1, 1],
  108. [1, 1, 0, 0],
  109. [1, 1, 0, 1],
  110. [1, 1, 1, 0],
  111. [1, 1, 1, 1]
  112. ]
  113.  
  114. # Train the neural network
  115. weights, biases = train_neural_network(input_data, target_data, learning_rate=0.001)
  116.  
  117. # Test the neural network
  118. predictions = test_neural_network(weights, biases, input_data)
  119.  
  120. # Display the results
  121. for i, (input_vector, target_vector, prediction) in enumerate(zip(input_data, target_data, predictions)):
  122. print(f"Table {i+1}: Input: {input_vector}, Target: {target_vector}, Prediction: {prediction}, Error: {np.abs(np.array(target_vector) - np.array(prediction))}")
  123.  
Success #stdin #stdout 3.56s 29176KB
stdin
Standard input is empty
stdout
Epoch 0, Error: 30.001706882323475
Epoch 100, Error: 5.691814090306985
Epoch 200, Error: 2.6755049196225897
Epoch 300, Error: 1.2893255968426054
Epoch 400, Error: 0.6394796565970304
Epoch 500, Error: 0.32878488413761187
Epoch 600, Error: 0.1762628377505781
Epoch 700, Error: 0.09882417777673828
Epoch 800, Error: 0.057896426582148025
Epoch 900, Error: 0.03528749909466773
Epoch 1000, Error: 0.022227717355829655
Epoch 1100, Error: 0.01436559134865939
Epoch 1200, Error: 0.0094626399996172
Epoch 1300, Error: 0.0063180931826432776
Epoch 1400, Error: 0.004258313244813741
Epoch 1500, Error: 0.002888422488023718
Epoch 1600, Error: 0.0019676269958325155
Epoch 1700, Error: 0.0013441898983668644
Epoch 1800, Error: 0.0009200155534099208
Epoch 1900, Error: 0.000630474187688474
Table 1: Input: [0, 0, 0, 0], Target: [0, 0, 0, 0], Prediction: [0.00557628 0.00708622 0.0046906  0.00512291], Error: [0.00557628 0.00708622 0.0046906  0.00512291]
Table 2: Input: [0, 0, 0, 1], Target: [0, 0, 0, 1], Prediction: [0.00317719 0.00404943 0.00271049 1.0026057 ], Error: [0.00317719 0.00404943 0.00271049 0.0026057 ]
Table 3: Input: [0, 0, 1, 0], Target: [0, 0, 1, 0], Prediction: [0.00323726 0.00404888 1.00241762 0.00294885], Error: [0.00323726 0.00404888 0.00241762 0.00294885]
Table 4: Input: [0, 0, 1, 1], Target: [0, 0, 1, 1], Prediction: [8.38161393e-04 1.01208912e-03 1.00043752e+00 1.00043163e+00], Error: [0.00083816 0.00101209 0.00043752 0.00043163]
Table 5: Input: [0, 1, 0, 0], Target: [0, 1, 0, 0], Prediction: [0.00320225 1.00372398 0.00267845 0.00296315], Error: [0.00320225 0.00372398 0.00267845 0.00296315]
Table 6: Input: [0, 1, 0, 1], Target: [0, 1, 0, 1], Prediction: [8.03152844e-04 1.00068719e+00 6.98338575e-04 1.00044593e+00], Error: [0.00080315 0.00068719 0.00069834 0.00044593]
Table 7: Input: [0, 1, 1, 0], Target: [0, 1, 1, 0], Prediction: [8.63223516e-04 1.00068664e+00 1.00040547e+00 7.89081134e-04], Error: [0.00086322 0.00068664 0.00040547 0.00078908]
Table 8: Input: [0, 1, 1, 1], Target: [0, 1, 1, 1], Prediction: [-0.00153587  0.99764985  0.99842537  0.99827187], Error: [0.00153587 0.00235015 0.00157463 0.00172813]
Table 9: Input: [1, 0, 0, 0], Target: [1, 0, 0, 0], Prediction: [1.00288204 0.00405951 0.00270566 0.00296291], Error: [0.00288204 0.00405951 0.00270566 0.00296291]
Table 10: Input: [1, 0, 0, 1], Target: [1, 0, 0, 1], Prediction: [1.00048294e+00 1.02271712e-03 7.25549842e-04 1.00044570e+00], Error: [0.00048294 0.00102272 0.00072555 0.0004457 ]
Table 11: Input: [1, 0, 1, 0], Target: [1, 0, 1, 0], Prediction: [1.00054301e+00 1.02216817e-03 1.00043269e+00 7.88848089e-04], Error: [0.00054301 0.00102217 0.00043269 0.00078885]
Table 12: Input: [1, 0, 1, 1], Target: [1, 0, 1, 1], Prediction: [ 0.99814392 -0.00201462  0.99845258  0.99827163], Error: [0.00185608 0.00201462 0.00154742 0.00172837]
Table 13: Input: [1, 1, 0, 0], Target: [1, 1, 0, 0], Prediction: [1.00050801e+00 1.00069727e+00 6.93506195e-04 8.03147341e-04], Error: [0.00050801 0.00069727 0.00069351 0.00080315]
Table 14: Input: [1, 1, 0, 1], Target: [1, 1, 0, 1], Prediction: [ 0.99810891  0.99766048 -0.0012866   0.99828593], Error: [0.00189109 0.00233952 0.0012866  0.00171407]
Table 15: Input: [1, 1, 1, 0], Target: [1, 1, 1, 0], Prediction: [ 0.99816898  0.99765993  0.99842053 -0.00137092], Error: [0.00183102 0.00234007 0.00157947 0.00137092]
Table 16: Input: [1, 1, 1, 1], Target: [1, 1, 1, 1], Prediction: [0.99576988 0.99462314 0.99644043 0.99611187], Error: [0.00423012 0.00537686 0.00355957 0.00388813]