I have huge .csv files (the biggest one, 72mb) that I need to load to my Android App. They contain matrices e.g [7500x1000], which I later use to multiply another one. Because their are too big to read them at one time () I'm reading line after line and perform some multiplication. I tested this on Nexus 4 and It took something about 15 minutes to do everything. When I copied the code from Android project to JAVA one(the only difference is that in Android i'm using Bitmap and in JAVA BufferedImage) and ran this as java application it took few seconds. Do you have any idea why is so and how to shorten this time? Here Is my code:
Android:
public class NetworkThread extends Thread {
private static boolean threadIsWorking = false;
private Context context;
private SQLiteHandler sql;
private static NetworkThread REFERENCE = null;
private String w1 = "w1.csv";
private String w2 = "w2.csv";
private String w3 = "w3.csv";
private String w4 = "w4.csv";
String NEURON_PATH = "/data/data/com.ZPIProject/neuronFiles/";
public static NetworkThread getInstance(Context context, SQLiteHandler sql) {
if (REFERENCE == null)
REFERENCE = new NetworkThread(context, sql);
return REFERENCE;
}
private NetworkThread(Context context, SQLiteHandler sql) {
this.context = context;
this.sql = sql;
}
#Override
public void start() {
if (!threadIsWorking) {
threadIsWorking = true;
try {
Log.i("MATRIX", "THREAD ID: " + Thread.currentThread().getId());
Bitmap bit = BitmapFactory.decodeResource(context.getResources(), R.drawable.aparat_small);
double[] imageInfo = ImageUtils.getImageInfo(bit);
copyNeuronWeightsToMemoryOnPhone();
double[] m4 = multiplyImageWithNeuronWeights(imageInfo);
List<WeightWithId> best10 = findBest10Results(m4);
for (int i = 0; i < best10.size(); i++) {
Log.e("IDS", best10.get(i).getId() + "");
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
private List<WeightWithId> findBest10Results(double[] m4) throws SQLException, CloneNotSupportedException {
List<WeightWithId> best10 = new ArrayList<WeightWithId>(20);
for (int j = 0; j < 19; j++) {
Map<Integer, double[]> readDescriptions = sql.getDescriptionForShoeId(j * 100, j * 100 + 100);
List<WeightWithId> distances = new ArrayList<WeightWithId>(100);
for (Integer i : readDescriptions.keySet()) {
double dist = dist(m4, readDescriptions.get(i));
distances.add(new WeightWithId(i, dist));
}
Collections.sort(distances);
for (int i = 0; i < 10; i++) {
best10.add(distances.get(i).clone());
}
Collections.sort(best10);
if (best10.size() > 10) {
for (int i = 10; i < best10.size(); i++) {
best10.remove(i);
}
}
}
return best10;
}
private double[] multiplyImageWithNeuronWeights(double[] imageInfo) throws IOException {
Log.i("MATRIX MULTIPLY", "M1");
double[] m1 = MatrixMaker.multiplyMatrixWithCsvMatrix(NEURON_PATH + w1, imageInfo, 1000, false);
Log.i("MATRIX MULTIPLY", "M2");
double[] m2 = MatrixMaker.multiplyMatrixWithCsvMatrix(NEURON_PATH + w2, m1, 500, false);
Log.i("MATRIX MULTIPLY", "M3");
double[] m3 = MatrixMaker.multiplyMatrixWithCsvMatrix(NEURON_PATH + w3, m2, 250, false);
Log.i("MATRIX MULTIPLY", "M4");
return MatrixMaker.multiplyMatrixWithCsvMatrix(NEURON_PATH + w4, m3, 50, true);
}
private void copyNeuronWeightsToMemoryOnPhone() throws IOException {
Log.i("MATRIX COPY", "W1");
CopyUtils.copyFileIntoDevice(NEURON_PATH, w1, context);
Log.i("MATRIX COPY", "W2");
CopyUtils.copyFileIntoDevice(NEURON_PATH, w2, context);
Log.i("MATRIX COPY", "W3");
CopyUtils.copyFileIntoDevice(NEURON_PATH, w3, context);
Log.i("MATRIX COPY", "W4");
CopyUtils.copyFileIntoDevice(NEURON_PATH, w4, context);
}
private double dist(double[] a, double[] b) {
double result = 0;
for (int i = 0; i < a.length; i++)
result += (a[i] - b[i]) * (a[i] - b[i]);
return result;
}
}
Matrix Operations:
public class MatrixMaker {
public static double[] multiplyImageInfoWithCsvMatrix(String csvFilePath, double[] imageInfo, int expectedSize, boolean lastOne) throws IOException {
InputStream inputStream = new FileInputStream(new File(csvFilePath));
InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
String line;
//counter - to which position calculated value should be setted
int counter = 0;
//result array
double[] multipliedImageInfoWithCsvMatrix = new double[expectedSize + 1];
//declaration of array for read line (parsed into double array)
double[] singleLineFromCsv = new double[imageInfo.length];
while ((line = bufferedReader.readLine()) != null) {
//splitting and parsing values from read line
String[] splitted = line.split(",");
for (int i = 0; i < splitted.length; i++) {
singleLineFromCsv[i] = Double.valueOf(splitted[i]);
}
//multiply imageInfo array with single row from csv file
multipliedImageInfoWithCsvMatrix[counter] = multiplyOneRow(imageInfo, singleLineFromCsv);
//ugly flag 'lastOne' needed to other business case
if (!lastOne) {
multipliedImageInfoWithCsvMatrix[counter] = 1d / (1 + Math.exp(-multipliedImageInfoWithCsvMatrix[counter]));
}
counter++;
//logging progress
if (counter % 100 == 0)
Log.i("MULTIPLY PROGRESS", counter + " " + System.currentTimeMillis());
}
if (!lastOne)
multipliedImageInfoWithCsvMatrix[expectedSize] = 1d;
bufferedReader.close();
inputStream.close();
return multipliedImageInfoWithCsvMatrix;
}
private static double multiplyOneRow(double first[], double second[]) {
double result = 0d;
for (int i = 0; i < first.length; i++) {
result += first[i] * second[i];
}
return result;
}
}
onCreate in one of Activity
public void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
try {
final SQLiteHandler sql = new SQLiteHandler(this);
sql.init();
NetworkThread.getInstance(this, sql).start();
} catch (IOException e) {
e.printStackTrace();
} catch (SQLException e) {
e.printStackTrace();
}
JAVA:
Multiply equivalent of Matrix Maker
import java.io.*;
public class Multiply {
public static double[] multiplyMatrixWithCsvMatrix(String csvFilePath, double[] imageInfo,
int expectedSize, boolean lastOne) throws IOException {
InputStream inputStream = new FileInputStream(new File(csvFilePath));
InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
String line;
// counter - to which position calculated value should be setted
int counter = 0;
// result array
double[] multipliedImageInfoWithCsvMatrix = new double[expectedSize + 1];
// declaration of array for read line (parsed into double array)
double[] singleLineFromCsv = new double[imageInfo.length];
while ((line = bufferedReader.readLine()) != null) {
// splitting and parsing values from read line
String[] splitted = line.split(",");
for (int i = 0; i < splitted.length; i++) {
singleLineFromCsv[i] = Double.valueOf(splitted[i]);
}
// multiply imageInfo array with single row from csv file
multipliedImageInfoWithCsvMatrix[counter] = multiplyOneRow(imageInfo, singleLineFromCsv);
// ugly flag 'lastOne' needed to other business case
if (!lastOne) {
multipliedImageInfoWithCsvMatrix[counter] = 1d / (1 + Math
.exp(-multipliedImageInfoWithCsvMatrix[counter]));
}
counter++;
// logging progress
if (counter % 100 == 0)
System.out.println(counter + " " + System.currentTimeMillis());
}
if (!lastOne)
multipliedImageInfoWithCsvMatrix[expectedSize] = 1d;
bufferedReader.close();
inputStream.close();
return multipliedImageInfoWithCsvMatrix;
}
private static double multiplyOneRow(double first[], double second[]) {
double result = 0d;
for (int i = 0; i < first.length; i++) {
result += first[i] * second[i];
}
return result;
}
}
Executable class
public class MRunner {
private static final int RED_INDEX = 0;
private static final int GREEN_INDEX = 2500;
private static final int BLUE_INDEX = 5000;
private static final String w1 = "w1.csv";
private static final String NEURON_PATH = "C:\\Users\\Vortim\\Desktop\\";
public static void main(String[] args) {
new Thread(new Runnable() {
#Override
public void run() {
try {
Multiply.multiplyMatrixWithCsvMatrix(NEURON_PATH + w1, getImageInfo(ImageIO
.read(new File("C:\\Users\\Vortim\\git\\MatrixMaker\\but.png"))), 1000,
false);
} catch (IOException e) {
e.printStackTrace();
}
}
}).start();
}
public static double[] getImageInfo(BufferedImage source) {
double[] result = new double[7501];
int width = source.getWidth();
int height = source.getHeight();
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
if (x == 0) {
result[y + RED_INDEX] = 255d;
result[y + GREEN_INDEX] = 255d;
result[y + BLUE_INDEX] = 255d;
} else if (y == 0) {
result[x * 50 + RED_INDEX] = 255d;
result[x * 50 + GREEN_INDEX] = 255d;
result[x * 50 + BLUE_INDEX] = 255d;
} else {
int pixel = source.getRGB(x, y);
double r = (pixel) >> 16 & 0xff;
double g = (pixel) >> 8 & 0xff;
double b = (pixel) & 0xff;
result[x * y + RED_INDEX] = 1d - (r / 255d);
result[x * y + GREEN_INDEX] = 1d - (g / 255d);
result[x * y + BLUE_INDEX] = 1d - (b / 255d);
}
}
}
result[7500] = 1;
return result;
}
}
Related
I have implemented this java app communicating with a remote 8 modem server and when I request for an image from a security camera (error free or not) the outcome is a 1 byte image (basically a small white square). Can you help me find the error?
Here is my code:
import java.io.*;
import java.util.Scanner;
import ithakimodem.Modem;
public class g {
private static Scanner scanner = new Scanner(System.in);
private static String EchoCode = "E3369";
private static String noImageErrorscode = "M2269";
private static String imageErrorsCode = "G6637";
private static String GPS_Code = "P7302";
private static String ACK_Code = "Q2591";
private static String NACK_Code = "R4510";
public static void main(String[] args) throws IOException, InterruptedException {
int timeout = 2000;
int speed = 80000;
Modem modem = new Modem(speed);
modem.setTimeout(timeout);
modem.write("ATD2310ITHAKI\r".getBytes());
getConsole(modem);
modem.write(("test\r").getBytes());
getConsole(modem);
while (true) {
System.out.println("\nChoose one of the options below :");
System.out.print("Option 0: Exit\n" + "Option 1: Echo packages\n" + "Option 2: Image with no errors\n" + "Option 3: Image with errors\n" + "Option 4: Tracks of GPS\n"
+ "Option 5: ARQ\n" );
System.out.print("Insert option : ");
int option = scanner.nextInt();
System.out.println("");
switch (option) {
case 0: {
// Exit
System.out.println("\nExiting...Goodbye stranger");
modem.close();
scanner.close();
return;
}
case 1: {
// Echo
getEcho(modem, EchoCode);
break;
}
case 2: {
// Image with no errors
getImage(modem, noImageErrorscode, "no_error_image.jpg");
break;
}
case 3: {
// Image with errors
getImage(modem, imageErrorsCode, "image_with_errors.jpg");
}
case 4: {
// GPS
getGPS(modem, GPS_Code);
break;
}
case 5: {
// ARQ
getARQ(modem, ACK_Code, NACK_Code);
break;
}
default:
System.out.println("Try again please\n");
}
}
}
public static String getConsole(Modem modem) {
int l;
StringBuilder stringBuilder= new StringBuilder();
String string = null;
while (true) {
try {
l = modem.read();
if (l == -1) {
break;
}
System.out.print((char) l);
stringBuilder.append((char) l);
string = stringBuilder.toString();
} catch (Exception exc) {
break;
}
}
System.out.println("");
return string;
}
private static void getImage(Modem modem, String password, String fileName) throws IOException {
int l;
boolean flag;
FileOutputStream writer = new FileOutputStream(fileName, false);
flag = modem.write((password + "\r").getBytes());
System.out.println("\nReceiving " + fileName + "...");
if (!flag) {
System.out.println("Code error or end of connection");
writer.close();
return;
}
while (true) {
try {
l = modem.read();
writer.write((char) l);
writer.flush();
if (l == -1) {
System.out.println("END");
break;
}
}
catch (Exception exc) {
break;
}
}
writer.close();
}
private static void getGPS(Modem modem, String password) throws IOException {
int rows = 99;
String Rcode = "";
Rcode = password + "R=10200" + rows;
System.out.println("Executing R parameter = " + Rcode + " (XPPPPLL)");
modem.write((Rcode + "\r").getBytes());
String stringConsole = getConsole(modem);
if (stringConsole == null) {
System.out.println("Code error or end of connection");
return;
}
String[] stringRows = stringConsole.split("\r\n");
if (stringRows[0].equals("n.a")) {
System.out.println("Error : Packages not received");
return;
}
System.out.println("**TRACES**\n");
float l1 = 0, l2 = 0;
int difference = 8;
difference *= 100 / 60;
int traceNumber = 7;
String[] traces = new String[traceNumber + 1];
int tracesCounter = 0, flag = 0;
for (int i = 0; i < rows; i++) {
if (stringRows[i].startsWith("$GPGGA")) {
if (flag == 0) {
String str = stringRows[i].split(",")[1];
l1 = Integer.valueOf(str.substring(0, 6)) * 100 / 60;
flag = 1;
}
String str = stringRows[i].split(",")[1];
l2 = Integer.valueOf(str.substring(0, 6)) * 100 / 60;
if (Math.abs(l2 - l1) >= difference) {
traces[tracesCounter] = stringRows[i];
if (tracesCounter == traceNumber)
break;
tracesCounter++;
l1 = l2;
}
}
}
for (int i = 0; i < traceNumber; i++) {
System.out.println(traces[i]);
}
String w = "", T_cd_fnl = password + "T=";
int p = 1;
System.out.println();
for (int i = 0; i < traceNumber; i++) {
String[] strSplit = traces[i].split(",");
System.out.print("T parameter = ");
String str1 = strSplit[4].substring(1, 3);
String str2 = strSplit[4].substring(3, 5);
String str3= String.valueOf(Integer.parseInt(strSplit[4].substring(6, 10)) * 60 / 100).substring(0, 2);
String str4= strSplit[2].substring(0, 2);
String str5= strSplit[2].substring(2, 4);
String str6= String.valueOf(Integer.parseInt(strSplit[2].substring(5, 9)) * 60 / 100).substring(0, 2);
w = str1 + str2 + str3 + str4 + str5 + str6 + "T";
p = p + 5;
System.out.println(w);
T_cd_fnl = T_cd_fnl + w + "=";
}
T_cd_fnl = T_cd_fnl.substring(0, T_cd_fnl.length() - 2);
System.out.println("\nSending code: " + T_cd_fnl);
getImage(modem, T_cd_fnl, "traces GPS.jpg");
}
private static void getEcho(Modem modem, String strl) throws InterruptedException, IOException {
FileOutputStream echo_time_writer = new FileOutputStream("echoTimes.txt", false);
FileOutputStream echo_counter_writer = new FileOutputStream("echoCounter.txt", false);
int h;
int runtime = 5, packetCounter = 0;
String str = "";
System.out.println("\nRuntime (mins) = " + runtime + "\n");
long start_time = System.currentTimeMillis();
long stop_time = start_time + 60 * 1000 * runtime;
long send_time = 0, receiveTime = 0;
String time = "", ClockTime = "";
echo_time_writer.write("Clock Time\tSystem Time\r\n".getBytes());
while (System.currentTimeMillis() <= stop_time) {
packetCounter++;
send_time = System.currentTimeMillis();
modem.write((strl + "\r").getBytes());
while (true) {
try {
h = modem.read();
System.out.print((char) h);
str += (char) h;
if ( h== -1) {
System.out.println("\nCode error or end of connection");
return;
}
if (str.endsWith("PSTOP")) {
receiveTime = System.currentTimeMillis();
ClockTime = str.substring(18, 26) + "\t";
time = String.valueOf((receiveTime - send_time) + "\r\n");
echo_time_writer.write(ClockTime.getBytes());
echo_time_writer.write(time.getBytes());
echo_time_writer.flush();
str = "";
break;
}
} catch (Exception e) {
break;
}
}
System.out.println("");
}
echo_counter_writer.write(("\r\nRuntime: " + String.valueOf(runtime)).getBytes());
echo_counter_writer.write(("\r\nPackets Received: " + String.valueOf(packetCounter)).getBytes());
echo_counter_writer.close();
echo_time_writer.close();
}
private static void getARQ(Modem modem, String strl, String nack_code) throws IOException, InterruptedException {
FileOutputStream arq_time_writer = new FileOutputStream("ARQtimes.txt", false);
FileOutputStream arq_counter_writer = new FileOutputStream("ARQcounter.txt", false);
int runtime = 5;
int xor = 1;
int f = 1;
int m;
int packageNumber = 0;
int iterationNumber = 0;
int[] nack_times_counter = new int[15];
int nack_to_package = 0;
String time = "", clock_time = "", s = "";
long start_time = System.currentTimeMillis();
long stop_time = start_time + 60 * 1000 * runtime;
long send_time = 0, receive_time = 0;
System.out.printf("Runtime (mins) = " + runtime + "\n");
arq_time_writer.write("Clock Time\tSystem Time\tPacket Resends\r\n".getBytes());
while (System.currentTimeMillis() <= stop_time) {
if (xor == f) {
packageNumber++;
nack_times_counter[nack_to_package]++;
nack_to_package = 0;
send_time = System.currentTimeMillis();
modem.write((strl + "\r").getBytes());
} else {
iterationNumber++;
nack_to_package++;
modem.write((nack_code + "\r").getBytes());
}
while (true) {
try {
m = modem.read();
System.out.print((char) m);
s += (char) m;
if (m == -1) {
System.out.println("\nCode error or end of connection");
return;
}
if (s.endsWith("PSTOP")) {
receive_time = System.currentTimeMillis();
break;
}
} catch (Exception e) {
break;
}
}
System.out.println("");
String[] string = s.split("<");
string = string[1].split(">");
f = Integer.parseInt(string[1].substring(1, 4));
xor = string[0].charAt(0) ^ string[0].charAt(1);
for (int i = 2; i < 16; i++) {
xor = xor ^ string[0].charAt(i);
}
if (xor == f) {
System.out.println("Packet ok");
receive_time = System.currentTimeMillis();
time = String.valueOf((receive_time - send_time) + "\t");
clock_time = s.substring(18, 26) + "\t";
arq_time_writer.write(clock_time.getBytes());
arq_time_writer.write(time.getBytes());
arq_time_writer.write((String.valueOf(nack_to_package) + "\r\n").getBytes());
arq_time_writer.flush();
} else {
xor = 0;
}
s = "";
}
arq_counter_writer.write(("\r\nRuntime: " + String.valueOf(runtime)).getBytes());
arq_counter_writer.write("\r\nPackets Received (ACK): ".getBytes());
arq_counter_writer.write(String.valueOf(packageNumber).getBytes());
arq_counter_writer.write("\r\nPackets Resent (NACK): ".getBytes());
arq_counter_writer.write(String.valueOf(iterationNumber).getBytes());
arq_counter_writer.write("\r\nNACK Time Details".getBytes());
for (int i = 0; i < nack_times_counter.length; i++) {
arq_counter_writer.write(("\r\n" + i + ":\t" + nack_times_counter[i]).getBytes());
}
arq_counter_writer.close();
arq_counter_writer.close();
System.out.println("Packets Received: " + packageNumber);
System.out.println("Packets Resent: " + iterationNumber);
System.out.println("\n\nFile arqTimes.txt is created.");
System.out.println("File arqCounter.txt is created.");
}
}
I know the problem is most probably in the getImage() function but I haven't figured it out yet.
Recently, I started trying to train a neural network using backpropagation. The network structure is 784-512-10, and I used the Sigmoid activation function. When I tested a single-layer network on the MNIST dataset, I got around 90%. My results are around 86% with this multi-layer network, is this normal? Did I get the backpropagation part wrong?
Here is my code:
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.Scanner;
public class NeuralNetwork{
public static double learningRate = 0.01;
public static int epoch = 15;
public static int ROWS = 28;
public static int COLUMNS = 28;
public static int INPUT = ROWS * COLUMNS;
public static int outNum = 10;
public static int hiddenNum = 512;
public static double[][] weights2 = new double[outNum][hiddenNum];
public static double[] bias2 = new double[outNum];
public static double[][] weights1 = new double[hiddenNum][INPUT];
public static double[] bias1 = new double[outNum];
private static final double TRAININGSIZE = 10;
public static double[][] inputs = new double[outNum][INPUT];
private static final double[][] target = new double[outNum][outNum];
private static final ArrayList<String> filenames = new ArrayList<>();
private static final ArrayList<Integer> yetDone = new ArrayList<>();
public static double[] actual = new double[outNum];
public static Random rand = new SecureRandom();
public static Scanner input = new Scanner(System.in);
public static void main(String[]args) throws Exception {
System.out.println("1. Learn the network");
System.out.println("2. Guess a number");
System.out.println("3. Guess file");
System.out.println("4. Guess All Numbers");
System.out.println("5. Guess image");
switch (input.nextInt()){
case 1:
learn();
break;
case 2:
guess();
break;
case 3:
guessFile();
break;
case 4:
guessAll();
break;
}
}
public static void guessAll() throws IOException, ClassNotFoundException {
System.out.println("Recognizing...");
/*
for(int x = 1; x < 60000; x++){
filenames.add("data/" + String.format("%05d",x) + ".txt");
}
ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
Layers lay = (Layers) ois.readObject();
int correct = 0;
for (String z : filenames) {
double[] a = scan(z,0);
correct += getBestGuess(sigmoid(lay.step(a))) == actual[0] ? 1 : 0;
}
System.out.println("Training: " + correct + " / " + filenames.size() + " correct.");
filenames.clear();
*/
for(int x = 60000; x < 70000; x++){
filenames.add("data/" + String.format("%05d",x) + ".txt");
}
ObjectInputStream oiss = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network1.ser")));
Layers lays1 = (Layers) oiss.readObject();
ObjectInputStream oiss2 = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network2.ser")));
Layers lays2 = (Layers) oiss2.readObject();
int corrects = 0;
for (String z : filenames) {
double[] a = scan(z,0);
corrects += getBestGuess(sigmoid(lays2.step(sigmoid(lays1.step(a))))) == actual[0] ? 1 : 0;
}
System.out.println("Testing: " + corrects + " / " + filenames.size() + " correct.");
System.out.println("Done!");
}
public static void makeList(){
for(int index = 0; index < TRAININGSIZE; index++){
int indices = rand.nextInt(yetDone.size() - 1) + 1;
filenames.add("data/" + String.format("%05d",yetDone.get(indices)) + ".txt");
yetDone.remove(indices);
}
prepareData();
for(int indices = 0; indices < outNum; indices++) {
for(int index = 0; index < outNum; index++){
target[indices][index] = 0;
}
target[indices][(int)actual[indices]] = 1;
}
}
public static void prepareData(){
for(int index = 0; index < outNum; index++){
try {
inputs[index] = scan(filenames.get(index), index);
} catch (FileNotFoundException ex) {
ex.printStackTrace();
}
}
}
public static double[] scan(String filename, int index) throws FileNotFoundException {
Scanner in = new Scanner(new File(filename));
double[] a = new double[INPUT];
for(int i = 0; i < INPUT; i++){
a[i] = in.nextDouble() / 255;
}
actual[index] = in.nextDouble();
return a;
}
public static void guessFile() throws IOException, ClassNotFoundException {
System.out.print("Enter Filename: ");
double[] a = scan(input.next(), 0);
ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
Layers lay = (Layers) ois.readObject();
double[] results = lay.step(a);
System.out.println("This is a " + getBestGuess(sigmoid(results)) + "!");
System.out.println(Arrays.toString(results));
}
public static double guess(double[] a) throws IOException, ClassNotFoundException {
ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
Layers lay = (Layers) ois.readObject();
double[] results = lay.step(a);
return getBestGuess(sigmoid(results));
}
public static void guess() throws IOException, ClassNotFoundException {
ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
System.out.println("Input number: ");
Layers lay = (Layers) ois.readObject();
double[] a = new double[INPUT];
for(int index = 0; index < a.length; index++){
a[index] = input.nextInt();
}
double[] results = lay.step(a);
System.out.println("This is a " + getBestGuess(sigmoid(results)) + "!");
System.out.println(Arrays.toString(sigmoid(results)));
}
public static void learn() {
System.out.println("Learning...");
initialise(weights2, outNum, hiddenNum);
initialise(bias2);
initialise(weights1,hiddenNum, INPUT);
initialise(bias1);
Layers lay2 = new Layers(weights2, bias2, outNum, hiddenNum);
Layers lay1 = new Layers(weights1, bias1, hiddenNum, INPUT);
double[] result2 = new double[lay2.outNum];
double[] result1 = new double[lay1.outNum];
double[] a2;
double[] a1;
double cost = 0;
double sumFinal;
for(int x = 0; x < epoch; x++) {
yetDone.clear();
for(int y = 0; y < 60000; y++){
yetDone.add(y);
}
for (int ind = 0; ind < 200; ind++) {
filenames.clear();
makeList();
for (int n = 0; n < lay2.outNum; n++) {
a1 = inputs[n]; //number
result1 = sigmoid(lay1.step(a1));
a2 = result1;
result2 = sigmoid(lay2.step(a2));
for (int i = 0; i < lay2.outNum; i++) {
for (int j = 0; j < lay2.INPUT; j++) {
weights2[i][j] += learningRate * a2[j] * (target[n][i] - result2[i]);
cost += Math.pow((target[n][i] - result2[i]), 2);
}
}
for(int i = 0; i < lay1.outNum; i++){
for(int j = 0; j < lay1.INPUT; j++){
sumFinal = 0;
for(int k = 0; k < lay2.outNum; k++){
// weight * derivSigma(outputHiddenLayer) * 2(out - expected)
sumFinal += result1[k] * (1 - result1[k]) * 2 * (result2[k] - target[n][k]); // * weights2[k][i]
}
weights1[i][j] -= learningRate * a1[j] * sumFinal * result1[i] * (1 - result1[i]);
}
}
}
lay1.update(weights1, bias1);
lay2.update(weights2, bias2);
}
System.out.println("Epoch " + x + ": " + cost);
cost = 0;
}
System.out.println(Arrays.toString(result1));
System.out.println(Arrays.toString(result2));
for(double[] arr : inputs) {
System.out.println("This is a " + getBestGuess(sigmoid(lay2.step(sigmoid(lay1.step(arr))))) + "!");
}
try (ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream("network1.ser")))) {
oos.writeObject(lay1);
} catch (IOException ex) {
ex.printStackTrace();
}
try (ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream("network2.ser")))) {
oos.writeObject(lay2);
} catch (IOException ex) {
ex.printStackTrace();
}
System.out.println("Done! Saved to file.");
}
public static double sigmoid(double x){
return 1 / (1 + Math.exp(-x));
}
public static double[] sigmoid(double[] weights){
for(int index = 0; index < weights.length; index++){
weights[index] = sigmoid(weights[index]);
}
return weights;
}
public static void initialise(double[] bias){
Random random = new Random();
for(int index = 0; index < bias.length; index++){
bias[index] = random.nextGaussian();
}
}
public static void initialise(double[][] weights, int outNum, int INPUT){
Random random = new Random();
for(int index = 0; index < outNum; index++){
for(int indice = 0; indice < INPUT; indice++){
weights[index][indice] = random.nextGaussian();
}
}
}
public static int getBestGuess(double[] result){
double k = Integer.MIN_VALUE;
double index = 0;
int current = 0;
for(double a : result){
if(k < a){
k = a;
index = current;
}
current++;
}
return (int)index;
}
}
class Layers implements Serializable {
private static final long serialVersionUID = 8L;
double[][] weights;
double[] bias;
int outNum;
int INPUT;
public Layers(double[][] weights, double[] bias, int outNum, int INPUT){
this.weights = weights;
this.bias = bias;
this.outNum = outNum;
this.INPUT = INPUT;
}
public void update(double[][] weights, double[] bias){
this.weights = weights;
this.bias = bias;
}
public double[] step(double[] aa){
double[] out = new double[outNum];
for (int index = 0; index < outNum; index++) {
for (int indices = 0; indices < INPUT; indices++) {
out[index] += weights[index][indices] * aa[indices];
}
}
return out;
}
}
Thanks in advance!
This is my BFS algorithm code.
I can calculate the shortest path distance, but somehow I am not able to display the shortest path.
Like for example, I calculated the shortest path from source to destination is 2. But i would like to also display the pathway. (PlaceA -> PlaceB -> PlaceC) for example.
May i know how do i display out the shortest path out using Java?
Please do help me! Thank you!
public static void main(String[] args) {
// TODO Auto-generated method stub
chooseNumOfVertices();
chooseNumOfEdges();
generateGraph();
in.close();
}
private static void generateGraph() {
int source = 0;
int destination = 0;
int edge = chooseEdge;
// TODO Auto-generated method stub
ArrayList<LinkedList<Integer>> graph = new ArrayList<LinkedList<Integer>>();
for (int i = 0; i < limit; i++) {
LinkedList<Integer> vertex = new LinkedList<Integer>();
vertex.add(i);
graph.add(vertex);
}
while (chooseEdge > 0) {
// Randomize the value
int value = new Random().nextInt(cityMapping.size());
int value2 = new Random().nextInt(cityMapping.size());
//
if (value != value2 && !graph.get(value).contains(value2) && !graph.get(value2).contains(value)) {
graph.get(value).add(value2);
graph.get(value2).add(value);
chooseEdge--;
}
}
// Printing out the Nodes
for (int i = 0; i < graph.size(); i++) {
// Return each LinkedList nodes
// System.out.println(graph.get(i));
for (int j = 0; j < graph.get(i).size(); j++) {
// Return each individual nodes inside LinkedList
for (Entry<Integer, String> entry : cityMapping.entrySet()) {
if (entry.getKey() == graph.get(i).get(j)) {
//System.out.print(graph.get(i).get(j) + "-> ");
System.out.print(graph.get(i).get(j) + ". " + entry.getValue() + " -> ");
}
}
}
System.out.println();
}
do {
for (
int i = 0; i < limit; i++) {
int[] newArray = new int[limit];
distance.add(newArray);
predecessor.add(newArray);
}
long time = System.nanoTime();
System.out.println("Searching BFS");
System.out.println("--------------------------------------------");
for (int i = 0; i < limit; i++) {
BFS(graph, i);
}
long CPUTime = (System.nanoTime() - time);
System.out.println("CPU Time for BFS for " + limit + "vertices and " + edge + "edges (in ns): " + CPUTime);
System.out.print("Enter -1 to exit! Enter source vertex (between 0 to " + (limit - 1) + ") : ");
source = in.nextInt();
if (source == -1) {
System.out.print("System terminating...");
break;
}
System.out.print("Enter destination vertex (between 0 to " + (limit - 1) + ") : ");
destination = in.nextInt();
System.out.println("Distance from " + source + " to " + destination + " is: " +
getDistance(source, destination));
System.out.println("The Predecessor of the path from " + source + " to " + destination + " is: "
+ getPredecessor(source, destination));
} while (source != -1);
}
private static void BFS(ArrayList<LinkedList<Integer>> graph, int i) {
// TODO Auto-generated method stub
boolean[] mark = new boolean[graph.size()];
Queue<Integer> L = new ArrayBlockingQueue<Integer>(graph.size()); //Queue
L.add(i);
mark[i] = true;
Arrays.fill(predecessor.get(i), -1);
Arrays.fill(distance.get(i), -1);
distance.get(i)[i] = 0;
while (!L.isEmpty()) {
int vertex = L.remove();
for (int i1 = 0; i1 < graph.get(vertex).size(); i1++) {
int v = graph.get(vertex).get(i1);
if (!mark[v]) {
mark[v] = true;
predecessor.get(i)[v] = vertex;
L.add(v);
distance.get(i)[v] = distance.get(i)[predecessor.get(i)[v]] + 1;
}
}
}
}
public static int getDistance(int start, int end) {
return (distance.get(start)[end]);
}
public static int getPredecessor(int start, int end) {
return (predecessor.get(start)[end]);
}
private static void chooseNumOfEdges() {
System.out.println("Please input the number of Edges:");
chooseEdge = in.nextInt();
}
// Number of Vertices
private static void chooseNumOfVertices() {
in = new Scanner(System.in);
System.out.println("Please input the number of Vertices:");
limit = in.nextInt();
// Read CSV
List<String[]> content = readCsvFile();
// Map each number to a city name
cityMapping = new HashMap<>();
for (int i = 0; i < limit; i++) {
cityMapping.put(i, content.get(i)[0]);
}
// System.out.println(cityMapping);
}
// Read CSV file
public static List<String[]> readCsvFile() {
String csvFile = "./Lab 4/country.csv";
BufferedReader br = null;
ArrayList<String> names = new ArrayList<String>();
List<String[]> content = new ArrayList<>();
String cvsSplitBy = ",";
try {
String line = "";
br = new BufferedReader(new FileReader(csvFile));
while ((line = br.readLine()) != null) {
content.add(line.split(cvsSplitBy));
}
} catch (Exception e) {
e.printStackTrace();
}
Random r = new Random();
Collections.shuffle(content);
return content;
}
}
import java.io.*;
public class Point {
private double x;
private double y;
public Point(double x_coord, double y_coord) {
x = x_coord;
y = y_coord;
}
}
public class PointArray {
private Point points[];
public PointArray(FileInputStream fileIn) throws IOException {
try {
BufferedReader inputStream = new BufferedReader(new InputStreamReader(fileIn));
int numberOfPoints = Integer.parseInt(inputStream.readLine())...points = new Point[numberOfPoints];
int i = 0;
String line;
while ((line = inputStream.readLine()) != null) {
System.out.print(line);
double x = Double.parseDouble(line.split(" ")[0]);
double y = Double.parseDouble(line.split(" ")[1]);
points[i] = new Point(x, y);
i++;
}
inputStream.close();
} catch (IOException e) {
System.out.println("Error");
System.exit(0);
}
}
}
public String toString() {
String format = "{";
for (int i = 0; i < points.length; i++) {
if (i < points.length - 1) format = format + points[i] + ", ";
else format = format + points[i];
}
format = format + "}";
return format;
}
public static void main(String[] args) {
FileInputStream five = new FileInputStream(new File("fivePoints.txt"));
PointArray fivePoints = new PointArray(five);
System.out.println(fivePoints.toString());
}
The txt file fivePoints is as shown below:
5
2 7
3 5
11 17
23 19
150 1
The first number in the first line means the number of points in the text file.
There is an error when I read the file. The output I want to get is {(2.0, 7.0), (3.0, 5.0), (11.0,17.0), (23.0, 19.0), (150.0, 1.0)}. How can I fix it?
Add a toString method to your Point which formats the Point as x, y
public class Point {
//...
#Override
public String toString() {
return x + ", " + y;
}
}
Move your toString method so it's defined in PointArray, you might also consider making use of StringJoiner to make your life simpler
public class PointArray {
//...
#Override
public String toString() {
StringJoiner sj = new StringJoiner(", ", "{", "}");
for (int i = 0; i < points.length; i++) {
sj.add("(" + points[i].toString() + ")");
}
return sj.toString();
}
}
I have the following little problem... I have this code that uses the method OpenFile() of one class ReadData to read a .txt file and also I have another class ArraysTZones used to create an object that stores 3 arrays (data1,data2,data3) and 3 integers (total1,total2,total3) returned by the method OpenFile(). The problem is that when I try to display each array (data1,data2,data3) using the method getArray() of ArrayTZones it stops and displays the error NullPointerException. Anyone knows how could I fix this?
public static void main (String args[]) throws IOException {
String fileName = ".//data.txt";
int[] def = new int[180];
try {
ReadData file = new ReadData(fileName);
ArraysTZones summaryatz = new ArraysTZones();
summaryatz = file.OpenFile();
for (int i = 0; i < 180; i++)
System.out.print (summaryatz.getArray1()[i] + " ");
System.out.println ("");
System.out.println (summaryatz.getTotal1());
for (int i = 0; i < 180; i++)
System.out.print (summaryatz.getArray2()[i] + " ");
System.out.println ("");
System.out.println (summaryatz.getTotal2());
for (int i = 0; i < 180; i++)
System.out.print (summaryatz.getArray3()[i] + " ");
System.out.println ("");
System.out.println (summaryatz.getTotal3());
}
catch (IOException e) {
System.out.println(e.getMessage());
}
}
Heres OpenFile()
public ArraysTZones OpenFile() throws IOException {
FileReader reader = new FileReader(path);
BufferedReader textReader = new BufferedReader(reader);
int numberOfTimeZones = 3;
int[] data1 = new int[180];
int[] data2 = new int[180];
int[] data3 = new int[180];
int total1 = 0;
int total2 = 0;
int total3 = 0;
ArraysTZones atz = new ArraysTZones();
for (int i = 0; i < numberOfTimeZones; i++){
if (i == 0) {
String firstTimeZone = textReader.readLine();
String[] val = firstTimeZone.split ("\\s+");
for (int u = 0; u < val.length; u++)
{
int stats = (int)(Math.ceil(Math.abs(Double.parseDouble(val[u]))));
total1 += stats;
data1[u] = stats;
}
total1= total1/180;
atz.setTotal1(total1);
atz.setArray1(data1);
}
else
if (i == 1) {
String secondTimeZone = textReader.readLine();
String[] val = secondTimeZone.split ("\\s+");
for (int u = 0; u < val.length; u++)
{
int stats = (int)(Math.ceil(Math.abs(Double.parseDouble(val[u]))));
total2 += stats;
data2[u] = stats;
}
total2= total2/180;
atz.setTotal2(total2);
atz.setArray2(data2);
}
else {
String thirdTimeZone = textReader.readLine();
String[] val = thirdTimeZone.split ("\\s+");
for (int u = 0; u < val.length; u++)
{
int stats = (int)(Math.ceil(Math.abs(Double.parseDouble(val[u]))));
total3 += stats;
data3[u] = stats;
}
total3= total3/180;
atz.setTotal3(total3);
atz.setArray3(data3);
}
}
textReader.close();
return atz;
}
The getArray()
public int[] getArray1 () {
return data1;
}
And setArray()
public void setArray1 (int[] farray) {
int[] data1 = new int[180];
//int[] farray = new int[180];
data1 = farray;
}
The problem seems to be here
public void setArray1 (int[] farray)
{
int[] data1 = new int[180];
//int[] farray = new int[180];
data1 = farray;
}
You're declaring a new variable called data1 and storing the content of farray to it.
After that method is done, that variable will be removed, due to his scope.
Remove int[] from the line int[] data1 = new int[180]; (or just remove the whole line .. it is unnecessary) and your data will be stored in the correct variable that was declared for the class.
public void setArray1 (int[] farray) {
data1 = farray;
}
You have to initialize ArraysTZones