单线程即通常情况下的串行办法
package prime.witersen;
import java.util.*;
import java.lang.*;
// command-line execution example) java MatmultD 6 < mat500.txt
// 6 means the number of threads to use
// < mat500.txt means the file that contains two matrices is given as standard input
//
// In eclipse, set the argument value and file input by using the menu [Run]->[Run Configurations]->{[Arguments],
// [Common->Input File]}.
// Original JAVA source code:
// http://stackoverflow.com/questions/21547462/how-to-multiply-2-dimensional-arrays-matrix-multiplication
public class MatmultD {
private static Scanner sc = new Scanner(System.in);
public static void main(String[] args) {
int thread_no = 0;
if (args.length == 1)
thread_no = Integer.valueOf(args[0]);
else
thread_no = 1;
int a[][] = readMatrix();
int b[][] = readMatrix();
long startTime = System.currentTimeMillis();
int[][] c = multMatrix(a, b);
long endTime = System.currentTimeMillis();
// printMatrix(a);
// printMatrix(b);
printMatrix(c);
// System.out.printf("thread_no: %d\n" , thread_no);
// System.out.printf("Calculation Time: %d ms\n" , endTime-startTime);
System.out.printf("[thread_no]:%2d , [Time]:%4d ms\n", thread_no, endTime - startTime);
}
public static int[][] readMatrix() {
int rows = sc.nextInt();
int cols = sc.nextInt();
int[][] result = new int[rows][cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result[i][j] = sc.nextInt();
}
}
return result;
}
public static void printMatrix(int[][] mat) {
System.out.println("Matrix[" + mat.length + "][" + mat[0].length + "]");
int rows = mat.length;
int columns = mat[0].length;
int sum = 0;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < columns; j++) {
System.out.printf("%4d ", mat[i][j]);
sum += mat[i][j];
}
System.out.println();
}
System.out.println();
System.out.println("Matrix Sum = " + sum + "\n");
}
public static int[][] multMatrix(int a[][], int b[][]) {// a[m][n], b[n][p]
if (a.length == 0)
return new int[0][0];
if (a[0].length != b.length)
return null; // invalid dims
int n = a[0].length;
int m = a.length;
int p = b[0].length;
int ans[][] = new int[m][p];
for (int i = 0; i < m; i++) {
for (int j = 0; j < p; j++) {
for (int k = 0; k < n; k++) {
ans[i][j] += a[i][k] * b[k][j];
}
}
}
return ans;
}
}
多线程计算方法
package prime.witersen;
import java.util.Scanner;
import java.util.concurrent.CountDownLatch;
public class Matrix {
private static Scanner sc = new Scanner(System.in);
public static void main(String[] args) throws InterruptedException {
// 声明时间变量
long startTime, endTime;
// 子线程数量
int thread_no = 0;
if (args.length == 1)
thread_no = Integer.valueOf(args[0]);
else
thread_no = 1;
// 读取矩阵
int[][] A = readMatrix();
int[][] B = readMatrix();
// 并行计算结果
int[][] parallel_result = new int[A.length][B[0].length];
// 串行计算结果
int[][] serial_result = new int[A.length][B[0].length];
// 子线程的分片计算间隔
int gap = A.length / thread_no;
CountDownLatch latch = new CountDownLatch(thread_no);
// printMatrix(A, "原矩阵A");
// printMatrix(B, "原矩阵B");
// 并行计算
TestThread[] threads = new TestThread[thread_no];
for (int i = 0; i < thread_no; i++) {
threads[i] = new TestThread(A, B, i, gap, parallel_result, latch);
threads[i].setName((i + 1) + "");
}
// 开始计时
startTime = System.currentTimeMillis();
// 启动线程
for (int i = 0; i < thread_no; i++) {
threads[i].start();
}
// 等待子线程结束
latch.await();
// 结束计时
endTime = System.currentTimeMillis();
System.out.println("并行计算总耗时:" + (endTime - startTime) + " ms");
// printMatrix(parallel_result, "并行结果");
// 开始计时
startTime = System.currentTimeMillis();
// 串行计算
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < B[0].length; j++) {
for (int k = 0; k < A[0].length; k++)
serial_result[i][j] += A[i][k] * B[k][j];
}
}
// 结束计时
endTime = System.currentTimeMillis();
System.out.println("串行计算总耗时:" + (endTime - startTime) + " ms");
// printMatrix(serial_result, "串行结果");
}
static class TestThread extends Thread {
private int[][] A;
private int[][] B;
private int index;
private int gap;
private int[][] result;
private CountDownLatch latch;
public TestThread(int[][] A, int[][] B, int index, int gap, int[][] result, CountDownLatch latch) {
this.latch = latch;
this.A = A;
this.B = B;
this.index = index;
this.gap = gap;
this.result = result;
this.latch = latch;
}
// 计算特定范围内的结果
public void run() {
// 每个线程内的开始计时
long startTime = System.currentTimeMillis();
for (int i = index * gap; i < (index + 1) * gap; i++) {
for (int j = 0; j < B[0].length; j++) {
for (int k = 0; k < B.length; k++)
result[i][j] += A[i][k] * B[k][j];
}
}
// 每个线程内的结束计时
long endTime = System.currentTimeMillis();
System.out.println("单线程" + Thread.currentThread().getName() + "耗时:" + (endTime - startTime) + " ms");
latch.countDown();
}
}
// 读取矩阵
public static int[][] readMatrix() {
int rows = sc.nextInt();
int cols = sc.nextInt();
int[][] result = new int[rows][cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result[i][j] = sc.nextInt();
}
}
return result;
}
// 打印结果
public static void printMatrix(int[][] mat, String str) {
System.out.println(str + "矩阵[" + mat.length + "][" + mat[0].length + "]");
int rows = mat.length;
int columns = mat[0].length;
int sum = 0;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < columns; j++) {
System.out.printf("%4d ", mat[i][j]);
sum += mat[i][j];
}
System.out.println();
}
System.out.println();
System.out.println("矩阵和 = " + sum + "\n");
}
}
原创文章,作者:witersen,如若转载,请注明出处:https://www.witersen.com