Function Values using Differential Evolution - java

How can I use differential evolution to find the maximum values of the function function f(x) = -x(x+1) from -500 to 500? I need this for a chess program I am making, I have begun researching on Differential Evolution and am still finding it quite difficult to understand, let alone use for a program. Can anyone please help me by introducing me to the algorithm in a simple way and possibly giving some example pseudo-code for such a program?

First, of all, sorry for the late reply.
I bet that you won't know the derivatives of the function that you'll be trying to max, that's why you want to use the Differential Evolution algorithm and not something like the Newton-Raphson method.
I found a great link that explains Differential Evolution in a straightforward manner: http://web.as.uky.edu/statistics/users/viele/sta705s06/diffev.pdf.
On the first page, there is a section with an explanation of the algorithm:
Let each generation of points consist of n points, with j terms in each.
Initialize an array with size j. Add a number j of distinct random x values from -500 to 500, the interval you are considering right now. Ideally, you would know around where the maximum value would be, and you would make it more probable for your x values to be there.
For each j, randomly select two points yj,1 and yj,2 uniformly from the set of points x
(m)
.
Construct a candidate point cj = x
(m)
j + α(yj,1 − yj,2). Basically the two y values involve
picking a random direction and distance, and the candidate is found by adding that random
direction and distance (scaled by α) to the current value.
Hmmm... This is a bit more complicated. Iterate through the array you made in the last step. For each x value, pick two random indexes (yj1 and yj2). Construct a candidate x value with cx = α(yj1 − yj2), where you choose your α. You can try experimenting with different values of alpha.
Check to see which one is larger, the candidate value or the x value at j. If the candidate value is larger, replace it for the x value at j.
Do this all until all of the values in the array are more or less similar.
Tahdah, any of the values of the array will be the maximum value. Just to reduce randomness (or maybe this is not important....), average them all together.
The more stringent you make the about method, the better approximations you will get, but the more time it will take.
For example, instead of Math.abs(a - b) <= alpha /10, I would do Math.abs(a - b) <= alpha /10000 to get a better approximation.
You will get a good approximation of the value that you want.
Happy coding!
Code I wrote for this response:
public class DifferentialEvolution {
public static final double alpha = 0.001;
public static double evaluate(double x) {
return -x*(x+1);
}
public static double max(int N) { // N is initial array size.
double[] xs = new double[N];
for(int j = 0; j < N; j++) {
xs[j] = Math.random()*1000.0 - 500.0; // Number from -500 to 500.
}
boolean done = false;
while(!done) {
for(int j = 0; j < N; j++) {
double yj1 = xs[(int)(Math.random()*N)]; // This might include xs[j], but that shouldn't be a problem.
double yj2 = xs[(int)(Math.random()*N)]; // It will only slow things down a bit.
double cj = xs[j] + alpha*(yj1-yj2);
if(evaluate(cj) > evaluate(xs[j])) {
xs[j] = cj;
}
}
double average = average(xs); // Edited
done = true;
for(int j = 0; j < N; j++) { // Edited
if(!about(xs[j], average)) { // Edited
done = false;
break;
}
}
}
return average(xs);
}
public static double average(double[] values) {
double sum = 0;
for(int i = 0; i < values.length; i++) {
sum += values[i];
}
return sum/values.length;
}
public static boolean about(double a, double b) {
if(Math.abs(a - b) <= alpha /10000) { // This should work.
return true;
}
return false;
}
public static void main(String[] args) {
long t = System.currentTimeMillis();
System.out.println(max(3));
System.out.println("Time (Milliseconds): " + (System.currentTimeMillis() - t));
}
}
If you have any questions after reading this, feel free to ask them in the comments. I'll do my best to help.

Related

Neural Network In Java Failing to Back Propogate

I have written code for a neural network but when I train my network it does not produce the desired output (network not learning and sometimes NaN values when training). What wrong with my back propagation algorithm? Attached below is how I derived the formula for weight and bias gradients respectively. Full code can be found here.
public double[][] predict(double[][] input) {
if(input.length != this.activations.get(0).length || input[0].length != this.activations.get(0)[0].length) {
throw new IllegalArgumentException("Prediction Error!");
}
this.activations.set(0, input);
for(int i = 1; i < this.activations.size(); i++) {
this.activations.set(i, this.sigmoid(this.add(this.multiply(this.weights.get(i-1), this.activations.get(i-1)), this.biases.get(i-1))));
}
return this.activations.get(this.n-1);
}
public void train(double[][] input, double[][] target) {
//calculate activations
this.predict(input);
//calculate weight gradients
for(int l = 0; l < this.weightGradients.size(); l++) {
for(int i = 0; i < this.weightGradients.get(l).length; i++) {
for(int j = 0; j < this.weightGradients.get(l)[0].length; j++) {
this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l, i, j, target);
}
}
}
//calculated bias gradients
for(int l = 0; l < this.biasGradients.size(); l++) {
for(int i = 0; i < this.biasGradients.get(l).length; i++) {
for(int j = 0; j < this.biasGradients.get(l)[0].length; j++) {
this.biasGradients.get(l)[i][j] = this.gradientOfBias(l, i, j, target);
}
}
}
//apply gradient
for(int i = 0; i < this.weights.size(); i++) {
this.weights.set(i, this.subtract(this.weights.get(i), this.weightGradients.get(i)));
}
for(int i = 0; i < this.biases.size(); i++) {
this.biases.set(i, this.subtract(this.biases.get(i), this.biasGradients.get(i)));
}
}
private double gradientOfWeight(int l, int i, int j, double[][] t) { //when referring to A, use l+1 because A[0] is input vector, n-1 because n starts at 1
double z = (this.activations.get(l + 1)[i][0] * (1.0 - this.activations.get(l + 1)[i][0]) * this.activations.get(l)[j][0]);
if((l + 1) < (this.n - 1)) {
double sum = 0.0;
for(int k = 0; k < this.weights.get(l + 1).length; k++) {
sum += this.gradientOfWeight(l + 1, k, i, t)*this.weights.get(l + 1)[k][i];
}
return ((z * sum) / this.activations.get(l + 1)[i][0]);
} else if((l + 1) == (this.n - 1)) {
return 2.0 * (this.activations.get(l + 1)[i][0] - t[i][0]) * z;
}
throw new IllegalArgumentException("Weight Gradient Calculation Error!");
}
The amount of math that's involved in this question combined with the lack of data/reproduction of your code makes it nearly impossible to answer the original question of "where is my NaN".
Instead, I would propose you reconsider this question to be a simpler one, "How can I tell where a value like NaN is coming from in my code".
If you can run your code in an IDE, most of them will support conditional breakpoints. i.e. breakpoints that will pause your code whenever a variable reaches a value. In your case, I would recommend running your code in your preferred IDE with a conditional breakpoint detecting a value is NaN.
You can read more about how you would set it in this SO post where the topic of NaN double checking is nicely mentioned in this thread:
Eclipse Debugger doesn't stop at conditional breakpoint
Another follow-up consideration is to think WHERE you need to put these breakpoints. The short answer is to put them wherever a double is computed, because any of these computations might introduce the NaN.
To that effect, I make the following two recommendations:
First, put a breakpoint where you currently compute doubles to see if NaN's come from these computations. That would be these two variables:
double z = ...
double sum = ...
Second, refactor your calls to gradientOfWeight to return into a temporary variable, and then put a similar breakpoint on THOSE interrim computations.
So instead of
this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l, i, j, target);
You would have:
double interrimComputationToListenForNaNon = this.gradientOfWeight(l, i, j, target);
this.weightGradients.get(l)[i][j] = interrimComputationToListenForNaNon;
Having these interrim variables is more of a convenience to give you an easy way to monitor the computation without changing the call in any significant way. There may be a smarter way to do that without requiring an interrim variable, but this one seems to easiest to monitor and explain.
The NaN you see is due to underflow, you need to use BigDecimal class instead of double for higher precision. Refer these for better understanding bigdecimal class java sample use , BigDecimal API Reference

how can i take the derivative of the softmax output in back-prop

So I am new to ML and trying to make a simple "library" so I can learn more about neural networks.
My question:
According to my understanding I have to take the derivative of each layer according to their activation function so I can calculate their deltas and adjust their weights etc...
For ReLU, sigmoid, tanh, it's super simple to implement them in Java (which is the language I am using BTW)
But to go from output to the input I have to start from (obviously) the output which has an activation function of softmax.
So do I have to take the derivative of the output layer as well or does it just apply to every other layer?
If I do have to get the derivative how can I implement the derivative that in Java?
Thanks.
I have read a lot of pages with the explanation of the derivative of the softmax algorithm but they were really complicated for me and as I said I just started to learn ML and I didn't wanted to use a library off the shelf so here I am.
This is the class I store my activation functions.
public class ActivationFunction {
public static double tanh(double val) {
return Math.tanh(val);
}
public static double sigmoid(double val) {
return 1 / 1 + Math.exp(-val);
}
public static double relu(double val) {
return Math.max(val, 0);
}
public static double leaky_relu(double val) {
double result = 0;
if (val > 0) result = val;
else result = val * 0.01;
return result;
}
public static double[] softmax(double[] array) {
double max = max(array);
for (int i = 0; i < array.length; i++) {
array[i] = array[i] - max;
}
double sum = 0;
double[] result = new double[array.length];
for (int i = 0; i < array.length; i++) {
sum += Math.exp(array[i]);
}
for (int i = 0; i < result.length; i++) {
result[i] = Math.exp(array[i]) / sum;
}
return result;
}
public static double dTanh(double x) {
double tan = Math.tanh(x);
return (1 / tan) - tan;
}
public static double dSigmoid(double x) {
return x * (1 - x);
}
public static double dRelu(double x) {
double result;
if (x > 0) result = 1;
else result = 0;
return result;
}
public static double dLeaky_Relu(double x) {
double result;
if (x > 0) result = 1;
else if (x < 0) result = 0.01;
else result = 0;
return result;
}
private static double max(double[] array) {
double result = Double.MIN_VALUE;
for (int i = 0; i < array.length; i++) {
if (array[i] > result) result = array[i];
}
return result;
}
}
I am expecting to get the answer for the question: Do I need the derivative of softmax or not?
If so how can I implement it?
A short answer to your first question is yes, you need to compute the derivative of softmax.
The longer version will involve some computation since in order to implement backpropagation you train your network by means of first-order optimization algorithm that requires to calculate partial derivatives of the cost function w.r.t the weights, i.e.:
However, since you are using the softmax for your last layer, it is very likely that you are going to optimize a cross-entropy cost function while training your neural network, namely:
where tj is a target value and aj is a softmax result for class j.
Softmax itself represents a probability distribution over n classes:
where all of z's are simple sums of the result of activation functions of previous layers times the corresponding weights:
where n is the number of layer, i is the number of neuron in the previous layer and j is the number of neuron in our softmax layer.
So in order to take partial derivatives with respect to any of these weights, one should calculate:
where second partial derivative ∂ak/∂zj is indeed the softmax derivative and can be computed in the following way:
But if you try to compute the aforementioned sum term of the derivative of the cost function w.r.t. the weights, you will get:
So in this particular case it turns out that the final result of the computation is quite neat and represents a simple difference between the outputs of the network and the target values, and that's it, i.e., all you need to compute this sum term of partial derivatives is just:
So to answer your second question, you can combine computation of the partial derivative of the cross-entropy cost function w.r.t output activation (i.e. softmax) together with the partial derivative of the output activation w.r.t. zj which results in a short and clear implementation, if you are using a non-vectorized form, it will look like this:
for (int i = 0; i < lenOfClasses; ++i)
{
dCdz[i] = t[i] - a[i];
}
And subsequently you can use dCdz for backpropagating to the rest of the layers of the neural network.

how to improve this code?

I have developed a code for expressing the number in terms of the power of the 2 and I am attaching the same code below.
But the problem is that the expressed output should of minimum length.
I am getting output as 3^2+1^2+1^2+1^2 which is not minimum length.
I need to output in this format:
package com.algo;
import java.util.Scanner;
public class GetInputFromUser {
public static void main(String[] args) {
// TODO Auto-generated method stub
int n;
Scanner in = new Scanner(System.in);
System.out.println("Enter an integer");
n = in.nextInt();
System.out.println("The result is:");
algofunction(n);
}
public static int algofunction(int n1)
{
int r1 = 0;
int r2 = 0;
int r3 = 0;
//System.out.println("n1: "+n1);
r1 = (int) Math.sqrt(n1);
r2 = (int) Math.pow(r1, 2);
// System.out.println("r1: "+r1);
//System.out.println("r2: "+r2);
System.out.print(r1+"^2");
r3 = n1-r2;
//System.out.println("r3: "+r3);
if (r3 == 0)
return 1;
if(r3 == 1)
{
System.out.print("+1^2");
return 1;
}
else {
System.out.print("+");
algofunction(r3);
return 1;
}
}
}
Dynamic programming is all about defining the problem in such a way that if you knew the answer to a smaller version of the original, you could use that to answer the main problem more quickly/directly. It's like applied mathematical induction.
In your particular problem, we can define MinLen(n) as the minimum length representation of n. Next, say, since we want to solve MinLen(12), suppose we already knew the answer to MinLen(1), MinLen(2), MinLen(3), ..., MinLen(11). How could we use the answer to those smaller problems to figure out MinLen(12)? This is the other half of dynamic programming - figuring out how to use the smaller problems to solve the bigger one. It doesn't help you if you come up with some smaller problem, but have no way of combining them back together.
For this problem, we can make the simple statement, "For 12, it's minimum length representation DEFINITELY has either 1^2, 2^2, or 3^2 in it." And in general, the minimum length representation of n will have some square less than or equal to n as a part of it. There is probably a better statement you can make, which would improve the runtime, but I'll say that it is good enough for now.
This statement means that MinLen(12) = 1^2 + MinLen(11), OR 2^2 + MinLen(8), OR 3^2 + MinLen(3). You check all of them and select the best one, and now you save that as MinLen(12). Now, if you want to solve MinLen(13), you can do that too.
Advice when solo:
The way I would test this kind of program myself is to plug in 1, 2, 3, 4, 5, etc, and see the first time it goes wrong. Additionally, any assumptions I happen to have thought were a good idea, I question: "Is it really true that the largest square number less than n will be in the representation of MinLen(n)?"
Your code:
r1 = (int) Math.sqrt(n1);
r2 = (int) Math.pow(r1, 2);
embodies that assumption (a greedy assumption), but it is wrong, as you've clearly seen with the answer for MinLen(12).
Instead you want something more like this:
public ArrayList<Integer> minLen(int n)
{
// base case of recursion
if (n == 0)
return new ArrayList<Integer>();
ArrayList<Integer> best = null;
int bestInt = -1;
for (int i = 1; i*i <= n; ++i)
{
// Check what happens if we use i^2 as part of our representation
ArrayList<Integer> guess = minLen(n - i*i);
// If we haven't selected a 'best' yet (best == null)
// or if our new guess is better than the current choice (guess.size() < best.size())
// update our choice of best
if (best == null || guess.size() < best.size())
{
best = guess;
bestInt = i;
}
}
best.add(bestInt);
return best;
}
Then, once you have your list, you can sort it (no guarantees that it came in sorted order), and print it out the way you want.
Lastly, you may notice that for larger values of n (1000 may be too large) that you plug in to the above recursion, it will start going very slow. This is because we are constantly recalculating all the small subproblems - for example, we figure out MinLen(3) when we call MinLen(4), because 4 - 1^2 = 3. But we figure it out twice for MinLen(7) -> 3 = 7 - 2^2, but 3 also is 7 - 1^2 - 1^2 - 1^2 - 1^2. And it gets much worse the larger you go.
The solution to this, which lets you solve up to n = 1,000,000 or more, very quickly, is to use a technique called Memoization. This means that once we figure out MinLen(3), we save it somewhere, let's say a global location to make it easy. Then, whenever we would try to recalculate it, we check the global cache first to see if we already did it. If so, then we just use that, instead of redoing all the work.
import java.util.*;
class SquareRepresentation
{
private static HashMap<Integer, ArrayList<Integer>> cachedSolutions;
public static void main(String[] args)
{
cachedSolutions = new HashMap<Integer, ArrayList<Integer>>();
for (int j = 100000; j < 100001; ++j)
{
ArrayList<Integer> answer = minLen(j);
Collections.sort(answer);
Collections.reverse(answer);
for (int i = 0; i < answer.size(); ++i)
{
if (i != 0)
System.out.printf("+");
System.out.printf("%d^2", answer.get(i));
}
System.out.println();
}
}
public static ArrayList<Integer> minLen(int n)
{
// base case of recursion
if (n == 0)
return new ArrayList<Integer>();
// new base case: problem already solved once before
if (cachedSolutions.containsKey(n))
{
// It is a bit tricky though, because we need to be careful!
// See how below that we are modifying the 'guess' array we get in?
// That means we would modify our previous solutions! No good!
// So here we need to return a copy
ArrayList<Integer> ans = cachedSolutions.get(n);
ArrayList<Integer> copy = new ArrayList<Integer>();
for (int i: ans) copy.add(i);
return copy;
}
ArrayList<Integer> best = null;
int bestInt = -1;
// THIS IS WRONG, can you figure out why it doesn't work?:
// for (int i = 1; i*i <= n; ++i)
for (int i = (int)Math.sqrt(n); i >= 1; --i)
{
// Check what happens if we use i^2 as part of our representation
ArrayList<Integer> guess = minLen(n - i*i);
// If we haven't selected a 'best' yet (best == null)
// or if our new guess is better than the current choice (guess.size() < best.size())
// update our choice of best
if (best == null || guess.size() < best.size())
{
best = guess;
bestInt = i;
}
}
best.add(bestInt);
// check... not needed unless you coded wrong
int sum = 0;
for (int i = 0; i < best.size(); ++i)
{
sum += best.get(i) * best.get(i);
}
if (sum != n)
{
throw new RuntimeException(String.format("n = %d, sum=%d, arr=%s\n", n, sum, best));
}
// New step: Save the solution to the global cache
cachedSolutions.put(n, best);
// Same deal as before... if you don't return a copy, you end up modifying your previous solutions
//
ArrayList<Integer> copy = new ArrayList<Integer>();
for (int i: best) copy.add(i);
return copy;
}
}
It took my program around ~5s to run for n = 100,000. Clearly there is more to be done if we want it to be faster, and to solve for larger n. The main issue now is that in storing the entire list of results of previous answers, we use up a lot of memory. And all of that copying! There is more you can do, like storing only an integer and a pointer to the subproblem, but I'll let you do that.
And by the by, 1000 = 30^2 + 10^2.

Memoization of a Recursive Search

I am trying to solve a problem in which you have to count the number of possible bar codes you can make given specific parameters. I solved the problem recursively and am able to get the correct answer every time. However, my program is dreadfully slow. I tried to rectify this using a technique I read about called memoization but my program still crawls when given certain input (ex: 10, 10, 10). Here's the code in java.
Does anybody have any idea what I'm doing wrong here?
import java.util.Scanner;
//f(n, k, m) = sum (1 .. m) f(n - i, k - 1, m)
public class BarCode { public static int[][] memo;
public static int count(int units, int bars, int width) {
int sum = 0;
if (units >= 0 && memo[units][bars] != -1) //if the value has already been calculated return that value
return memo[units][bars];
for (int i = 1; i <= width; ++i) {
if (units == 0 && bars == 0)
return 1;
else if (bars == 0)
return 0;
else {
sum += count(units - i, bars - 1, width);
}
}
if (units > -1)
memo[units][bars] = sum;
return sum;
}
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
//while (in.hasNext()) {
int num = in.nextInt();
int bars = in.nextInt();
int width = in.nextInt();
memo = new int[51][51];
for (int i = 0; i < memo.length; ++i) {
for (int j = 0; j < memo.length; ++j)
memo[i][j] = -1;
}
int sum = 0;
sum += count(num, bars, width);
System.out.println(sum);
//}
in.close();
}
}
TL:DR My memoization of a recursive search is too slow. Help!
You exclude all results from count calls with units < 0 from memoization:
if (units > -1)
memo[units][bars] = sum;
This leads to a lot of unnecessary calls to count for these values.
To include all cases, you could use a HashMap with a key generated from units and bars values. I used a string generated from units and bars like this:
//f(n, k, m) = sum (1 .. m) f(n - i, k - 1, m)
public class BarCode {
public static Map<String, Integer> memo = new HashMap<>();
public static int count(int units, int bars, int width) {
int sum = 0;
final String key = units + " " + bars;
Integer memoSum = memo.get(key);
if (memoSum != null) {
return memoSum.intValue();
}
for (int i = 1; i <= width; ++i) {
if (units == 0 && bars == 0)
return 1;
else if (bars == 0)
return 0;
else {
sum += count(units - i, bars - 1, width);
}
}
memo.put(key, Integer.valueOf(sum));
return sum;
}
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
int num = in.nextInt();
int bars = in.nextInt();
int width = in.nextInt();
memo = new HashMap<>();
int sum = 0;
sum += count(num, bars, width);
System.out.println(sum);
in.close();
}
}
For example, this brings the number of calls to count down from over 6 million to 4,150 for the input values "10 10 10" with 415 entries saved in the Map.
Your memoization implementation looks to be valid. It might help some, but the real problem here is your choice of algorithm.
From my cursory inspection of your code, on average a call to your count method will loop through width number of times. and each time it loops through, it goes a layer deeper by calling count again. It also looks like it's going to loop down bars layers deeper from the first layer. If my asymptotic analysis a few fingers of scotch in is correct, this would result in an algorithm which has a O(width^bars) runtime complexity. As you increase your input parameters, especially bars, the amount of steps your application needs to take in order to calculate your answer will increase greatly (exponentially, in the case of bars).
Your memoization will reduce the number of duplicate calculations needed, but each value being memoized will still need to be calculated at least once for the memoization to help. So with or without the memoization, you're still dealing with a non-polynomial time complexity, and that always spells bad performance.
You might want to consider looking for a more efficient approach. Instead of trying to count the number of bar code combinations, perhaps try using combinatorics to try to calculate it. For example, I could try to figure out the number of lowercase character strings (using only chars a-z) I can make for a string of length n by generating all of them and counting how many of them there are, but that will have an exponential time complexity and will not be performant. On the other hand, I know basic combinatorics tells me that the formula for the number of strings I can create is 26^n (26 choices in each position, and n positions), which the computer can easily evaluate quickly.
Look for a similar approach for computing the number of bar codes.

Improving a prime sieve algorithm

I'm trying to make a decent Java program that generates the primes from 1 to N (mainly for Project Euler problems).
At the moment, my algorithm is as follows:
Initialise an array of booleans (or a bitarray if N is sufficiently large) so they're all false, and an array of ints to store the primes found.
Set an integer, s equal to the lowest prime, (ie 2)
While s is <= sqrt(N)
Set all multiples of s (starting at s^2) to true in the array/bitarray.
Find the next smallest index in the array/bitarray which is false, use that as the new value of s.
Endwhile.
Go through the array/bitarray, and for every value that is false, put the corresponding index in the primes array.
Now, I've tried skipping over numbers not of the form 6k + 1 or 6k + 5, but that only gives me a ~2x speed up, whilst I've seen programs run orders of magnitudes faster than mine (albeit with very convoluted code), such as the one here
What can I do to improve?
Edit: Okay, here's my actual code (for N of 1E7):
int l = 10000000, n = 2, sqrt = (int) Math.sqrt(l);
boolean[] nums = new boolean[l + 1];
int[] primes = new int[664579];
while(n <= sqrt){
for(int i = 2 * n; i <= l; nums[i] = true, i += n);
for(n++; nums[n]; n++);
}
for(int i = 2, k = 0; i < nums.length; i++) if(!nums[i]) primes[k++] = i;
Runs in about 350ms on my 2.0GHz machine.
While s is <= sqrt(N)
One mistake people often do in such algorithms is not precomputing square root.
while (s <= sqrt(N)) {
is much, much slower than
int limit = sqrt(N);
while (s <= limit) {
But generally speaking, Eiko is right in his comment. If you want people to offer low-level optimisations, you have to provide code.
update Ok, now about your code.
You may notice that number of iterations in your code is just little bigger than 'l'. (you may put counter inside first 'for' loop, it will be just 2-3 times bigger) And, obviously, complexity of your solution can't be less then O(l) (you can't have less than 'l' iterations).
What can make real difference is accessing memory effectively. Note that guy who wrote that article tries to reduce storage size not just because he's memory-greedy. Making compact arrays allows you to employ cache better and thus increase speed.
I just replaced boolean[] with int[] and achieved immediate x2 speed gain. (and 8x memory) And I didn't even try to do it efficiently.
update2
That's easy. You just replace every assignment a[i] = true with a[i/32] |= 1 << (i%32) and each read operation a[i] with (a[i/32] & (1 << (i%32))) != 0. And boolean[] a with int[] a, obviously.
From the first replacement it should be clear how it works: if f(i) is true, then there's a bit 1 in an integer number a[i/32], at position i%32 (int in Java has exactly 32 bits, as you know).
You can go further and replace i/32 with i >> 5, i%32 with i&31. You can also precompute all 1 << j for each j between 0 and 31 in array.
But sadly, I don't think in Java you could get close to C in this. Not to mention, that guy uses many other tricky optimizations and I agree that his could would've been worth a lot more if he made comments.
Using the BitSet will use less memory. The Sieve algorithm is rather trivial, so you can simply "set" the bit positions on the BitSet, and then iterate to determine the primes.
Did you also make the array smaller while skipping numbers not of the form 6k+1 and 6k+5?
I only tested with ignoring numbers of the form 2k and that gave me ~4x speed up (440 ms -> 120 ms):
int l = 10000000, n = 1, sqrt = (int) Math.sqrt(l);
int m = l/2;
boolean[] nums = new boolean[m + 1];
int[] primes = new int[664579];
int i, k;
while (n <= sqrt) {
int x = (n<<1)+1;
for (i = n+x; i <= m; nums[i] = true, i+=x);
for (n++; nums[n]; n++);
}
primes[0] = 2;
for (i = 1, k = 1; i < nums.length; i++) {
if (!nums[i])
primes[k++] = (i<<1)+1;
}
The following is from my Project Euler Library...Its a slight Variation of the Sieve of Eratosthenes...I'm not sure, but i think its called the Euler Sieve.
1) It uses a BitSet (so 1/8th the memory)
2) Only uses the bitset for Odd Numbers...(another 1/2th hence 1/16th)
Note: The Inner loop (for multiples) begins at "n*n" rather than "2*n" and also multiples of increment "2*n" are only crossed off....hence the speed up.
private void beginSieve(int mLimit)
{
primeList = new BitSet(mLimit>>1);
primeList.set(0,primeList.size(),true);
int sqroot = (int) Math.sqrt(mLimit);
primeList.clear(0);
for(int num = 3; num <= sqroot; num+=2)
{
if( primeList.get(num >> 1) )
{
int inc = num << 1;
for(int factor = num * num; factor < mLimit; factor += inc)
{
//if( ((factor) & 1) == 1)
//{
primeList.clear(factor >> 1);
//}
}
}
}
}
and here's the function to check if a number is prime...
public boolean isPrime(int num)
{
if( num < maxLimit)
{
if( (num & 1) == 0)
return ( num == 2);
else
return primeList.get(num>>1);
}
return false;
}
You could do the step of "putting the corresponding index in the primes array" while you are detecting them, taking out a run through the array, but that's about all I can think of right now.
I wrote a simple sieve implementation recently for the fun of it using BitSet (everyone says not to, but it's the best off the shelf way to store huge data efficiently). The performance seems to be pretty good to me, but I'm still working on improving it.
public class HelloWorld {
private static int LIMIT = 2140000000;//Integer.MAX_VALUE broke things.
private static BitSet marked;
public static void main(String[] args) {
long startTime = System.nanoTime();
init();
sieve();
long estimatedTime = System.nanoTime() - startTime;
System.out.println((float)estimatedTime/1000000000); //23.835363 seconds
System.out.println(marked.size()); //1070000000 ~= 127MB
}
private static void init()
{
double size = LIMIT * 0.5 - 1;
marked = new BitSet();
marked.set(0,(int)size, true);
}
private static void sieve()
{
int i = 0;
int cur = 0;
int add = 0;
int pos = 0;
while(((i<<1)+1)*((i<<1)+1) < LIMIT)
{
pos = i;
if(marked.get(pos++))
{
cur = pos;
add = (cur<<1);
pos += add*cur + cur - 1;
while(pos < marked.length() && pos > 0)
{
marked.clear(pos++);
pos += add;
}
}
i++;
}
}
private static void readPrimes()
{
int pos = 0;
while(pos < marked.length())
{
if(marked.get(pos++))
{
System.out.print((pos<<1)+1);
System.out.print("-");
}
}
}
}
With smaller LIMITs (say 10,000,000 which took 0.077479s) we get much faster results than the OP.
I bet java's performance is terrible when dealing with bits...
Algorithmically, the link you point out should be sufficient
Have you tried googling, e.g. for "java prime numbers". I did and dug up this simple improvement:
http://www.anyexample.com/programming/java/java_prime_number_check_%28primality_test%29.xml
Surely, you can find more at google.
Here is my code for Sieve of Erastothenes and this is actually the most efficient that I could do:
final int MAX = 1000000;
int p[]= new int[MAX];
p[0]=p[1]=1;
int prime[] = new int[MAX/10];
prime[0]=2;
void sieve()
{
int i,j,k=1;
for(i=3;i*i<=MAX;i+=2)
{
if(p[i])
continue;
for(j=i*i;j<MAX;j+=2*i)
p[j]=1;
}
for(i=3;i<MAX;i+=2)
{
if(p[i]==0)
prime[k++]=i;
}
return;
}

Categories