Memoization of a Recursive Search - java

I am trying to solve a problem in which you have to count the number of possible bar codes you can make given specific parameters. I solved the problem recursively and am able to get the correct answer every time. However, my program is dreadfully slow. I tried to rectify this using a technique I read about called memoization but my program still crawls when given certain input (ex: 10, 10, 10). Here's the code in java.
Does anybody have any idea what I'm doing wrong here?
import java.util.Scanner;
//f(n, k, m) = sum (1 .. m) f(n - i, k - 1, m)
public class BarCode { public static int[][] memo;
public static int count(int units, int bars, int width) {
int sum = 0;
if (units >= 0 && memo[units][bars] != -1) //if the value has already been calculated return that value
return memo[units][bars];
for (int i = 1; i <= width; ++i) {
if (units == 0 && bars == 0)
return 1;
else if (bars == 0)
return 0;
else {
sum += count(units - i, bars - 1, width);
}
}
if (units > -1)
memo[units][bars] = sum;
return sum;
}
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
//while (in.hasNext()) {
int num = in.nextInt();
int bars = in.nextInt();
int width = in.nextInt();
memo = new int[51][51];
for (int i = 0; i < memo.length; ++i) {
for (int j = 0; j < memo.length; ++j)
memo[i][j] = -1;
}
int sum = 0;
sum += count(num, bars, width);
System.out.println(sum);
//}
in.close();
}
}
TL:DR My memoization of a recursive search is too slow. Help!

You exclude all results from count calls with units < 0 from memoization:
if (units > -1)
memo[units][bars] = sum;
This leads to a lot of unnecessary calls to count for these values.
To include all cases, you could use a HashMap with a key generated from units and bars values. I used a string generated from units and bars like this:
//f(n, k, m) = sum (1 .. m) f(n - i, k - 1, m)
public class BarCode {
public static Map<String, Integer> memo = new HashMap<>();
public static int count(int units, int bars, int width) {
int sum = 0;
final String key = units + " " + bars;
Integer memoSum = memo.get(key);
if (memoSum != null) {
return memoSum.intValue();
}
for (int i = 1; i <= width; ++i) {
if (units == 0 && bars == 0)
return 1;
else if (bars == 0)
return 0;
else {
sum += count(units - i, bars - 1, width);
}
}
memo.put(key, Integer.valueOf(sum));
return sum;
}
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
int num = in.nextInt();
int bars = in.nextInt();
int width = in.nextInt();
memo = new HashMap<>();
int sum = 0;
sum += count(num, bars, width);
System.out.println(sum);
in.close();
}
}
For example, this brings the number of calls to count down from over 6 million to 4,150 for the input values "10 10 10" with 415 entries saved in the Map.

Your memoization implementation looks to be valid. It might help some, but the real problem here is your choice of algorithm.
From my cursory inspection of your code, on average a call to your count method will loop through width number of times. and each time it loops through, it goes a layer deeper by calling count again. It also looks like it's going to loop down bars layers deeper from the first layer. If my asymptotic analysis a few fingers of scotch in is correct, this would result in an algorithm which has a O(width^bars) runtime complexity. As you increase your input parameters, especially bars, the amount of steps your application needs to take in order to calculate your answer will increase greatly (exponentially, in the case of bars).
Your memoization will reduce the number of duplicate calculations needed, but each value being memoized will still need to be calculated at least once for the memoization to help. So with or without the memoization, you're still dealing with a non-polynomial time complexity, and that always spells bad performance.
You might want to consider looking for a more efficient approach. Instead of trying to count the number of bar code combinations, perhaps try using combinatorics to try to calculate it. For example, I could try to figure out the number of lowercase character strings (using only chars a-z) I can make for a string of length n by generating all of them and counting how many of them there are, but that will have an exponential time complexity and will not be performant. On the other hand, I know basic combinatorics tells me that the formula for the number of strings I can create is 26^n (26 choices in each position, and n positions), which the computer can easily evaluate quickly.
Look for a similar approach for computing the number of bar codes.

Related

CoinChange Problem with DP in Java using 2D array

I am implementing an algorithm to solve the Coin Change problem, where given an array that indicates types of coins (i.e. int[] coinValues = {1,4,6};) and a value to achieve (i.e. int totalAmount=8;), an array is returned where the value at position 0 indicates the minimum number of coins needed to achieve totalAmount. The rest of the array will keep a track of how many coins are needed to achieve the total sum.
An example input of coins = {1,4,6} and total = 8 should return the array [3,2,0,1]. However, my code is returning [1,2,0,1].
Another example would be coins = {2,4,8,16,34,40,64} and total = 50 should return the array [2, 0, 0, 0, 1, 1, 0, 0]. My code is not returning that result.
The algorithm is implemented with 2 methods: CoinChange and CoinCount. CoinChange creates the coin matrix and CoinCount keeps track of the coins required to achieve the total sum.
package P5;
import java.util.Arrays;
public class CoinChange7 {
public static int[] CoinChange(int[] v, int sum) {
int[][] aux = new int[v.length + 1][sum + 1];
// Initialising first column with 0
for(int i = 1; i <= v.length; i++) {
aux[i][0] = 0;
}
// Implementing the recursive solution
for(int i = 1; i <= v.length-1; i++) {
for(int j = 1; j <= sum; j++) {
if(i == 1) {
if(v[1] > j) {
aux[i][0]=999999999; //instead of Integer.MAX_VALUE
} else {
aux[i][j]=1 + aux[1][j-v[1]];
}
} else {
if(v[i] > j) { //choose best option ,discard this coin or use it.
aux[i][j] = aux[i - 1][j];
} else
aux[i][j] = Math.min(aux[i-1][j],1 + aux[i][j-v[i]]);
}
}
}
int []z=CoinCount(sum,aux,v);
return z; // Return solution to the initial problem
}
public static int[] CoinCount(int A, int[][] aux, int[] d){
int coin = d.length-1;
int limit=A;
int [] typo=new int[d.length+1]; //We create solution array, that will count no of coins
for (int k=0;k<typo.length;k++) {
typo[k]=0;
} while (coin>0 || limit>0){
if(limit-d[coin]>=0 && coin-1>=0){
if(1+aux[coin][limit-d[coin]]<aux[coin-1][limit]){
typo[coin+1]=typo[coin+1]+1;
limit=limit-d[coin];
} else {
coin=coin-1;
}
} else if(limit-d[coin]>=0) {
typo[coin+1]=typo[coin+1]+1;
limit=limit-d[coin];
} else if(coin-1>=0) {
coin=coin-1;
}
}
typo[0]= aux[d.length-1][A];
return typo; //return the final array with solutions of each coin
}
public static void main(String[] args) {
int[] coins = {1,4,6};
int sum = 8;
int[] x=CoinChange(coins,sum);
System.out.println("At least " + Arrays.toString(x) +" from set "+ Arrays.toString(coins)
+ " coins are required to make a value of " + sum);
}
}
Clarification
I don't know if you still need the answer to this question but I will try to answer it anyway.
First, there are a few things I would like to clarify. The coin change problem does not have a unique solution. If you want both the minimum of coins used to make the change and frequencies of coins usage, I think that depends on the approach used to solve the program and the arrangement of the coins.
For example: Take the coins to be [4,6,8] and amount = 12. You'll quickly see that the minimum coins required to make this change is 2. Going by your choice of output, the following are all correct: [2,0,2,0] and [2,1,0,1].
By the way, the Coin change problem can be solved in many ways. A simple recursive DP approach in Java is here. It only returns the min coins needed to make the change at O(nlog(n)) time and O(n) space.
Another approach is by using a 2D DP matrix (same with the approach you tried using) at both O(n^2) time and space. Explanation on how to use this approach is here. Please be careful with the explanation because it is not generally correct. I noticed it's almost the same as the one you used.
Your solution
I will mention a few things about your solution that may have affected the result.
The number of rows of the DP matrix is v.length not v.length + 1.
Based on your solution, this should not affect the result because I noticed you don't seem comfortable with zero indexes.
I think it is not necessary to initialize the first column of the DB matrix since the data type you used is int, which is 0 by default. Again, this does not affect the answer, though.
The way you filled row 1 (supposed to be the first row, but you ignored row 0) is not good and may affect the result of some solutions.
The only mistake I see there is that there is no uniform value to specify amounts (i.e. j) that cannot be solved using the single coin (i.e. v[0]). Negative numbers could have been better because any positive integer is a potential valid solution for the cell. You could use -1 (if you're going by the Leetcode instruction). This way, you'll easily know cells that contain invalid values while filling the rest of the matrix.
The way you compute aux[i][j] is wrong because you are using the wrong coins. you are using v[i] instead of v[i-1] since you aux.length is one bigger than the v.length.
I did not look at the countCoint method. It looks complex for a seemingly simple problem. Please see my solution.
My Solution
public int[] change(int[] coins, int amount){
int[][] DP = new int[coins.length][amount+1];
//fill the first column with 0
//int array contains 0 by default, so this part is not necessary
/*
for (int i = 0; i < coins.length; i++) {
DP[i][0] =0;
}
*/
//fill the first row.
//At 0th row, we are trying to find the min number of ways to change j amount using only
//one coin i.e. coins[0] (that is the meaning of DP[0][j];
for (int j = 1; j <= amount; j++) {
if(coins[0] > j || j % coins[0] != 0){
DP[0][j] = -1;
}else{
DP[0][j] = j /coins[0];
}
}
//iterate the rest of the unfilled DP
for (int i = 1; i < coins.length; i++) {
for (int j = 1; j < amount+1; j++) {
if(coins[i] > j){
DP[i][j] = DP[i-1][j];
}else {
int prev = DP[i-1][j];
int cur = 1+DP[i][j-coins[i]];
if(cur == 0){
DP[i][j] = DP[i-1][j];
} else if(prev == -1) {
DP[i][j] = 1 + DP[i][j - coins[i]];
}else{
DP[i][j] = Math.min(DP[i-1][j],1+DP[i][j-coins[i]]);
}
}
}
}
return countCoin(coins,amount,DP);
}
public int[] countCoin(int[] coins, int amount, int[][] DP){
int[] result = new int[coins.length+1];//The 1 added is to hold result.
int i = coins.length -1;
int j = amount;
//while the rest will contain counter for coins used.
result[0] = DP[i][j];
if(result[0] ==0 || result[0] ==-1)return result;
while (j > 0 ){
if(i-1 >= 0 && DP[i][j] == DP[i-1][j]){
i = i-1;
}else{
j = j - coins[i];
result[i+1] += 1;
}
}
return result;
}

Hashmap memoization slower than directly computing the answer

I've been playing around with the Project Euler challenges to help improve my knowledge of Java. In particular, I wrote the following code for problem 14, which asks you to find the longest Collatz chain which starts at a number below 1,000,000. It works on the assumption that subchains are incredibly likely to arise more than once, and by storing them in a cache, no redundant calculations are done.
Collatz.java:
import java.util.HashMap;
public class Collatz {
private HashMap<Long, Integer> chainCache = new HashMap<Long, Integer>();
public void initialiseCache() {
chainCache.put((long) 1, 1);
}
private long collatzOp(long n) {
if(n % 2 == 0) {
return n/2;
}
else {
return 3*n +1;
}
}
public int collatzChain(long n) {
if(chainCache.containsKey(n)) {
return chainCache.get(n);
}
else {
int count = 1 + collatzChain(collatzOp(n));
chainCache.put(n, count);
return count;
}
}
}
ProjectEuler14.java:
public class ProjectEuler14 {
public static void main(String[] args) {
Collatz col = new Collatz();
col.initialiseCache();
long limit = 1000000;
long temp = 0;
long longestLength = 0;
long index = 1;
for(long i = 1; i < limit; i++) {
temp = col.collatzChain(i);
if(temp > longestLength) {
longestLength = temp;
index = i;
}
}
System.out.println(index + " has the longest chain, with length " + longestLength);
}
}
This works. And according to the "measure-command" command from Windows Powershell, it takes roughly 1708 milliseconds (1.708 seconds) to execute.
However, after reading through the forums, I noticed that some people, who had written seemingly naive code, which calculate each chain from scratch, seemed to be getting much better execution times than me. I (conceptually) took one of the answers, and translated it into Java:
NaiveProjectEuler14.java:
public class NaiveProjectEuler14 {
public static void main(String[] args) {
int longest = 0;
int numTerms = 0;
int i;
long j;
for (i = 1; i <= 10000000; i++) {
j = i;
int currentTerms = 1;
while (j != 1) {
currentTerms++;
if (currentTerms > numTerms){
numTerms = currentTerms;
longest = i;
}
if (j % 2 == 0){
j = j / 2;
}
else{
j = 3 * j + 1;
}
}
}
System.out.println("Longest: " + longest + " (" + numTerms + ").");
}
}
On my machine, this also gives the correct answer, but it gives it in 0.502 milliseconds - a third of the speed of my original program. At first I thought that maybe there was a small overhead in creating a HashMap, and that the times taken were too small to draw any conclusions. However, if I increase the upper limit from 1,000,000 to 10,000,000 in both programs, NaiveProjectEuler14 takes 4709 milliseconds (4.709 seconds), whilst ProjectEuler14 takes a whopping 25324 milliseconds (25.324 seconds)!
Why does ProjectEuler14 take so long? The only explanation I can fathom is that storing huge amounts of pairs in the HashMap data structure is adding a huge overhead, but I can't see why that should be the case. I've also tried recording the number of (key, value) pairs stored during the course of the program (2,168,611 pairs for the 1,000,000 case, and 21,730,849 pairs for the 10,000,000 case) and supplying a little over that number to the HashMap constructor so that it only has to resize itself at most once, but this does not seem to affect the execution times.
Does anyone have any rationale for why the memoized version is a lot slower?
There are some reasons for that unfortunate reality:
Instead of containsKey, do an immediate get and check for null
The code uses an extra method to be called
The map stores wrapped objects (Integer, Long) for primitive types
The JIT compiler translating byte code to machine code can do more with calculations
The caching does not concern a large percentage, like fibonacci
Comparable would be
public static void main(String[] args) {
int longest = 0;
int numTerms = 0;
int i;
long j;
Map<Long, Integer> map = new HashMap<>();
for (i = 1; i <= 10000000; i++) {
j = i;
Integer terms = map.get(i);
if (terms != null) {
continue;
}
int currentTerms = 1;
while (j != 1) {
currentTerms++;
if (currentTerms > numTerms){
numTerms = currentTerms;
longest = i;
}
if (j % 2 == 0){
j = j / 2;
// Maybe check the map only here
Integer m = map.get(j);
if (m != null) {
currentTerms += m;
break;
}
}
else{
j = 3 * j + 1;
}
}
map.put(j, currentTerms);
}
System.out.println("Longest: " + longest + " (" + numTerms + ").");
}
This does not really do an adequate memoization. For increasing parameters not checking the 3*j+1 somewhat decreases the misses (but might also skip meoized values).
Memoization lives from heavy calculation per call. If the function takes long because of deep recursion rather than calculation, the memoization overhead per function call counts negatively.

Implementing Euclid's Algorithm in Java

I've been trying to implement Euclid's algorithm in Java for 2 numbers or more.The problem with my code is that
a) It works fine for 2 numbers,but returns the correct value multiple times when more than 2 numbers are entered.My guess is that this is probably because of the return statements in my code.
b) I don't quite understand how it works.Though I coded it myself,I don't quite understand how the return statements are working.
import java.util.*;
public class GCDCalc {
static int big, small, remainder, gcd;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
// Remove duplicates from the arraylist containing the user input.
ArrayList<Integer> listofnum = new ArrayList();
System.out.println("GCD Calculator");
System.out.println("Enter the number of values you want to calculate the GCD of: ");
int counter = sc.nextInt();
for (int i = 0; i < counter; i++) {
System.out.println("Enter #" + (i + 1) + ": ");
int val = sc.nextInt();
listofnum.add(val);
}
// Sorting algorithm.
// This removed the need of conditional statements(we don't have to
// check if the 1st number is greater than the 2nd element
// before applying Euclid's algorithm.
// The outer loop ensures that the maximum number of swaps are occurred.
// It ensures the implementation of the swapping process as many times
// as there are numbers in the array.
for (int i = 0; i < listofnum.size(); i++) {
// The inner loop performs the swapping.
for (int j = 1; j < listofnum.size(); j++) {
if (listofnum.get(j - 1) > listofnum.get(j)) {
int dummyvar = listofnum.get(j);
int dummyvar2 = listofnum.get(j - 1);
listofnum.set(j - 1, dummyvar);
listofnum.set(j, dummyvar2);
}
}
}
// nodup contains the array containing the userinput,without any
// duplicates.
ArrayList<Integer> nodup = new ArrayList();
// Remove duplicates.
for (int i = 0; i < listofnum.size(); i++) {
if (!nodup.contains(listofnum.get(i))) {
nodup.add(listofnum.get(i));
}
}
// Since the array is sorted in ascending order,we can easily determine
// which of the indexes has the bigger and smaller values.
small = nodup.get(0);
big = nodup.get(1);
remainder = big % small;
if (nodup.size() == 2) {
recursion(big, small, remainder);
} else if (nodup.size() > 2) {
largerlist(nodup, big, small, 2);
} else // In the case,the array only consists of one value.
{
System.out.println("GCD: " + nodup.get(0));
}
}
// recursive method.
public static int recursion(int big, int small, int remainder) {
remainder = big % small;
if (remainder == 0) {
System.out.println(small);
} else {
int dummyvar = remainder;
big = small;
small = dummyvar;
recursion(big, small, remainder);
}
return small;
}
// Method to deal with more than 2 numbers.
public static void largerlist(ArrayList<Integer> list, int big, int small, int counter) {
remainder = big % small;
gcd = recursion(big, small, remainder);
if (counter == list.size()) {
} else if (counter != list.size()) {
big = gcd;
small = list.get(counter);
counter++;
largerlist(list, gcd, small, counter);
}
}
}
I apologize in advance for any formatting errors etc.
Any suggestions would be appreciated.Thanks!
I think these two assignments are the wrong way around
big = gcd;
small = list.get(counter);
and then big not used
largerlist(list, gcd, small, counter);
Also you've used static variables, which is usually a problem.
I suggest removing static/global variables and generally don't reuse variables.
Edit: Oh yes, return. You've ignored the return value of the recursion method when called from the recursion method. That shouldn't matter as you are printing out instead of returning the value, but such solutions break when, say, you want to use the function more than once.

how to improve this code?

I have developed a code for expressing the number in terms of the power of the 2 and I am attaching the same code below.
But the problem is that the expressed output should of minimum length.
I am getting output as 3^2+1^2+1^2+1^2 which is not minimum length.
I need to output in this format:
package com.algo;
import java.util.Scanner;
public class GetInputFromUser {
public static void main(String[] args) {
// TODO Auto-generated method stub
int n;
Scanner in = new Scanner(System.in);
System.out.println("Enter an integer");
n = in.nextInt();
System.out.println("The result is:");
algofunction(n);
}
public static int algofunction(int n1)
{
int r1 = 0;
int r2 = 0;
int r3 = 0;
//System.out.println("n1: "+n1);
r1 = (int) Math.sqrt(n1);
r2 = (int) Math.pow(r1, 2);
// System.out.println("r1: "+r1);
//System.out.println("r2: "+r2);
System.out.print(r1+"^2");
r3 = n1-r2;
//System.out.println("r3: "+r3);
if (r3 == 0)
return 1;
if(r3 == 1)
{
System.out.print("+1^2");
return 1;
}
else {
System.out.print("+");
algofunction(r3);
return 1;
}
}
}
Dynamic programming is all about defining the problem in such a way that if you knew the answer to a smaller version of the original, you could use that to answer the main problem more quickly/directly. It's like applied mathematical induction.
In your particular problem, we can define MinLen(n) as the minimum length representation of n. Next, say, since we want to solve MinLen(12), suppose we already knew the answer to MinLen(1), MinLen(2), MinLen(3), ..., MinLen(11). How could we use the answer to those smaller problems to figure out MinLen(12)? This is the other half of dynamic programming - figuring out how to use the smaller problems to solve the bigger one. It doesn't help you if you come up with some smaller problem, but have no way of combining them back together.
For this problem, we can make the simple statement, "For 12, it's minimum length representation DEFINITELY has either 1^2, 2^2, or 3^2 in it." And in general, the minimum length representation of n will have some square less than or equal to n as a part of it. There is probably a better statement you can make, which would improve the runtime, but I'll say that it is good enough for now.
This statement means that MinLen(12) = 1^2 + MinLen(11), OR 2^2 + MinLen(8), OR 3^2 + MinLen(3). You check all of them and select the best one, and now you save that as MinLen(12). Now, if you want to solve MinLen(13), you can do that too.
Advice when solo:
The way I would test this kind of program myself is to plug in 1, 2, 3, 4, 5, etc, and see the first time it goes wrong. Additionally, any assumptions I happen to have thought were a good idea, I question: "Is it really true that the largest square number less than n will be in the representation of MinLen(n)?"
Your code:
r1 = (int) Math.sqrt(n1);
r2 = (int) Math.pow(r1, 2);
embodies that assumption (a greedy assumption), but it is wrong, as you've clearly seen with the answer for MinLen(12).
Instead you want something more like this:
public ArrayList<Integer> minLen(int n)
{
// base case of recursion
if (n == 0)
return new ArrayList<Integer>();
ArrayList<Integer> best = null;
int bestInt = -1;
for (int i = 1; i*i <= n; ++i)
{
// Check what happens if we use i^2 as part of our representation
ArrayList<Integer> guess = minLen(n - i*i);
// If we haven't selected a 'best' yet (best == null)
// or if our new guess is better than the current choice (guess.size() < best.size())
// update our choice of best
if (best == null || guess.size() < best.size())
{
best = guess;
bestInt = i;
}
}
best.add(bestInt);
return best;
}
Then, once you have your list, you can sort it (no guarantees that it came in sorted order), and print it out the way you want.
Lastly, you may notice that for larger values of n (1000 may be too large) that you plug in to the above recursion, it will start going very slow. This is because we are constantly recalculating all the small subproblems - for example, we figure out MinLen(3) when we call MinLen(4), because 4 - 1^2 = 3. But we figure it out twice for MinLen(7) -> 3 = 7 - 2^2, but 3 also is 7 - 1^2 - 1^2 - 1^2 - 1^2. And it gets much worse the larger you go.
The solution to this, which lets you solve up to n = 1,000,000 or more, very quickly, is to use a technique called Memoization. This means that once we figure out MinLen(3), we save it somewhere, let's say a global location to make it easy. Then, whenever we would try to recalculate it, we check the global cache first to see if we already did it. If so, then we just use that, instead of redoing all the work.
import java.util.*;
class SquareRepresentation
{
private static HashMap<Integer, ArrayList<Integer>> cachedSolutions;
public static void main(String[] args)
{
cachedSolutions = new HashMap<Integer, ArrayList<Integer>>();
for (int j = 100000; j < 100001; ++j)
{
ArrayList<Integer> answer = minLen(j);
Collections.sort(answer);
Collections.reverse(answer);
for (int i = 0; i < answer.size(); ++i)
{
if (i != 0)
System.out.printf("+");
System.out.printf("%d^2", answer.get(i));
}
System.out.println();
}
}
public static ArrayList<Integer> minLen(int n)
{
// base case of recursion
if (n == 0)
return new ArrayList<Integer>();
// new base case: problem already solved once before
if (cachedSolutions.containsKey(n))
{
// It is a bit tricky though, because we need to be careful!
// See how below that we are modifying the 'guess' array we get in?
// That means we would modify our previous solutions! No good!
// So here we need to return a copy
ArrayList<Integer> ans = cachedSolutions.get(n);
ArrayList<Integer> copy = new ArrayList<Integer>();
for (int i: ans) copy.add(i);
return copy;
}
ArrayList<Integer> best = null;
int bestInt = -1;
// THIS IS WRONG, can you figure out why it doesn't work?:
// for (int i = 1; i*i <= n; ++i)
for (int i = (int)Math.sqrt(n); i >= 1; --i)
{
// Check what happens if we use i^2 as part of our representation
ArrayList<Integer> guess = minLen(n - i*i);
// If we haven't selected a 'best' yet (best == null)
// or if our new guess is better than the current choice (guess.size() < best.size())
// update our choice of best
if (best == null || guess.size() < best.size())
{
best = guess;
bestInt = i;
}
}
best.add(bestInt);
// check... not needed unless you coded wrong
int sum = 0;
for (int i = 0; i < best.size(); ++i)
{
sum += best.get(i) * best.get(i);
}
if (sum != n)
{
throw new RuntimeException(String.format("n = %d, sum=%d, arr=%s\n", n, sum, best));
}
// New step: Save the solution to the global cache
cachedSolutions.put(n, best);
// Same deal as before... if you don't return a copy, you end up modifying your previous solutions
//
ArrayList<Integer> copy = new ArrayList<Integer>();
for (int i: best) copy.add(i);
return copy;
}
}
It took my program around ~5s to run for n = 100,000. Clearly there is more to be done if we want it to be faster, and to solve for larger n. The main issue now is that in storing the entire list of results of previous answers, we use up a lot of memory. And all of that copying! There is more you can do, like storing only an integer and a pointer to the subproblem, but I'll let you do that.
And by the by, 1000 = 30^2 + 10^2.

Improving a prime sieve algorithm

I'm trying to make a decent Java program that generates the primes from 1 to N (mainly for Project Euler problems).
At the moment, my algorithm is as follows:
Initialise an array of booleans (or a bitarray if N is sufficiently large) so they're all false, and an array of ints to store the primes found.
Set an integer, s equal to the lowest prime, (ie 2)
While s is <= sqrt(N)
Set all multiples of s (starting at s^2) to true in the array/bitarray.
Find the next smallest index in the array/bitarray which is false, use that as the new value of s.
Endwhile.
Go through the array/bitarray, and for every value that is false, put the corresponding index in the primes array.
Now, I've tried skipping over numbers not of the form 6k + 1 or 6k + 5, but that only gives me a ~2x speed up, whilst I've seen programs run orders of magnitudes faster than mine (albeit with very convoluted code), such as the one here
What can I do to improve?
Edit: Okay, here's my actual code (for N of 1E7):
int l = 10000000, n = 2, sqrt = (int) Math.sqrt(l);
boolean[] nums = new boolean[l + 1];
int[] primes = new int[664579];
while(n <= sqrt){
for(int i = 2 * n; i <= l; nums[i] = true, i += n);
for(n++; nums[n]; n++);
}
for(int i = 2, k = 0; i < nums.length; i++) if(!nums[i]) primes[k++] = i;
Runs in about 350ms on my 2.0GHz machine.
While s is <= sqrt(N)
One mistake people often do in such algorithms is not precomputing square root.
while (s <= sqrt(N)) {
is much, much slower than
int limit = sqrt(N);
while (s <= limit) {
But generally speaking, Eiko is right in his comment. If you want people to offer low-level optimisations, you have to provide code.
update Ok, now about your code.
You may notice that number of iterations in your code is just little bigger than 'l'. (you may put counter inside first 'for' loop, it will be just 2-3 times bigger) And, obviously, complexity of your solution can't be less then O(l) (you can't have less than 'l' iterations).
What can make real difference is accessing memory effectively. Note that guy who wrote that article tries to reduce storage size not just because he's memory-greedy. Making compact arrays allows you to employ cache better and thus increase speed.
I just replaced boolean[] with int[] and achieved immediate x2 speed gain. (and 8x memory) And I didn't even try to do it efficiently.
update2
That's easy. You just replace every assignment a[i] = true with a[i/32] |= 1 << (i%32) and each read operation a[i] with (a[i/32] & (1 << (i%32))) != 0. And boolean[] a with int[] a, obviously.
From the first replacement it should be clear how it works: if f(i) is true, then there's a bit 1 in an integer number a[i/32], at position i%32 (int in Java has exactly 32 bits, as you know).
You can go further and replace i/32 with i >> 5, i%32 with i&31. You can also precompute all 1 << j for each j between 0 and 31 in array.
But sadly, I don't think in Java you could get close to C in this. Not to mention, that guy uses many other tricky optimizations and I agree that his could would've been worth a lot more if he made comments.
Using the BitSet will use less memory. The Sieve algorithm is rather trivial, so you can simply "set" the bit positions on the BitSet, and then iterate to determine the primes.
Did you also make the array smaller while skipping numbers not of the form 6k+1 and 6k+5?
I only tested with ignoring numbers of the form 2k and that gave me ~4x speed up (440 ms -> 120 ms):
int l = 10000000, n = 1, sqrt = (int) Math.sqrt(l);
int m = l/2;
boolean[] nums = new boolean[m + 1];
int[] primes = new int[664579];
int i, k;
while (n <= sqrt) {
int x = (n<<1)+1;
for (i = n+x; i <= m; nums[i] = true, i+=x);
for (n++; nums[n]; n++);
}
primes[0] = 2;
for (i = 1, k = 1; i < nums.length; i++) {
if (!nums[i])
primes[k++] = (i<<1)+1;
}
The following is from my Project Euler Library...Its a slight Variation of the Sieve of Eratosthenes...I'm not sure, but i think its called the Euler Sieve.
1) It uses a BitSet (so 1/8th the memory)
2) Only uses the bitset for Odd Numbers...(another 1/2th hence 1/16th)
Note: The Inner loop (for multiples) begins at "n*n" rather than "2*n" and also multiples of increment "2*n" are only crossed off....hence the speed up.
private void beginSieve(int mLimit)
{
primeList = new BitSet(mLimit>>1);
primeList.set(0,primeList.size(),true);
int sqroot = (int) Math.sqrt(mLimit);
primeList.clear(0);
for(int num = 3; num <= sqroot; num+=2)
{
if( primeList.get(num >> 1) )
{
int inc = num << 1;
for(int factor = num * num; factor < mLimit; factor += inc)
{
//if( ((factor) & 1) == 1)
//{
primeList.clear(factor >> 1);
//}
}
}
}
}
and here's the function to check if a number is prime...
public boolean isPrime(int num)
{
if( num < maxLimit)
{
if( (num & 1) == 0)
return ( num == 2);
else
return primeList.get(num>>1);
}
return false;
}
You could do the step of "putting the corresponding index in the primes array" while you are detecting them, taking out a run through the array, but that's about all I can think of right now.
I wrote a simple sieve implementation recently for the fun of it using BitSet (everyone says not to, but it's the best off the shelf way to store huge data efficiently). The performance seems to be pretty good to me, but I'm still working on improving it.
public class HelloWorld {
private static int LIMIT = 2140000000;//Integer.MAX_VALUE broke things.
private static BitSet marked;
public static void main(String[] args) {
long startTime = System.nanoTime();
init();
sieve();
long estimatedTime = System.nanoTime() - startTime;
System.out.println((float)estimatedTime/1000000000); //23.835363 seconds
System.out.println(marked.size()); //1070000000 ~= 127MB
}
private static void init()
{
double size = LIMIT * 0.5 - 1;
marked = new BitSet();
marked.set(0,(int)size, true);
}
private static void sieve()
{
int i = 0;
int cur = 0;
int add = 0;
int pos = 0;
while(((i<<1)+1)*((i<<1)+1) < LIMIT)
{
pos = i;
if(marked.get(pos++))
{
cur = pos;
add = (cur<<1);
pos += add*cur + cur - 1;
while(pos < marked.length() && pos > 0)
{
marked.clear(pos++);
pos += add;
}
}
i++;
}
}
private static void readPrimes()
{
int pos = 0;
while(pos < marked.length())
{
if(marked.get(pos++))
{
System.out.print((pos<<1)+1);
System.out.print("-");
}
}
}
}
With smaller LIMITs (say 10,000,000 which took 0.077479s) we get much faster results than the OP.
I bet java's performance is terrible when dealing with bits...
Algorithmically, the link you point out should be sufficient
Have you tried googling, e.g. for "java prime numbers". I did and dug up this simple improvement:
http://www.anyexample.com/programming/java/java_prime_number_check_%28primality_test%29.xml
Surely, you can find more at google.
Here is my code for Sieve of Erastothenes and this is actually the most efficient that I could do:
final int MAX = 1000000;
int p[]= new int[MAX];
p[0]=p[1]=1;
int prime[] = new int[MAX/10];
prime[0]=2;
void sieve()
{
int i,j,k=1;
for(i=3;i*i<=MAX;i+=2)
{
if(p[i])
continue;
for(j=i*i;j<MAX;j+=2*i)
p[j]=1;
}
for(i=3;i<MAX;i+=2)
{
if(p[i]==0)
prime[k++]=i;
}
return;
}

Categories