You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

72 lines
2.1KB

  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <time.h>
  4. #include <math.h>
  5. #include "genann.h"
  6. int main(int argc, char *argv[])
  7. {
  8. printf("GENANN example 2.\n");
  9. printf("Train a small ANN to the XOR function using random search.\n");
  10. srand(time(0));
  11. /* Input and expected out data for the XOR function. */
  12. const double input[4][2] = {{0, 0}, {0, 1}, {1, 0}, {1, 1}};
  13. const double output[4] = {0, 1, 1, 0};
  14. int i;
  15. /* New network with 2 inputs,
  16. * 1 hidden layer of 2 neurons,
  17. * and 1 output. */
  18. genann *ann = genann_init(2, 1, 2, 1);
  19. double err;
  20. double last_err = 1000;
  21. int count = 0;
  22. do {
  23. ++count;
  24. if (count % 1000 == 0) {
  25. /* We're stuck, start over. */
  26. genann_randomize(ann);
  27. last_err = 1000;
  28. }
  29. genann *save = genann_copy(ann);
  30. /* Take a random guess at the ANN weights. */
  31. for (i = 0; i < ann->total_weights; ++i) {
  32. ann->weight[i] += ((double)rand())/RAND_MAX-0.5;
  33. }
  34. /* See how we did. */
  35. err = 0;
  36. err += pow(*genann_run(ann, input[0]) - output[0], 2.0);
  37. err += pow(*genann_run(ann, input[1]) - output[1], 2.0);
  38. err += pow(*genann_run(ann, input[2]) - output[2], 2.0);
  39. err += pow(*genann_run(ann, input[3]) - output[3], 2.0);
  40. /* Keep these weights if they're an improvement. */
  41. if (err < last_err) {
  42. genann_free(save);
  43. last_err = err;
  44. } else {
  45. genann_free(ann);
  46. ann = save;
  47. }
  48. } while (err > 0.01);
  49. printf("Finished in %d loops.\n", count);
  50. /* Run the network and see what it predicts. */
  51. printf("Output for [%1.f, %1.f] is %1.f.\n", input[0][0], input[0][1], *genann_run(ann, input[0]));
  52. printf("Output for [%1.f, %1.f] is %1.f.\n", input[1][0], input[1][1], *genann_run(ann, input[1]));
  53. printf("Output for [%1.f, %1.f] is %1.f.\n", input[2][0], input[2][1], *genann_run(ann, input[2]));
  54. printf("Output for [%1.f, %1.f] is %1.f.\n", input[3][0], input[3][1], *genann_run(ann, input[3]));
  55. genann_free(ann);
  56. return 0;
  57. }