Line Regression

Last Updated: Jun 22, 2016

In statistics, linear regression is a statistical analysis method used to determine the interdependence relationship between two variables or among more than two variables. Differently from classification algorithm to discrete prediction, regression algorithm can predict successive value type. Linear regression algorithm defines the lost function as the minimum square error sum of sample set to solve the weight vector through minimizing the lost function. The commonly used method is the gradient descent method:

  • Initialize the weight vector and give the descent rate and the iteration occerence (or the iteration convergence condition);
  • For each piece of sample, calculate minimum square error.
  • Get the sum of minimum square error and update the weights according to the descent rate.
  • Repeat the iterations until convergence.

Source Code

  1. import java.io.DataInput;
  2. import java.io.DataOutput;
  3. import java.io.IOException;
  4. import com.aliyun.odps.data.TableInfo;
  5. import com.aliyun.odps.graph.Aggregator;
  6. import com.aliyun.odps.graph.ComputeContext;
  7. import com.aliyun.odps.graph.GraphJob;
  8. import com.aliyun.odps.graph.MutationContext;
  9. import com.aliyun.odps.graph.WorkerContext;
  10. import com.aliyun.odps.graph.Vertex;
  11. import com.aliyun.odps.graph.GraphLoader;
  12. import com.aliyun.odps.io.DoubleWritable;
  13. import com.aliyun.odps.io.LongWritable;
  14. import com.aliyun.odps.io.NullWritable;
  15. import com.aliyun.odps.io.Tuple;
  16. import com.aliyun.odps.io.Writable;
  17. import com.aliyun.odps.io.WritableRecord;
  18. /**
  19. * LineRegression input: y,x1,x2,x3,......
  20. *
  21. * @author shiwan.ch
  22. * @update jiasheng.tjs running parameters are like: tjs_lr_in tjs_lr_out 1500 2
  23. * 0.07
  24. */
  25. public class LinearRegression {
  26. public static class GradientWritable implements Writable {
  27. Tuple lastTheta;
  28. Tuple currentTheta;
  29. Tuple tmpGradient;
  30. LongWritable count;
  31. DoubleWritable lost;
  32. @Override
  33. public void readFields(DataInput in) throws IOException {
  34. lastTheta = new Tuple();
  35. lastTheta.readFields(in);
  36. currentTheta = new Tuple();
  37. currentTheta.readFields(in);
  38. tmpGradient = new Tuple();
  39. tmpGradient.readFields(in);
  40. count = new LongWritable();
  41. count.readFields(in);
  42. /* update 1: add a variable to store lost at every iteration */
  43. lost = new DoubleWritable();
  44. lost.readFields(in);
  45. }
  46. @Override
  47. public void write(DataOutput out) throws IOException {
  48. lastTheta.write(out);
  49. currentTheta.write(out);
  50. tmpGradient.write(out);
  51. count.write(out);
  52. lost.write(out);
  53. }
  54. }
  55. public static class LinearRegressionVertex extends
  56. Vertex<LongWritable, Tuple, NullWritable, NullWritable> {
  57. @Override
  58. public void compute(
  59. ComputeContext<LongWritable, Tuple, NullWritable, NullWritable> context,
  60. Iterable<NullWritable> messages) throws IOException {
  61. context.aggregate(getValue());
  62. }
  63. }
  64. public static class LinearRegressionVertexReader extends
  65. GraphLoader<LongWritable, Tuple, NullWritable, NullWritable> {
  66. @Override
  67. public void load(LongWritable recordNum, WritableRecord record,
  68. MutationContext<LongWritable, Tuple, NullWritable, NullWritable> context)
  69. throws IOException {
  70. LinearRegressionVertex vertex = new LinearRegressionVertex();
  71. vertex.setId(recordNum);
  72. vertex.setValue(new Tuple(record.getAll()));
  73. context.addVertexRequest(vertex);
  74. }
  75. }
  76. public static class LinearRegressionAggregator extends
  77. Aggregator<GradientWritable> {
  78. @SuppressWarnings("rawtypes")
  79. @Override
  80. public GradientWritable createInitialValue(WorkerContext context)
  81. throws IOException {
  82. if (context.getSuperstep() == 0) {
  83. /* set initial value, all 0 */
  84. GradientWritable grad = new GradientWritable();
  85. grad.lastTheta = new Tuple();
  86. grad.currentTheta = new Tuple();
  87. grad.tmpGradient = new Tuple();
  88. grad.count = new LongWritable(1);
  89. grad.lost = new DoubleWritable(0.0);
  90. int n = (int) Long.parseLong(context.getConfiguration()
  91. .get("Dimension"));
  92. for (int i = 0; i < n; i++) {
  93. grad.lastTheta.append(new DoubleWritable(0));
  94. grad.currentTheta.append(new DoubleWritable(0));
  95. grad.tmpGradient.append(new DoubleWritable(0));
  96. }
  97. return grad;
  98. } else
  99. return (GradientWritable) context.getLastAggregatedValue(0);
  100. }
  101. public static double vecMul(Tuple value, Tuple theta) {
  102. double sum = 0.0;
  103. for (int j = 1; j < value.size(); j++)
  104. sum += Double.parseDouble(value.get(j).toString())
  105. * Double.parseDouble(theta.get(j).toString());
  106. Double tmp = Double.parseDouble(theta.get(0).toString()) + sum
  107. - Double.parseDouble(value.get(0).toString());
  108. return tmp;
  109. }
  110. @Override
  111. public void aggregate(GradientWritable gradient, Object value)
  112. throws IOException {
  113. /*
  114. * perform on each vertex--each sample i:set theta(j) for each sample i
  115. * for each dimension
  116. */
  117. double tmpVar = vecMul((Tuple) value, gradient.currentTheta);
  118. /*
  119. * update 2:local worker aggregate(), perform like merge() below. This
  120. * means the variable gradient denotes the previous aggregated value
  121. */
  122. gradient.tmpGradient.set(0, new DoubleWritable(
  123. ((DoubleWritable) gradient.tmpGradient.get(0)).get() + tmpVar));
  124. gradient.lost.set(Math.pow(tmpVar, 2));
  125. for (int j = 1; j < gradient.tmpGradient.size(); j++)
  126. gradient.tmpGradient.set(j, new DoubleWritable(
  127. ((DoubleWritable) gradient.tmpGradient.get(j)).get() + tmpVar
  128. * Double.parseDouble(((Tuple) value).get(j).toString())));
  129. }
  130. @Override
  131. public void merge(GradientWritable gradient, GradientWritable partial)
  132. throws IOException {
  133. /* perform SumAll on each dimension for all samples. */
  134. Tuple master = (Tuple) gradient.tmpGradient;
  135. Tuple part = (Tuple) partial.tmpGradient;
  136. for (int j = 0; j < gradient.tmpGradient.size(); j++) {
  137. DoubleWritable s = (DoubleWritable) master.get(j);
  138. s.set(s.get() + ((DoubleWritable) part.get(j)).get());
  139. }
  140. gradient.lost.set(gradient.lost.get() + partial.lost.get());
  141. }
  142. @SuppressWarnings("rawtypes")
  143. @Override
  144. public boolean terminate(WorkerContext context, GradientWritable gradient)
  145. throws IOException {
  146. /*
  147. * 1. calculate new theta 2. judge the diff between last step and this
  148. * step, if smaller than the threshold, stop iteration
  149. */
  150. gradient.lost = new DoubleWritable(gradient.lost.get()
  151. / (2 * context.getTotalNumVertices()));
  152. /*
  153. * we can calculate lost in order to make sure the algorithm is running on
  154. * the right direction (for debug)
  155. */
  156. System.out.println(gradient.count + " lost:" + gradient.lost);
  157. Tuple tmpGradient = gradient.tmpGradient;
  158. System.out.println("tmpGra" + tmpGradient);
  159. Tuple lastTheta = gradient.lastTheta;
  160. Tuple tmpCurrentTheta = new Tuple(gradient.currentTheta.size());
  161. System.out.println(gradient.count + " terminate_start_last:" + lastTheta);
  162. double alpha = 0.07; // learning rate
  163. // alpha =
  164. // Double.parseDouble(context.getConfiguration().get("Alpha"));
  165. /* perform theta(j) = theta(j)-alpha*tmpGradient */
  166. long M = context.getTotalNumVertices();
  167. /*
  168. * update 3: add (/M) on the code. The original code forget this step
  169. */
  170. for (int j = 0; j < lastTheta.size(); j++) {
  171. tmpCurrentTheta
  172. .set(
  173. j,
  174. new DoubleWritable(Double.parseDouble(lastTheta.get(j)
  175. .toString())
  176. - alpha
  177. / M
  178. * Double.parseDouble(tmpGradient.get(j).toString())));
  179. }
  180. System.out.println(gradient.count + " terminate_start_current:"
  181. + tmpCurrentTheta);
  182. // judge if convergence is happening.
  183. double diff = 0.00d;
  184. for (int j = 0; j < gradient.currentTheta.size(); j++)
  185. diff += Math.pow(((DoubleWritable) tmpCurrentTheta.get(j)).get()
  186. - ((DoubleWritable) lastTheta.get(j)).get(), 2);
  187. if (/*
  188. * Math.sqrt(diff) < 0.00000000005d ||
  189. */Long.parseLong(context.getConfiguration().get("Max_Iter_Num")) == gradient.count
  190. .get()) {
  191. context.write(gradient.currentTheta.toArray());
  192. return true;
  193. }
  194. gradient.lastTheta = tmpCurrentTheta;
  195. gradient.currentTheta = tmpCurrentTheta;
  196. gradient.count.set(gradient.count.get() + 1);
  197. int n = (int) Long.parseLong(context.getConfiguration().get("Dimension"));
  198. /*
  199. * update 4: Important!!! Remember this step. Graph won't reset the
  200. * initial value for global variables at the beginning of each iteration
  201. */
  202. for (int i = 0; i < n; i++) {
  203. gradient.tmpGradient.set(i, new DoubleWritable(0));
  204. }
  205. return false;
  206. }
  207. }
  208. public static void main(String[] args) throws IOException {
  209. GraphJob job = new GraphJob();
  210. job.setGraphLoaderClass(LinearRegressionVertexReader.class);
  211. job.setRuntimePartitioning(false);
  212. job.setNumWorkers(3);
  213. job.setVertexClass(LinearRegressionVertex.class);
  214. job.setAggregatorClass(LinearRegressionAggregator.class);
  215. job.addInput(TableInfo.builder().tableName(args[0]).build());
  216. job.addOutput(TableInfo.builder().tableName(args[1]).build());
  217. job.setMaxIteration(Integer.parseInt(args[2])); // Numbers of Iteration
  218. job.setInt("Max_Iter_Num", Integer.parseInt(args[2]));
  219. job.setInt("Dimension", Integer.parseInt(args[3])); // Dimension
  220. job.setFloat("Alpha", Float.parseFloat(args[4])); // Learning rate
  221. long start = System.currentTimeMillis();
  222. job.run();
  223. System.out.println("Job Finished in "
  224. + (System.currentTimeMillis() - start) / 1000.0 + " seconds");
  225. }
  226. }
Thank you! We've received your feedback.