I would like to multiply matrices very fast on a single core. I have looked around of the web and came across a few algorithms and found out Strassen's algorithm is the only one, that is actually implemented by people. I have looked on a few examples and came to the solution below. I made a simple benchmark which generates two randomly filled 500x500 matrices. Strassen's algorithm took 18 seconds, where the high school algorithm was done in 0.4 seconds. Other people where very promising after implementing the algorithm, so what is wrong with mine, how can I make it quicker?
// return C = A * B
private Matrix strassenTimes(Matrix B, int LEAFSIZE) {
Matrix A = this;
if (B.M != A.M || B.N != A.N) throw new RuntimeException("Illegal matrix dimensions.");
if (N <= LEAFSIZE || M <= LEAFSIZE) {
return A.times(B);
}
// make new sub-matrices
int newAcols = (A.N + 1) / 2;
int newArows = (A.M + 1) / 2;
Matrix a11 = new Matrix(newArows, newAcols);
Matrix a12 = new Matrix(newArows, newAcols);
Matrix a21 = new Matrix(newArows, newAcols);
Matrix a22 = new Matrix(newArows, newAcols);
int newBcols = (B.N + 1) / 2;
int newBrows = (B.M + 1) / 2;
Matrix b11 = new Matrix(newBrows, newBcols);
Matrix b12 = new Matrix(newBrows, newBcols);
Matrix b21 = new Matrix(newBrows, newBcols);
Matrix b22 = new Matrix(newBrows, newBcols);
for (int i = 1; i <= newArows; i++) {
for (int j = 1; j <= newAcols; j++) {
a11.setElement(i, j, A.saveGet(i, j)); // top left
a12.setElement(i, j, A.saveGet(i, j + newAcols)); // top right
a21.setElement(i, j, A.saveGet(i + newArows, j)); // bottom left
a22.setElement(i, j, A.saveGet(i + newArows, j + newAcols)); // bottom right
}
}
for (int i = 1; i <= newBrows; i++) {
for (int j = 1; j <= newBcols; j++) {
b11.setElement(i, j, B.saveGet(i, j)); // top left
b12.setElement(i, j, B.saveGet(i, j + newBcols)); // top right
b21.setElement(i, j, B.saveGet(i + newBrows, j)); // bottom left
b22.setElement(i, j, B.saveGet(i + newBrows, j + newBcols)); // bottom right
}
}
Matrix aResult;
Matrix bResult;
aResult = a11.add(a22);
bResult = b11.add(b22);
Matrix p1 = aResult.strassenTimes(bResult, LEAFSIZE);
aResult = a21.add(a22);
Matrix p2 = aResult.strassenTimes(b11, LEAFSIZE);
bResult = b12.minus(b22); // b12 - b22
Matrix p3 = a11.strassenTimes(bResult, LEAFSIZE);
bResult = b21.minus(b11); // b21 - b11
Matrix p4 = a22.strassenTimes(bResult, LEAFSIZE);
aResult = a11.add(a12); // a11 + a12
Matrix p5 = aResult.strassenTimes(b22, LEAFSIZE);
aResult = a21.minus(a11); // a21 - a11
bResult = b11.add(b12); // b11 + b12
Matrix p6 = aResult.strassenTimes(bResult, LEAFSIZE);
aResult = a12.minus(a22); // a12 - a22
bResult = b21.add(b22); // b21 + b22
Matrix p7 = aResult.strassenTimes(bResult, LEAFSIZE);
Matrix c12 = p3.add(p5); // c12 = p3 + p5
Matrix c21 = p2.add(p4); // c21 = p2 + p4
aResult = p1.add(p4); // p1 + p4
bResult = aResult.add(p7); // p1 + p4 + p7
Matrix c11 = bResult.minus(p5);
aResult = p1.add(p3); // p1 + p3
bResult = aResult.add(p6); // p1 + p3 + p6
Matrix c22 = bResult.minus(p2);
// Grouping the results obtained in a single matrix:
int rows = c11.nrRows();
int cols = c11.nrColumns();
Matrix C = new Matrix(A.M, B.N);
for (int i = 1; i <= A.M; i++) {
for (int j = 1; j <= B.N; j++) {
int el;
if (i <= rows) {
if (j <= cols) {
el = c11.get(i, j);
} else {
el = c12.get(i, j - cols);
}
} else {
if (j <= cols) {
el = c21.get(i - rows, j);
} else {
el = c22.get(i - rows, j - rows);
}
}
C.setElement(i, j, el);
}
}
return C;
}
The little benchmark has the following code:
int AM, AN, BM, BN;
AM = 500;
AN = BM = 500;
BN = 500;
Matrix a = new Matrix(AM, AN);
Matrix b = new Matrix(BM, BN);
Random random = new Random();
for (int i = 1; i <= AM; i++) {
for (int j = 1; j <= AN; j++) {
a.setElement(i, j, random.nextInt(20));
}
}
for (int i = 1; i <= BM; i++) {
for (int j = 1; j <= BN; j++) {
b.setElement(i, j, random.nextInt(20));
}
}
System.out.println("strassen: A x B");
long tijd = System.currentTimeMillis();
Matrix c = a.strassenTimes(b);
System.out.println("time = " + (System.currentTimeMillis() - tijd));
System.out.println("normal: A x B");
tijd = System.currentTimeMillis();
Matrix d = a.times(b);
System.out.println("time = " + (System.currentTimeMillis() - tijd));
System.out.println("nr of different elements = " + c.compare(d));
With the following results:
strassen: A x B
time = 18372
normal: A x B
time = 308
nr of different elements = 0
I know it's a low of code, but I would be very happy if you guys help me out ;)
EDIT 1:
For the sake of completeness I add some methods that is used by the above code.
public int get(int r, int c) {
if (c > nrColumns() || r > nrRows() || c <= 0 || r <= 0) {
throw new ArrayIndexOutOfBoundsException("matrix is of size (" +
nrRows() + ", " + nrColumns() + "), but tries to set element(" + r + ", " + c + ")");
}
return content[r - 1][c - 1];
}
private int saveGet(int r, int c) {
if (c > nrColumns() || r > nrRows() || c <= 0 || r <= 0) {
return 0;
}
return content[r - 1][c - 1];
}
public void setElement(int r, int c, int n) {
if (c > nrColumns() || r > nrRows() || c <= 0 || r <= 0) {
throw new ArrayIndexOutOfBoundsException("matrix is of size (" +
nrRows() + ", " + nrColumns() + "), but tries to set element(" + r + ", " + c + ")");
}
content[r - 1][c - 1] = n;
}
// return C = A + B
public Matrix add(Matrix B) {
Matrix A = this;
if (B.M != A.M || B.N != A.N) throw new RuntimeException("Illegal matrix dimensions.");
Matrix C = new Matrix(M, N);
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
C.content[i][j] = A.content[i][j] + B.content[i][j];
}
}
return C;
}
I should have choosen another leaf size for Strassen´s algorithm. Therefore I did a little experiment. It seems that leaf size 256 works best with the code included in the question. Below a plot with different leaf sizes with each time a random matrix of size 1025 x 1025.
I have compared Strassen´s algorithm with leaf size 256 with the trivial algorithm for matrix multiplication, to see if it´s actually an improvement. It turned out to be an improvement, see below the results on random matrices of different sizes (in steps of 10 and repeated 50 times for each size).
Below the code for the trivial algorithm for matrix multiplication:
// return C = A * B
public Matrix times(Matrix B) {
Matrix A = this;
if (A.N != B.M) throw new RuntimeException("Illegal matrix dimensions.");
Matrix C = new Matrix(A.M, B.N);
for (int i = 0; i < C.M; i++) {
for (int j = 0; j < C.N; j++) {
for (int k = 0; k < A.N; k++) {
C.content[i][j] += (A.content[i][k] * B.content[k][j]);
}
}
}
return C;
}
It still think there can be done other improvements on the implementation, but it turned out that the leaf size is a very important factor. All experiments are done with a machine running on Ubuntu 14.04 with the following specifications:
CPU: Intel(R) Core(TM) i7-2600K CPU # 3.40GHz
Memory: 2 x 4GB DDR3 1333 MHz
Related
I've been sitting for 3 days, I checked through debugging, I still don't see an error.
An equation like this: y = a^x mod p
To get started, we choose m=k=sqrt(p) + 1
. Then we start laying out 2 rows:
First: (a,ay,a^2 * y....a^(m-1) * y) mod p.
Second: (a^m, a^2m...a^km) mod p.
Then we look for the first available value from row 1 in the second and write down the indices of both, the answer should be
x = im - j, the equality a^(im) = a^j * y must also hold
BigInteger p = new BigInteger("61");
BigInteger m = p.sqrt().add(BigInteger.ONE);
BigInteger k = p.sqrt().add(BigInteger.ONE);
BigInteger a = new BigInteger("2");
BigInteger y = new BigInteger("45");
ArrayList<BigInteger> array1 = new ArrayList<>();
ArrayList<BigInteger> array2 = new ArrayList<>();
for(BigInteger i = BigInteger.ZERO; i.compareTo(m) < 0; i = i.add(BigInteger.ONE)) {
BigInteger temp = y.multiply(a.pow(i.intValue())).mod(p);
System.out.println(temp);
array1.add(temp);
}
System.out.println("---------------------------------------------------------");
System.out.println("---------------------------------------------------------");
System.out.println("---------------------------------------------------------");
for(BigInteger j = BigInteger.ONE; j.compareTo(k) < 0; j = j.add(BigInteger.ONE)) {
BigInteger temp = a.pow(j.multiply(m).intValue()).mod(p);
array2.add(temp);
System.out.println("temp = " + temp);
for(int h = 0; h < array1.size(); h++) {
if(Objects.equals(array1.get(h), temp)) {
System.out.println(a.pow(m.multiply(BigInteger.valueOf(h)).intValue()));
System.out.println(a.pow(j.intValue()).multiply(y));
System.out.println("h = " + h + " m = " + m + " j = " + j);
return BigInteger.valueOf(h).multiply(m).subtract(j);
}
/*if(a.pow(m.multiply(BigInteger.valueOf(h)).intValue()).equals(a.pow(j.intValue()).multiply(y))) {
System.out.println("h = " + h + " m = " + m + " j = " + j);
return new BigInteger("-1");
}*/
}
}
return new BigInteger("-1");
}
I use Java implemented Held-KarpTSP algorithm algo to solve a 25 cities TSP problem.
The program passes with 4 cities.
When it runs with 25 cities it won't stop for several hours. I use jVisualVM to see what's the hotspot, after some optimization now it shows
98% of time is in real computing instead in Map.contains or Map.get.
So I'd like to have your advice, and here is the code:
private void solve() throws Exception {
long beginTime = System.currentTimeMillis();
int counter = 0;
List<BitSetEndPointID> previousCosts;
List<BitSetEndPointID> currentCosts;
//maximum number of elements is c(n,[n/2])
//To calculate m-set's costs just need to keep (m-1)set's costs
List<BitSetEndPointID> lastKeys = new ArrayList<BitSetEndPointID>();
int m;
if (totalNodes < 10) {
//for test data, generate them on the fly
SetUtil3.generateMSet(totalNodes);
}
//m=1
BitSet beginSet = new BitSet();
beginSet.set(0);
previousCosts = new ArrayList<BitSetEndPointID>(1);
BitSetEndPointID beginner = new BitSetEndPointID(beginSet, 0);
beginner.setCost(0f);
previousCosts.add(beginner);
//for m=2 to totalNodes
for (m = 2; m <= totalNodes; m++) {// sum(m=2..n 's C(n,m)*(m-1)(m-1)) ==> O(n^2 * 2^n)
//pick m elements from total nodes, the element id is the index of nodeCoordinates
// the first node is always present
BitSet[] msets;
if (totalNodes < 10) {
msets = SetUtil3.msets[m - 1];
} else {
//for real data set, will read from serialized file
msets = SetUtil3.getMsets(totalNodes, m-1);
}
currentCosts = new ArrayList<BitSetEndPointID>(msets.length);
//System.out.println(m + " sets' size: " + msets.size());
for (BitSet mset : msets) { //C(n,m) mset
int[] candidates = allSetBits(mset, m);
//mset is a BitSet which makes sure begin point 0 comes first
//so end point candidate begins with 1. candidate[0] is always begin point 0
for (int i = 1; i < candidates.length; i++) { // m-1 bits are set
//set the new last point as j, j must not be the same as begin point 0
int j = candidates[i];
//middleNodes = mset -{j}
BitSet middleNodes = (BitSet) mset.clone();
middleNodes.clear(j);
//loop through all possible points which are second to the last
//and get min(A[S-{j},k] + k->j), k!=j
float min = Float.MAX_VALUE;
int k;
for (int ki = 0; ki < candidates.length; ki++) {// m-1 calculation
k = candidates[ki];
if (k == j) continue;
float middleCost = 0;
BitSetEndPointID key = new BitSetEndPointID(middleNodes, k);
int index = previousCosts.indexOf(key);
if (index != -1) {
//System.out.println("get value from map in m " + m + " y key " + middleNodes);
middleCost = previousCosts.get(index).getCost();
} else if (k == 0 && !middleNodes.equals(beginSet)) {
continue;
} else {
System.out.println("middleCost not found!");
continue;
// System.exit(-1);
}
float lastCost = distances[k][j];
float cost = middleCost + lastCost;
if (cost < min) {
min = cost;
}
counter++;
if (counter % 500000 == 0) {
try {
Thread.currentThread().sleep(100);
} catch (InterruptedException iex) {
System.out.println("Who dares interrupt my precious sleep?!");
}
}
}
//set the costs for chosen mset and last point j
BitSetEndPointID key = new BitSetEndPointID(mset, j);
key.setCost(min);
currentCosts.add(key);
// System.out.println("===========================================>mset " + mset + " and end at " +
// j + " 's min cost: " + min);
// if (m == totalNodes) {
// lastKeys.add(key);
// }
}
}
previousCosts = currentCosts;
System.out.println("...");
}
calcLastStop(lastKeys, previousCosts);
System.out.println(" cost " + (System.currentTimeMillis() - beginTime) / 60000 + " minutes.");
}
private void calcLastStop(List<BitSetEndPointID> lastKeys, List<BitSetEndPointID> costs) {
//last step, calculate the min(A[S={1..n},k] +k->1)
float finalMinimum = Float.MAX_VALUE;
for (BitSetEndPointID key : costs) {
float middleCost = key.getCost();
Integer endPoint = key.lastPointID;
float lastCost = distances[endPoint][0];
float cost = middleCost + lastCost;
if (cost < finalMinimum) {
finalMinimum = cost;
}
}
System.out.println("final result: " + finalMinimum);
}
You can speed up your code by using arrays of primitives (it's likely to have to better memory layout than a list of objects) and operating on bitmasks directly (without bitsets or other objects). Here is some code (it generates a random graph but you can easily change it so that it reads your graph):
import java.io.*;
import java.util.*;
class Main {
final static float INF = 1e10f;
public static void main(String[] args) {
final int n = 25;
float[][] dist = new float[n][n];
Random random = new Random();
for (int i = 0; i < n; i++)
for (int j = i + 1; j < n; j++)
dist[i][j] = dist[j][i] = random.nextFloat();
float[][] dp = new float[n][1 << n];
for (int i = 0; i < dp.length; i++)
Arrays.fill(dp[i], INF);
dp[0][1] = 0.0f;
for (int mask = 1; mask < (1 << n); mask++) {
for (int lastNode = 0; lastNode < n; lastNode++) {
if ((mask & (1 << lastNode)) == 0)
continue;
for (int nextNode = 0; nextNode < n; nextNode++) {
if ((mask & (1 << nextNode)) != 0)
continue;
dp[nextNode][mask | (1 << nextNode)] = Math.min(
dp[nextNode][mask | (1 << nextNode)],
dp[lastNode][mask] + dist[lastNode][nextNode]);
}
}
}
double res = INF;
for (int lastNode = 0; lastNode < n; lastNode++)
res = Math.min(res, dist[lastNode][0] + dp[lastNode][(1 << n) - 1]);
System.out.println(res);
}
}
It takes only a couple of minutes to complete on my computer:
time java Main
...
real 2m5.546s
user 2m2.264s
sys 0m1.572s
Closed. This question needs details or clarity. It is not currently accepting answers.
Want to improve this question? Add details and clarify the problem by editing this post.
Closed 8 years ago.
Improve this question
This is a program for calculate the area of triangle that create by six points
and point out what types of triangle.
import java.util.Scanner;
public class assignment {
public static void main(String[] args) {
Scanner input = new Scanner(System.in);
int m = 0,n = 0;
int[][] x = new int[6][2];
double a,b,c,a1,b1,c1;
String y = "";
System.out.println("Please enter six points,");
for(int i = 0 ; i < x.length ; i++ ){
for(int j = 0 ; j < x[i].length ; j++)
x[i][j] = input.nextInt();
}
for(int i = 0 ; i < 30 ; i++)
System.out.print("--");
System.out.print("\n\t\t\t\tTypes of Triangles\n");
for(int i = 0 ; i < 30 ; i++)
System.out.print("--");
System.out.println();
System.out.println("Point1\t\tPoint2\t\tPoint3\t\tType of Triangle");
for(int i = 0 ; i < 30 ; i++)
System.out.print("--");
System.out.println();
for(int j = 0 ; j < 4 ; j++){
for(int k = j + 1 ; k < 5 ; k++ ){
for(int l = k + 1 ;l < 6 ; l++ ){
a = Math.sqrt(Math.pow((x[j][0]-x[k][0]),2)+Math.pow((x[j][1]-x[k][1]),2));
b = Math.sqrt(Math.pow((x[j][0]-x[l][0]),2)+Math.pow((x[j][1]-x[l][1]),2));
c = Math.sqrt(Math.pow((x[l][0]-x[k][0]),2)+Math.pow((x[l][1]-x[k][1]),2));
c1=Math.max(Math.max(a,b),c);
b1=Math.min(Math.min(a,b),c);
a1=(a+b+c)-c1-b1;
if ( a1 > 0 && b1 > 0 && c1 > 0 && a1 + b1 > c1 && b1 + c1 > a1 && c1 + a1 > b1 ){
if(Math.pow(c1,2) == (Math.pow(b1,2) + Math.pow(a1,2))){
if(a1==b1 || b1==c1 || c1==a1){
y = ("Right-angled\tIsosceles");
}
else if(a1!=b1 && b1!=c1 && c1!=a1){
y = ("Right-angled\tScalene");
}
}
else if(a1==b1 || b1==c1 || c1==a1){
y = ("Isosceles");
}
else if(a1!=b1 && b1!=c1 && c1!=a1){
y=("Scalene");
}
}
else{
y = ("Non-triangle");
}
System.out.println("(" + x[j][0] + "," + x[j][1]+ ")\t\t" + "(" + x[k][0] + "," + x[k][1] + ")\t\t" + "(" + x [l][0] + "," + x[l][1]+")\t\t" + y );
}
}
}
System.out.print("Maximum area of triangle =" + Math.max(area()));
System.out.print("Maximum area of triangle =" + Math.min(area()));
}
public static double area(double[][] n){
double s = 0;
int i;
for(int j = 0 ; j < 4 ; j++){
for(int k = j + 1 ; k < 5 ; k++ ){
for(int l = k + 1 ;l < 6 ; l++ ){
a = Math.sqrt(Math.pow((x[j][0]-x[k][0]),2)+Math.pow((x[j][1]-x[k][1]),2));
b = Math.sqrt(Math.pow((x[j][0]-x[l][0]),2)+Math.pow((x[j][1]-x[l][1]),2));
c = Math.sqrt(Math.pow((x[l][0]-x[k][0]),2)+Math.pow((x[l][1]-x[k][1]),2));
s = (a+b+c)/2;
n[i][0] = Math.sqrt(s * (s-a) * (s-b) * (s-c));
i++;
}
}
}
return n[][];
}
}
I have got a error: '.class' expected,
I am a beginner ,can someone help me please?
thx a lot
There are quite a few problems with the code you posted. Things like assigning to undefined variables such as the following in the area method:
a = Math.sqrt(Math.pow((x[j][0]-x[k][0]),2)+Math.pow((x[j][1]-x[k][1]),2));
You're also referencing the variable x in that statement which is not the correct variable name. The return type and value returned don't match, and there's a syntax problem in the return.
A lot of these problems will be pointed out using an IDE like Eclipse. I think you would do yourself a favor to learn coding in Eclipse.
Here is a revision of your code that runs. I'm not sure about correctness of what you're trying to do:
public class assignment {
public static void main(String[] args) {
Scanner input = new Scanner(System.in);
int m = 0, n = 0;
int[][] x = new int[6][2];
double a, b, c, a1, b1, c1;
String y = "";
System.out.println("Please enter six points,");
for (int i = 0; i < x.length; i++) {
for (int j = 0; j < x[i].length; j++)
x[i][j] = input.nextInt();
}
for (int i = 0; i < 30; i++)
System.out.print("--");
System.out.print("\n\t\t\t\tTypes of Triangles\n");
for (int i = 0; i < 30; i++)
System.out.print("--");
System.out.println();
System.out.println("Point1\t\tPoint2\t\tPoint3\t\tType of Triangle");
for (int i = 0; i < 30; i++)
System.out.print("--");
System.out.println();
for (int j = 0; j < 4; j++) {
for (int k = j + 1; k < 5; k++) {
for (int l = k + 1; l < 6; l++) {
a = Math.sqrt(Math.pow((x[j][0] - x[k][0]), 2) + Math.pow((x[j][1] - x[k][1]), 2));
b = Math.sqrt(Math.pow((x[j][0] - x[l][0]), 2) + Math.pow((x[j][1] - x[l][1]), 2));
c = Math.sqrt(Math.pow((x[l][0] - x[k][0]), 2) + Math.pow((x[l][1] - x[k][1]), 2));
c1 = Math.max(Math.max(a, b), c);
b1 = Math.min(Math.min(a, b), c);
a1 = (a + b + c) - c1 - b1;
if (a1 > 0 && b1 > 0 && c1 > 0 && a1 + b1 > c1 && b1 + c1 > a1 && c1 + a1 > b1) {
if (Math.pow(c1, 2) == (Math.pow(b1, 2) + Math.pow(a1, 2))) {
if (a1 == b1 || b1 == c1 || c1 == a1) {
y = ("Right-angled\tIsosceles");
} else if (a1 != b1 && b1 != c1 && c1 != a1) {
y = ("Right-angled\tScalene");
}
} else if (a1 == b1 || b1 == c1 || c1 == a1) {
y = ("Isosceles");
} else if (a1 != b1 && b1 != c1 && c1 != a1) {
y = ("Scalene");
}
} else {
y = ("Non-triangle");
}
System.out.println("(" + x[j][0] + "," + x[j][1] + ")\t\t" + "(" + x[k][0] + "," + x[k][1] + ")\t\t" + "(" + x[l][0] + "," + x[l][1] + ")\t\t" + y);
}
}
}
System.out.print("Maximum area of triangle =" + area(x));
//System.out.print("Maximum area of triangle =" + Math.min(area(x)));
}
public static double area(int[][] n) {
double s = 0;
double max = 0;
for (int j = 0; j < 4; j++) {
for (int k = j + 1; k < 5; k++) {
for (int l = k + 1; l < 6; l++) {
double a = Math.sqrt(Math.pow((n[j][0] - n[k][0]), 2) + Math.pow((n[j][1] - n[k][1]), 2));
double b = Math.sqrt(Math.pow((n[j][0] - n[l][0]), 2) + Math.pow((n[j][1] - n[l][1]), 2));
double c = Math.sqrt(Math.pow((n[l][0] - n[k][0]), 2) + Math.pow((n[l][1] - n[k][1]), 2));
s = (a + b + c) / 2;
if (s > max) {
max = s;
}
}
}
}
return max;
}
}
I have n bags of candies such that no two bags have the same number of candies inside (i.e. it's a set A[] = {a0,a1,a2,...,ai,...,aj} where ai != aj).
I know how many candies is in each bag and the total number M of candies I have.
I need to divide the bags among three children so that the candies are distributed as fairly as possible (i.e. each child gets as close to M/3 as possible).
Needless to say, I may not tear into the bags to even out the counts -- then the question would be trivial.
Does anyone have any thoughts how to solve this -- preferably in Java?
EDIT:
the interviewer wanted me to use a 2-D array to solve the problem: the first kid gets x, the second kid y, the third gets the rest: S[x][y].
This after I tried following:
1] sort array n lg n
2] starting with largest remaining bag, give bag to kid with fewest candy.
Here is my solution for partitioning to two children (it is the correct answer). Maybe it will help with getting the 3-way partition.
int evenlyForTwo(int[] A, int M) {
boolean[] S = new boolean[M+1];
S[0]=true;//empty set
for(int i=0; i<A.length; i++)
for(int x=M; x >= A[i]; x--)
if(!S[x])
S[x]=S[x-A[i]];
int k = (int) M/2;
while(!S[k])
k--;
return k;//one kid gets k the other the rest.
}//
The problem you describe is known as the 3-Partition problem and is known to be NP-hard. The problem is discussed a bit on MathOverflow. You might find some of the pointers there of some value.
Here is a little solution, crude but gives correct results. And you can even change the number of children, bags, etc.
public class BagOfCandies {
static public void main(String...args) {
int repeat = 10;
int childCount = 3;
int bagsCount = childCount + (int) (Math.random() * 10);
for (int k=0; k<repeat; k++) {
int candyCount = 0, n=0;
int[] bags = new int[bagsCount];
for (int i=0; i<bags.length; i++) {
n += 1 + (int) (Math.random() * 2);
bags[i] = n;
candyCount += n;
}
shuffle(bags); // completely optional! It works regardless
boolean[][] dist = divideBags(bags, childCount);
System.out.println("Bags of candy : " + Arrays.toString(bags) + " = " + bags.length);
System.out.println("Total calculated candies is " + candyCount);
int childCandySum = 0;
for (int c=0; c<childCount; c++) {
int childCandies = countSumBags(bags, dist[c]);
System.out.println("Child " + (c+1) + " = " + childCandies + " --> " + Arrays.toString(dist[c]));
childCandySum += childCandies;
}
System.out.println("For a total of " + childCandySum + " candies");
System.out.println("----------------");
}
}
static private void shuffle(int[] bags) {
for (int i=0, len=bags.length; i<len; i++) {
int a = (int)Math.floor(Math.random()*len);
int b = (int)Math.floor(Math.random()*len);
int v = bags[a];
bags[a] = bags[b];
bags[b] = v;
}
}
static private boolean[][] divideBags(int[] bags, int childCount) {
int bagCount = bags.length;
boolean[][] dist = new boolean[childCount][bagCount];
for (int c=0; c<childCount; c++)
Arrays.fill(dist[c], false);
for (int i=0; i<bagCount; i+=childCount)
for (int j=i, c=0; c<childCount && j<bagCount; j++, c++)
dist[c][j] = true;
if (childCount == 1) return dist; // shortcut here
int sumDiff = 1;
int oldDiff = 0;
while (sumDiff != oldDiff) {
oldDiff = sumDiff;
sumDiff = 0;
// start comparing children in pair
for (int child1=0; child1<childCount-1; child1++) {
for (int child2=child1+1; child2<childCount; child2++) {
int count1 = countSumBags(bags, dist[child1]);
int count2 = countSumBags(bags, dist[child2]);
int diff = Math.abs(count1 - count2);
// a difference less than 2 is negligeable
if (diff > 1) {
// find some bags with can swap to even their difference
int c1=-1, c2=-1, cdiff;
boolean swap = false;
for (int i=0; i<bagCount-1; i++) {
for (int j=i; j<bagCount; j++) {
if (dist[child1][i] && dist[child2][j]) {
cdiff = Math.abs((count1 - bags[i] + bags[j]) - (count2 + bags[i] - bags[j]));
if (cdiff < diff) {
c1 = i; c2 = j;
diff = cdiff;
swap = true;
}
}
if (dist[child1][j] && dist[child2][i]) {
cdiff = Math.abs((count1 - bags[j] + bags[i]) - (count2 + bags[j] - bags[i]));
if (cdiff < diff) {
c1 = j; c2 = i;
diff = cdiff;
swap = true;
}
}
}
}
if (swap) {
//System.out.println("Swaping " + c1 + " with " + c2);
dist[child1][c1] = false; dist[child1][c2] = true;
dist[child2][c1] = true; dist[child2][c2] = false;
}
}
//System.out.println("Diff between " + child1 + "(" + countSumBags(bags, dist[child1]) + ") and " + child2 + "(" + countSumBags(bags, dist[child2]) + ") is " + diff);
sumDiff += diff;
}
}
//System.out.println("oldDiff="+oldDiff+", sumDiff="+sumDiff);
}
return dist;
}
static private int countSumBags(int[] bags, boolean[] t) {
int count = 0;
for (int i=0; i<t.length; i++) {
if (t[i]) {
count+=bags[i];
}
}
return count;
}
}
I don't know if this the result you were looking for, but it seems to be, from my understanding of the question.
I am trying to implement the rectangular/lattice multiplication in Java. For those who don't know, this is a short tutorial.
I tried some methods wherein I used a single array to store multiplication of two digits and sigma-append zeroes to it. Once all the numbers are multiplied, I pick two elements from the array and then add the sigma value and fetch two more numbers and again perform the same thing until all the numbers are fetched.
The logic works fine but I am not able to find the exact number of zeroes that I should maintain, since for every different set of numbers (4 digits * 3 digits) I get different number of zeroes.
Could somebody please help?
I liked the tutorial, very neat. So I wanted to implement it, but not do your project work. So I came up with a lousy, quick and dirty implementation, violating many of the design rules I practice myself. I used arrays to save the digit by digit multiplication results, and pretty much followed what the tutorial said. I never have had to count the number of 0's, and I am not sure what is sigma-appending, so I cannot answer that. Lastly, there is a bug in the code which shows up when the 2 numbers have a different count of digits. Here is the source code - feel free to edit and use any part. I think an easy fix would be to prepend 0's to the smaller number to make the digit count the same for the 2 numbers, and not to display the corresponding row/columns. More bookkeeping, but thats up to you.
import java.util.*;
import java.awt.*;
import java.awt.event.*;
import javax.swing.*;
public class Lattice extends JPanel implements ActionListener {
protected Font axisFont, rectFont, carrFont;
protected Color boxColor = new Color (25, 143, 103), gridColor = new Color (78, 23, 211),
diagColor = new Color (93, 192, 85), fontColor = new Color (23, 187, 98),
carrColor = new Color (162, 34, 19);
protected int nDigitP, nDigitQ, dSize = 60,
m1, m2, lastCarry, iResult[],
xDigits[], yDigits[], prodTL[][], prodBR[][];
public Lattice (int p, int q, Font font) {
nDigitP = (int) Math.ceil (Math.log10 (p)); xDigits = new int[nDigitP];
nDigitQ = (int) Math.ceil (Math.log10 (q)); yDigits = new int[nDigitQ];
prodTL = new int[nDigitP][nDigitQ]; prodBR = new int[nDigitP][nDigitQ];
m1 = p; m2 = q; // To display in report
int np = p, nq = q, size = font.getSize(); // Save the digits in array
for (int i = 0 ; i < nDigitP ; i++) {
xDigits[i] = np % 10;
np /= 10;
}
for (int i = 0 ; i < nDigitQ ; i++) {
yDigits[i] = nq % 10;
nq /= 10;
}
for (int i = 0 ; i < nDigitP ; i++) { // Cell products as upper/lower matrix
for (int j = 0 ; j < nDigitQ ; j++) {
int prod = xDigits[i] * yDigits[j];
prodTL[i][j] = prod / 10;
prodBR[i][j] = prod % 10;
}}
axisFont = font.deriveFont (Font.PLAIN, size+8.0f);
rectFont = font.deriveFont (Font.PLAIN, size+4.0f);
carrFont = font.deriveFont (Font.PLAIN);
setPreferredSize (new Dimension ((nDigitP+2)*dSize, (nDigitQ+2)*dSize));
}
public void paint (Graphics g) {
int w = getWidth(), h = getHeight();
Graphics2D g2 = (Graphics2D) g; // To make diagonal lines smooth
g2.setPaint (Color.white);
g2.fillRect (0,0,w,h);
int dx = (int) Math.round (w/(2.0+nDigitP)), // Grid spacing to position
dy = (int) Math.round (h/(2.0+nDigitQ)); // the lines and the digits
g2.setRenderingHint (RenderingHints.KEY_ANTIALIASING,
RenderingHints.VALUE_ANTIALIAS_ON);
g2.setRenderingHint (RenderingHints.KEY_INTERPOLATION,
RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g2.setFont (axisFont);
FontMetrics fm = g2.getFontMetrics();
for (int i = 0 ; i < nDigitP ; i++) { // Grid || Y-axis and labels on axis
int px = w - (i+1)*dx;
g2.setPaint (gridColor);
if (i > 0)
g2.drawLine (px, dy, px, h-dy);
String str = /*i + */"" + xDigits[i];
int strw = fm.stringWidth (str);
g2.setPaint (fontColor);
g2.drawString (str, px-dx/2-strw/2, 4*dy/5);
}
for (int i = 0 ; i < nDigitQ ; i++) { // Grid || X-axis and labels on axis
int py = h - (i+1)*dy;
g2.setPaint (gridColor);
if (i > 0)
g2.drawLine (dx, py, w-dx, py);
String str = /*i + */"" + yDigits[i];
int strw = fm.stringWidth (str);
g2.setPaint (fontColor);
g2.drawString (str, w-dx+2*dx/5-strw/2, py-dy/2+10);
}
g2.setFont (rectFont);
fm = g2.getFontMetrics(); // Upper/Lower traingular product matrix
for (int i = 0 ; i < nDigitP ; i++) {
for (int j = 0 ; j < nDigitQ ; j++) {
int px = w - (i+1)*dx;
int py = h - (j+1)*dy;
String strT = "" + prodTL[i][j];
int strw = fm.stringWidth (strT);
g2.drawString (strT, px-3*dx/4-strw/2, py-3*dy/4+5);
String strB = "" + prodBR[i][j];
strw = fm.stringWidth (strB);
g2.drawString (strB, px-dx/4-strw/2, py-dy/4+5);
}}
g2.setFont (axisFont);
fm = g2.getFontMetrics();
int carry = 0;
Vector cVector = new Vector(), iVector = new Vector();
for (int k = 0 ; k < 2*Math.max (nDigitP, nDigitQ) ; k++) {
int dSum = carry, i = k/2, j = k/2;
//System.out.println ("k="+k);
if ((k % 2) == 0) { // even k
if (k/2 < nDigitP && k/2 < nDigitQ)
dSum += prodBR[k/2][k/2];
// go right and top
for (int c = 0 ; c < k ; c++) {
if (--i < 0)
break;
if (i < nDigitP && j < nDigitQ)
dSum += prodTL[i][j];
//System.out.println (" >> TL (i,j) = (" + i+","+j+")");
if (++j == nDigitQ)
break;
if (i < nDigitP && j < nDigitQ)
dSum += prodBR[i][j];
//System.out.println (" >> BR (i,j) = (" + i+","+j+")");
}
// go bottom and left
i = k/2; j = k/2;
for (int c = 0 ; c < k ; c++) {
if (--j < 0)
break;
if (i < nDigitP && j < nDigitQ)
dSum += prodTL[i][j];
//System.out.println (" >> TL (i,j) = (" + i+","+j+")");
if (++i == nDigitP)
break;
if (i < nDigitP && j < nDigitQ)
dSum += prodBR[i][j];
//System.out.println (" >> BR (i,j) = (" + i+","+j+")");
}} else { // odd k
if (k/2 < nDigitP && k/2 < nDigitQ)
dSum += prodTL[k/2][k/2];
// go top and right
for (int c = 0 ; c < k ; c++) {
if (++j == nDigitQ)
break;
if (i < nDigitP && j < nDigitQ)
dSum += prodBR[i][j];
//System.out.println (" >> BR (i,j) = (" + i+","+j+")");
if (--i < 0)
break;
if (i < nDigitP && j < nDigitQ)
dSum += prodTL[i][j];
//System.out.println (" >> TL (i,j) = (" + i+","+j+")");
}
i = k/2; j = k/2;
// go left and bottom
for (int c = 0 ; c < k ; c++) {
if (++i == nDigitP)
break;
if (i < nDigitP && j < nDigitQ)
dSum += prodBR[i][j];
//System.out.println (" >> BR (i,j) = (" + i+","+j+")");
if (--j < 0)
break;
if (i < nDigitP && j < nDigitQ)
dSum += prodTL[i][j];
//System.out.println (" >> TL (i,j) = (" + i+","+j+")");
}}
int digit = dSum % 10; carry = dSum / 10;
cVector.addElement (new Integer (carry));
iVector.addElement (new Integer (digit));
String strD = "" + digit;
int strw = fm.stringWidth (strD);
if (k < nDigitP) {
int px = w - (k+1)*dx - 4*dx/5, py = h-dy + fm.getHeight();
g2.drawString (strD, px-strw/2, py);
} else {
int px = dx - 12, py = h - (k-nDigitP+1)*dy - dy/4;
g2.drawString (strD, px-strw/2, py+5);
}} // End k-loop
g2.setPaint (diagColor);
for (int i = 0 ; i < nDigitP ; i++) {
int xt = (i+1) * dx,
yb = (i+2) * dy;
g2.drawLine (xt, dy, 0, yb);
}
for (int i = 0 ; i < nDigitQ ; i++) {
int xb = (i + nDigitP - nDigitQ) * dx,
yr = (i+1) * dy;
g2.drawLine (w-dx, yr, xb, h);
}
// System.out.println ("carry Vector has " + cVector.size() + " elements");
g2.setFont (carrFont);
g2.setPaint (carrColor);
fm = g2.getFontMetrics();
for (int k = 0 ; k < 2*Math.max (nDigitP, nDigitQ) ; k++) {
carry = ((Integer) cVector.elementAt (k)).intValue();
lastCarry = carry; // To display
if (carry == 0)
continue;
String strC = "" + carry;
int strw = fm.stringWidth (strC),
px = w-dx-5-strw/2, // Const X while going Up
py = dy + fm.getHeight(); // Const Y while going Left
if (k < (nDigitQ-1))
py = h-(k+3)*dy + dy/5 + fm.getHeight();
else
px = w - (k-nDigitQ+2) * dx - dx/2 - strw/2;
g2.drawString (strC, px, py);
}
int n = iVector.size(); // Save the vector content to display later
iResult = new int[n];
for (int i = 0 ; i < n ; i++)
iResult[i] = ((Integer) iVector.elementAt (n-i-1)).intValue();
g2.setPaint (boxColor); g2.drawRect (dx, dy, w-2*dx, h-2*dy);
}
private void displayResults () {
StringBuffer sb = new StringBuffer ("Lattice: " + m1 + " \u00D7 " + m2 + " = " +
((lastCarry == 0) ? "" : (""+lastCarry)));
for (int k = 0 ; k < iResult.length ; k++)
sb.append ("" + iResult[k]);
// System.out.println (sb.toString());
JOptionPane.showMessageDialog (null, sb.toString(), "Lattice Multiplier",
JOptionPane.INFORMATION_MESSAGE);
}
public JPanel getButtonPanel () {
JPanel bp = new JPanel(new GridLayout (1,bNames.length));
for (int i = 0 ; i < bNames.length ; i++) {
JButton b = new JButton (bNames[i]);
b.addActionListener (this);
bp.add (b);
}
return bp;
}
private final static String[] bNames = {"Close", "Result"};
public void actionPerformed (ActionEvent e) {
String cmd = e.getActionCommand();
if (cmd.equals (bNames[0])) System.exit (0);
else if (cmd.equals (bNames[1])) displayResults();
}
public static void main (String[] args) {
JTextField tf1 = new JTextField (), tf2 = new JTextField();
JPanel num2m = new JPanel(new GridBagLayout());
GridBagConstraints gbc = new GridBagConstraints();
gbc.insets = new Insets (2,2,2,2);
gbc.fill = GridBagConstraints.HORIZONTAL;
gbc.gridx = 0;
gbc.gridy = GridBagConstraints.RELATIVE;
gbc.anchor = GridBagConstraints.EAST;
JLabel
label = new JLabel ("Multiplicand", JLabel.TRAILING); num2m.add (label, gbc);
label = new JLabel ("Multiplier", JLabel.TRAILING); num2m.add (label, gbc);
gbc.gridx++;
gbc.weightx = 1.0f; num2m.add (tf1, gbc); num2m.add (tf2, gbc);
JFrame lf = new JFrame ("Lattice Multiplication");
if (JOptionPane.showConfirmDialog (lf, num2m, "Enter numbers to multiply",
JOptionPane.OK_CANCEL_OPTION,
JOptionPane.QUESTION_MESSAGE ) == JOptionPane.OK_OPTION) {
try {
int m = Integer.parseInt (tf1.getText()), n = Integer.parseInt (tf2.getText());
Lattice lattice = new Lattice (m, n, label.getFont());
lf.add (lattice.getButtonPanel(), "South");
lf.add (lattice, "Center");
lf.setDefaultCloseOperation (JFrame.EXIT_ON_CLOSE);
lf.pack();
lf.setVisible (true);
} catch (NumberFormatException nex) {
JOptionPane.showMessageDialog (lf, "Invalid numbers to multiply",
"Lattice Multiplier Error", JOptionPane.ERROR_MESSAGE);
System.exit (1);
}} else { System.exit (2);
}}}
this works for any set of numbers..2 digit multiplicand with 4 digit multiplier..etc..
if(numberM.length()!=numberN.length()){
int mLen = numberM.length();
int nLen = numberN.length();
if(numberM.length()>numberN.length()){
for(int i=0;i<mLen-nLen;i++){
numberN = "0" + numberN;
}
}
else
{
for(int i=0;i<nLen-mLen;i++){
numberM = "0" + numberM;
}
}
}
int result[] = new int[numberN.length()+numberM.length()];
String numberRN = new StringBuffer(numberN).reverse().toString();
String numberRM = new StringBuffer(numberM).reverse().toString();
//reversing the number
int n[] = new int[numberN.length()];
int m[] = new int[numberM.length()];
int size_of_array = 0;
for(int i=0;i<numberN.length();i++){
n[i] = Integer.parseInt((new Character(numberRN.charAt(i))).toString());
m[i] = Integer.parseInt((new Character(numberRM.charAt(i))).toString());
}
//System.out.println("Numbers are:");
//displayArray(n,"n");
//displayArray(m,"m");
size_of_array = (m.length*2)*2;
int soa = size_of_array;
int tempArray[] = new int[size_of_array*m.length];
//System.out.println("Size of tempArray ="+tempArray.length);
//placing the numbers in a single array
int oddValue =3;
int index = 0;
tempArray[index++] = 0;
for(int i=0;i<m.length;i++){
for(int j=0;j<n.length;j++){
//System.out.println("n["+j+"]="+n[j]+" and m["+i+"]="+m[i]);
tempArray[index++] = (n[j] * m[i]) % 10;
tempArray[index] = (n[j] * m[i]) / 10;
//System.out.println("tempArray["+(index-1)+"]="+tempArray[index-1]+" tempArray["+(index)+"]="+tempArray[index]);
index++;
}
//System.out.println("index before appending zero="+index);
size_of_array=(i+1)*soa;
index = size_of_array;
//System.out.println("index after appending zero="+index);
//index+=i+oddArray[i];
index+=i+oddValue;
oddValue++;
//System.out.println("After move index="+index);
}
//System.out.println("tempArray full");
//displayArray(tempArray,"tempArray");
//adding the numbers and also dealing with carry
int count=0;int ind = 0;int res = 0;int carry =0;int tempInd = 0;
for(int i=0;i<m.length*2;i++){
tempInd = ind;
for(int k=0;k<m.length;k++){
//System.out.println(tempArray[ind]+"---"+tempArray[ind+1]);
res = tempArray[ind] + tempArray[ind+1] + res + carry;
ind = ind + soa ;
carry = 0;
//System.out.println("res="+res+" ind="+ind+" soa="+soa);
}
//System.out.println("----------------");
result[count++] = res %10;
carry = res /10;
res = 0;
ind = tempInd+2;
}
//displayArray(result,"result");
displayResult(result,sign);
}
private static void displayResult(int[] result,String sign) {
System.out.print("The result is "+sign);
for(int i=result.length-1;i>=0;i--){
System.out.print(result[i]);
}
}
static void displayArray(int tempArray[],String name){
for(int k =0;k<tempArray.length;k++)
System.out.print(" "+name+"["+k+"]-"+tempArray[k]);
System.out.println("");
}