Linear regression

Last Updated: Oct 31, 2017

In statistics, linear regression is a statistical analysis method used to determine the dependency between two or more variables. Different from the classification algorithm that processes discrete prediction, the regression algorithm can predict the continuous value type. The linear regression algorithm defines the loss function as the sum of the least square errors of the sample set. It minimizes the loss function to calculate the weight vector.

A common solution is gradient descent that:

  • Initializes the weight vector, and provides the descent rate and number of iterations (or iteration convergence conditions).
  • Calculates the least square error of each sample.
  • Sums least square errors and updates the weight based on the descent rate.
  • Repeats iterations until convergence is implemented.

Source code

  1. import java.io.DataInput;
  2. import java.io.DataOutput;
  3. import java.io.IOException;
  4. import com.aliyun.odps.data.TableInfo;
  5. import com.aliyun.odps.graph.Aggregator;
  6. import com.aliyun.odps.graph.ComputeContext;
  7. import com.aliyun.odps.graph.GraphJob;
  8. import com.aliyun.odps.graph.MutationContext;
  9. import com.aliyun.odps.graph.WorkerContext;
  10. import com.aliyun.odps.graph.Vertex;
  11. import com.aliyun.odps.graph.GraphLoader;
  12. import com.aliyun.odps.io.DoubleWritable;
  13. import com.aliyun.odps.io.LongWritable;
  14. import com.aliyun.odps.io.NullWritable;
  15. import com.aliyun.odps.io.Tuple;
  16. import com.aliyun.odps.io.Writable;
  17. import com.aliyun.odps.io.WritableRecord;
  18. /**
  19. * LineRegression input: y,x1,x2,x3,......
  20. **/
  21. public class LinearRegression {
  22. public static class GradientWritable implements Writable {
  23. Tuple lastTheta;
  24. Tuple currentTheta;
  25. Tuple tmpGradient;
  26. LongWritable count;
  27. DoubleWritable lost;
  28. @Override
  29. public void readFields(DataInput in) throws IOException {
  30. lastTheta = new Tuple();
  31. lastTheta.readFields(in);
  32. currentTheta = new Tuple();
  33. currentTheta.readFields(in);
  34. tmpGradient = new Tuple();
  35. tmpGradient.readFields(in);
  36. count = new LongWritable();
  37. count.readFields(in);
  38. /* update 1: add a variable to store lost at every iteration */
  39. lost = new DoubleWritable();
  40. lost.readFields(in);
  41. }
  42. @Override
  43. public void write(DataOutput out) throws IOException {
  44. lastTheta.write(out);
  45. currentTheta.write(out);
  46. tmpGradient.write(out);
  47. count.write(out);
  48. lost.write(out);
  49. }
  50. }
  51. public static class LinearRegressionVertex extends
  52. Vertex<LongWritable, Tuple, NullWritable, NullWritable> {
  53. @Override
  54. public void compute(
  55. ComputeContext<LongWritable, Tuple, NullWritable, NullWritable> context,
  56. Iterable<NullWritable> messages) throws IOException {
  57. context.aggregate(getValue());
  58. }
  59. }
  60. public static class LinearRegressionVertexReader extends
  61. GraphLoader<LongWritable, Tuple, NullWritable, NullWritable> {
  62. @Override
  63. public void load(LongWritable recordNum, WritableRecord record,
  64. MutationContext<LongWritable, Tuple, NullWritable, NullWritable> context)
  65. throws IOException {
  66. LinearRegressionVertex vertex = new LinearRegressionVertex();
  67. vertex.setId(recordNum);
  68. vertex.setValue(new Tuple(record.getAll()));
  69. context.addVertexRequest(vertex);
  70. }
  71. }
  72. public static class LinearRegressionAggregator extends
  73. Aggregator<GradientWritable> {
  74. @SuppressWarnings("rawtypes")
  75. @Override
  76. public GradientWritable createInitialValue(WorkerContext context)
  77. throws IOException {
  78. if (context.getSuperstep() == 0) {
  79. /* set initial value, all 0 */
  80. GradientWritable grad = new GradientWritable();
  81. grad.lastTheta = new Tuple();
  82. grad.currentTheta = new Tuple();
  83. grad.tmpGradient = new Tuple();
  84. grad.count = new LongWritable(1);
  85. grad.lost = new DoubleWritable(0.0);
  86. int n = (int) Long.parseLong(context.getConfiguration()
  87. .get("Dimension"));
  88. for (int i = 0; i < n; i++) {
  89. grad.lastTheta.append(new DoubleWritable(0));
  90. grad.currentTheta.append(new DoubleWritable(0));
  91. grad.tmpGradient.append(new DoubleWritable(0));
  92. }
  93. return grad;
  94. } else
  95. return (GradientWritable) context.getLastAggregatedValue(0);
  96. }
  97. public static double vecMul(Tuple value, Tuple theta) {
  98. /* Perform this partial computing: y(i)−hθ(x(i)) for each sample */
  99. /* value denote a piece of sample and value(0) is y */
  100. double sum = 0.0;
  101. for (int j = 1; j < value.size(); j++)
  102. sum += Double.parseDouble(value.get(j).toString())
  103. * Double.parseDouble(theta.get(j).toString());
  104. Double tmp = Double.parseDouble(theta.get(0).toString()) + sum
  105. - Double.parseDouble(value.get(0).toString());
  106. return tmp;
  107. }
  108. @Override
  109. public void aggregate(GradientWritable gradient, Object value)
  110. throws IOException {
  111. /*
  112. * Perform on each vertex--each sample i: set theta(j) for each sample i
  113. * for each dimension
  114. */
  115. double tmpVar = vecMul((Tuple) value, gradient.currentTheta);
  116. /*
  117. * update 2:local worker aggregate(), perform like merge() below. This
  118. * means the variable gradient denotes the previous aggregated value
  119. */
  120. gradient.tmpGradient.set(0, new DoubleWritable(
  121. ((DoubleWritable) gradient.tmpGradient.get(0)).get() + tmpVar));
  122. gradient.lost.set(Math.pow(tmpVar, 2));
  123. /*
  124. * Calculate (y(i)−hθ(x(i))) x(i)(j) for each sample i for each
  125. * dimension j
  126. */
  127. for (int j = 1; j < gradient.tmpGradient.size(); j++)
  128. gradient.tmpGradient.set(j, new DoubleWritable(
  129. ((DoubleWritable) gradient.tmpGradient.get(j)).get() + tmpVar
  130. * Double.parseDouble(((Tuple) value).get(j).toString())));
  131. }
  132. @Override
  133. public void merge(GradientWritable gradient, GradientWritable partial)
  134. throws IOException {
  135. /* perform SumAll on each dimension for all samples. */
  136. Tuple master = (Tuple) gradient.tmpGradient;
  137. Tuple part = (Tuple) partial.tmpGradient;
  138. for (int j = 0; j < gradient.tmpGradient.size(); j++) {
  139. DoubleWritable s = (DoubleWritable) master.get(j);
  140. s.set(s.get() + ((DoubleWritable) part.get(j)).get());
  141. }
  142. gradient.lost.set(gradient.lost.get() + partial.lost.get());
  143. }
  144. @SuppressWarnings("rawtypes")
  145. @Override
  146. public boolean terminate(WorkerContext context, GradientWritable gradient)
  147. throws IOException {
  148. /*
  149. * 1. calculate new theta 2. judge the diff between last step and this
  150. * step, if smaller than the threshold, stop iteration
  151. */
  152. gradient.lost = new DoubleWritable(gradient.lost.get()
  153. / (2 * context.getTotalNumVertices()));
  154. /*
  155. * we can calculate lost in order to make sure the algorithm is running on
  156. * the right direction (for debug)
  157. */
  158. System.out.println(gradient.count + " lost:" + gradient.lost);
  159. Tuple tmpGradient = gradient.tmpGradient;
  160. System.out.println("tmpGra" + tmpGradient);
  161. Tuple lastTheta = gradient.lastTheta;
  162. Tuple tmpCurrentTheta = new Tuple(gradient.currentTheta.size());
  163. System.out.println(gradient.count + " terminate_start_last:" + lastTheta);
  164. double alpha = 0.07; // learning rate
  165. // alpha =
  166. // Double.parseDouble(context.getConfiguration().get("Alpha"));
  167. /* perform theta(j) = theta(j)-alpha*tmpGradient */
  168. long M = context.getTotalNumVertices();
  169. /*
  170. * update 3: add (/M) on the code. The original code forget this step
  171. */
  172. for (int j = 0; j < lastTheta.size(); j++) {
  173. tmpCurrentTheta
  174. .set(
  175. j,
  176. new DoubleWritable(Double.parseDouble(lastTheta.get(j)
  177. .toString())
  178. - alpha
  179. / M
  180. * Double.parseDouble(tmpGradient.get(j).toString())));
  181. }
  182. System.out.println(gradient.count + " terminate_start_current:"
  183. + tmpCurrentTheta);
  184. // judge if convergence is happening.
  185. double diff = 0.00d;
  186. for (int j = 0; j < gradient.currentTheta.size(); j++)
  187. diff += Math.pow(((DoubleWritable) tmpCurrentTheta.get(j)).get()
  188. - ((DoubleWritable) lastTheta.get(j)).get(), 2);
  189. if (/*
  190. * Math.sqrt(diff) < 0.00000000005d ||
  191. */Long.parseLong(context.getConfiguration().get("Max_Iter_Num")) == gradient.count
  192. .get()) {
  193. context.write(gradient.currentTheta.toArray());
  194. return true;
  195. }
  196. gradient.lastTheta = tmpCurrentTheta;
  197. gradient.currentTheta = tmpCurrentTheta;
  198. gradient.count.set(gradient.count.get() + 1);
  199. int n = (int) Long.parseLong(context.getConfiguration().get("Dimension"));
  200. /*
  201. * update 4: Important!!! Remember this step. Graph won't reset the
  202. * initial value for global variables at the beginning of each iteration
  203. */
  204. for (int i = 0; i < n; i++) {
  205. gradient.tmpGradient.set(i, new DoubleWritable(0));
  206. }
  207. return false;
  208. }
  209. }
  210. public static void main(String[] args) throws IOException {
  211. GraphJob job = new GraphJob();
  212. job.setGraphLoaderClass(LinearRegressionVertexReader.class);
  213. job.setRuntimePartitioning(false);
  214. job.setNumWorkers(3);
  215. job.setVertexClass(LinearRegressionVertex.class);
  216. job.setAggregatorClass(LinearRegressionAggregator.class);
  217. job.addInput(TableInfo.builder().tableName(args[0]).build());
  218. job.addOutput(TableInfo.builder().tableName(args[1]).build());
  219. job.setMaxIteration(Integer.parseInt(args[2])); // Numbers of Iteration
  220. job.setInt("Max_Iter_Num", Integer.parseInt(args[2]));
  221. job.setInt("Dimension", Integer.parseInt(args[3])); // Dimension
  222. job.setFloat("Alpha", Float.parseFloat(args[4])); // Learning rate
  223. long start = System.currentTimeMillis();
  224. job.run();
  225. System.out.println("Job Finished in "
  226. + (System.currentTimeMillis() - start) / 1000.0 + " seconds");
  227. }
  228. }
Thank you! We've received your feedback.