You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

example4.c 3.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <time.h>
  4. #include <string.h>
  5. #include <math.h>
  6. #include "genann.h"
  7. /* This example is to illustrate how to use GENANN.
  8. * It is NOT an example of good machine learning techniques.
  9. */
  10. const char *iris_data = "example/iris.data";
  11. double *input, *class;
  12. int samples;
  13. const char *class_names[] = {"Iris-setosa", "Iris-versicolor", "Iris-virginica"};
  14. void load_data() {
  15. /* Load the iris data-set. */
  16. FILE *in = fopen("example/iris.data", "r");
  17. if (!in) {
  18. printf("Could not open file: %s\n", iris_data);
  19. exit(1);
  20. }
  21. /* Loop through the data to get a count. */
  22. char line[1024];
  23. while (!feof(in) && fgets(line, 1024, in)) {
  24. ++samples;
  25. }
  26. fseek(in, 0, SEEK_SET);
  27. printf("Loading %d data points from %s\n", samples, iris_data);
  28. /* Allocate memory for input and output data. */
  29. input = malloc(sizeof(double) * samples * 4);
  30. class = malloc(sizeof(double) * samples * 3);
  31. /* Read the file into our arrays. */
  32. int i, j;
  33. for (i = 0; i < samples; ++i) {
  34. double *p = input + i * 4;
  35. double *c = class + i * 3;
  36. c[0] = c[1] = c[2] = 0.0;
  37. if (fgets(line, 1024, in) == NULL) {
  38. perror("fgets");
  39. exit(1);
  40. }
  41. char *split = strtok(line, ",");
  42. for (j = 0; j < 4; ++j) {
  43. p[j] = atof(split);
  44. split = strtok(0, ",");
  45. }
  46. split[strlen(split)-1] = 0;
  47. if (strcmp(split, class_names[0]) == 0) {c[0] = 1.0;}
  48. else if (strcmp(split, class_names[1]) == 0) {c[1] = 1.0;}
  49. else if (strcmp(split, class_names[2]) == 0) {c[2] = 1.0;}
  50. else {
  51. printf("Unknown class %s.\n", split);
  52. exit(1);
  53. }
  54. /* printf("Data point %d is %f %f %f %f -> %f %f %f\n", i, p[0], p[1], p[2], p[3], c[0], c[1], c[2]); */
  55. }
  56. fclose(in);
  57. }
  58. int main(int argc, char *argv[])
  59. {
  60. printf("GENANN example 4.\n");
  61. printf("Train an ANN on the IRIS dataset using backpropagation.\n");
  62. srand(time(0));
  63. /* Load the data from file. */
  64. load_data();
  65. /* 4 inputs.
  66. * 1 hidden layer(s) of 4 neurons.
  67. * 3 outputs (1 per class)
  68. */
  69. genann *ann = genann_init(4, 1, 4, 3);
  70. int i, j;
  71. int loops = 5000;
  72. /* Train the network with backpropagation. */
  73. printf("Training for %d loops over data.\n", loops);
  74. for (i = 0; i < loops; ++i) {
  75. for (j = 0; j < samples; ++j) {
  76. genann_train(ann, input + j*4, class + j*3, .01);
  77. }
  78. /* printf("%1.2f ", xor_score(ann)); */
  79. }
  80. int correct = 0;
  81. for (j = 0; j < samples; ++j) {
  82. const double *guess = genann_run(ann, input + j*4);
  83. if (class[j*3+0] == 1.0) {if (guess[0] > guess[1] && guess[0] > guess[2]) ++correct;}
  84. else if (class[j*3+1] == 1.0) {if (guess[1] > guess[0] && guess[1] > guess[2]) ++correct;}
  85. else if (class[j*3+2] == 1.0) {if (guess[2] > guess[0] && guess[2] > guess[1]) ++correct;}
  86. else {printf("Logic error.\n"); exit(1);}
  87. }
  88. printf("%d/%d correct (%0.1f%%).\n", correct, samples, (double)correct / samples * 100.0);
  89. genann_free(ann);
  90. free(input);
  91. free(class);
  92. return 0;
  93. }