Java多线程-矩阵乘法计算

单线程即通常情况下的串行办法

package prime.witersen;

import java.util.*;
import java.lang.*;

// command-line execution example) java MatmultD 6 < mat500.txt
// 6 means the number of threads to use
// < mat500.txt means the file that contains two matrices is given as standard input
//
// In eclipse, set the argument value and file input by using the menu [Run]->[Run Configurations]->{[Arguments],
// [Common->Input File]}.

// Original JAVA source code:
// http://stackoverflow.com/questions/21547462/how-to-multiply-2-dimensional-arrays-matrix-multiplication
public class MatmultD {
    private static Scanner sc = new Scanner(System.in);

    public static void main(String[] args) {
        int thread_no = 0;
        if (args.length == 1)
            thread_no = Integer.valueOf(args[0]);
        else
            thread_no = 1;

        int a[][] = readMatrix();
        int b[][] = readMatrix();

        long startTime = System.currentTimeMillis();
        int[][] c = multMatrix(a, b);
        long endTime = System.currentTimeMillis();

        // printMatrix(a);
        // printMatrix(b);
        printMatrix(c);

        // System.out.printf("thread_no: %d\n" , thread_no);
        // System.out.printf("Calculation Time: %d ms\n" , endTime-startTime);

        System.out.printf("[thread_no]:%2d , [Time]:%4d ms\n", thread_no, endTime - startTime);
    }

    public static int[][] readMatrix() {
        int rows = sc.nextInt();
        int cols = sc.nextInt();
        int[][] result = new int[rows][cols];
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                result[i][j] = sc.nextInt();
            }
        }
        return result;
    }

    public static void printMatrix(int[][] mat) {
        System.out.println("Matrix[" + mat.length + "][" + mat[0].length + "]");
        int rows = mat.length;
        int columns = mat[0].length;
        int sum = 0;
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < columns; j++) {
                System.out.printf("%4d ", mat[i][j]);
                sum += mat[i][j];
            }
            System.out.println();
        }
        System.out.println();
        System.out.println("Matrix Sum = " + sum + "\n");
    }

    public static int[][] multMatrix(int a[][], int b[][]) {// a[m][n], b[n][p]
        if (a.length == 0)
            return new int[0][0];
        if (a[0].length != b.length)
            return null; // invalid dims

        int n = a[0].length;
        int m = a.length;
        int p = b[0].length;
        int ans[][] = new int[m][p];

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < p; j++) {
                for (int k = 0; k < n; k++) {
                    ans[i][j] += a[i][k] * b[k][j];
                }
            }
        }
        return ans;
    }
}

多线程计算方法

package prime.witersen;

import java.util.Scanner;
import java.util.concurrent.CountDownLatch;

public class Matrix {

    private static Scanner sc = new Scanner(System.in);
    
    public static void main(String[] args) throws InterruptedException {
        // 声明时间变量
        long startTime, endTime;

        // 子线程数量
        int thread_no = 0;
        if (args.length == 1)
            thread_no = Integer.valueOf(args[0]);
        else
            thread_no = 1;

        // 读取矩阵
        int[][] A = readMatrix();
        int[][] B = readMatrix();

        // 并行计算结果
        int[][] parallel_result = new int[A.length][B[0].length];

        // 串行计算结果
        int[][] serial_result = new int[A.length][B[0].length];

        // 子线程的分片计算间隔
        int gap = A.length / thread_no;

        CountDownLatch latch = new CountDownLatch(thread_no);

        // printMatrix(A, "原矩阵A");
        // printMatrix(B, "原矩阵B");

        // 并行计算
        TestThread[] threads = new TestThread[thread_no];
        for (int i = 0; i < thread_no; i++) {
            threads[i] = new TestThread(A, B, i, gap, parallel_result, latch);
            threads[i].setName((i + 1) + "");
        }

        // 开始计时
        startTime = System.currentTimeMillis();

        // 启动线程
        for (int i = 0; i < thread_no; i++) {
            threads[i].start();
        }

        // 等待子线程结束
        latch.await();

        // 结束计时
        endTime = System.currentTimeMillis();

        System.out.println("并行计算总耗时:" + (endTime - startTime) + " ms");

        // printMatrix(parallel_result, "并行结果");

        // 开始计时
        startTime = System.currentTimeMillis();

        // 串行计算
        for (int i = 0; i < A.length; i++) {
            for (int j = 0; j < B[0].length; j++) {
                for (int k = 0; k < A[0].length; k++)
                    serial_result[i][j] += A[i][k] * B[k][j];
            }
        }

        // 结束计时
        endTime = System.currentTimeMillis();

        System.out.println("串行计算总耗时:" + (endTime - startTime) + " ms");

        // printMatrix(serial_result, "串行结果");
    }

    static class TestThread extends Thread {
        private int[][] A;
        private int[][] B;
        private int index;
        private int gap;
        private int[][] result;
        private CountDownLatch latch;

        public TestThread(int[][] A, int[][] B, int index, int gap, int[][] result, CountDownLatch latch) {
            this.latch = latch;
            this.A = A;
            this.B = B;
            this.index = index;
            this.gap = gap;
            this.result = result;
            this.latch = latch;
        }

        // 计算特定范围内的结果
        public void run() {
            // 每个线程内的开始计时
            long startTime = System.currentTimeMillis();

            for (int i = index * gap; i < (index + 1) * gap; i++) {
                for (int j = 0; j < B[0].length; j++) {
                    for (int k = 0; k < B.length; k++)
                        result[i][j] += A[i][k] * B[k][j];
                }
            }

            // 每个线程内的结束计时
            long endTime = System.currentTimeMillis();

            System.out.println("单线程" + Thread.currentThread().getName() + "耗时:" + (endTime - startTime) + " ms");

            latch.countDown();
        }
    }

    // 读取矩阵
    public static int[][] readMatrix() {
        int rows = sc.nextInt();
        int cols = sc.nextInt();
        int[][] result = new int[rows][cols];
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                result[i][j] = sc.nextInt();
            }
        }
        return result;
    }

    // 打印结果
    public static void printMatrix(int[][] mat, String str) {
        System.out.println(str + "矩阵[" + mat.length + "][" + mat[0].length + "]");
        int rows = mat.length;
        int columns = mat[0].length;
        int sum = 0;
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < columns; j++) {
                System.out.printf("%4d ", mat[i][j]);
                sum += mat[i][j];
            }
            System.out.println();
        }
        System.out.println();
        System.out.println("矩阵和 = " + sum + "\n");
    }
}

原创文章,作者:witersen,如若转载,请注明出处:https://www.witersen.com

(4)
witersen的头像witersen
上一篇 2021年4月21日 上午12:40
下一篇 2021年5月31日 下午9:31

相关推荐

  • 【9-2】表达式语言 购物车

    工程目录 一、说明 本实验基于实验1-2,不同之处为,本实验用servlet实现商品的添加和删除、使用表达式语言实现页面数据的展示 数据库表也与实验1-2相同 二、文件 1、工程文…

    2020年12月6日
    1.5K0

发表回复

登录后才能评论