[MPI] MPI Matrix Multiplication


[2024 Summer GeekPie HPC Lecture Notes] Catalog of Notes


This is also an exercise (ex2-5) of the 2024 summer HPC lectures.

Code

#include <chrono>
#include <fstream>
#include <iostream>
#include <mpi.h>
#include <string>

#define DATA_PATH "./attachment/"

// serial
void naive_mul(double *a, double *b, double *c, uint64_t n1, uint64_t n2, uint64_t n3) { 
    for (int i = 0; i < n1; i++) {
        for (int j = 0; j < n2; j++) {
            for (int k = 0; k < n3; k++) {
                c[i * n3 + k] += a[i * n2 + j] * b[j * n3 + k]; // c[i][k] += a[i][j] * b[j][k]
            }
        }
    }
}

// THIS IS THE FUNCTION IMPLEMENTED
/* -----------------TODO---------------- */
void mul(double *a, double *b, double *c, uint64_t n1, uint64_t n2, uint64_t n3) {
    // TODO: Implement your parallel matrix multiplication
    int rank, size;
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);
    // distribute different rows to different processes, i.e. c[i] to process i
    if (rank == 0) {
        // master process
        // 1. Broadcast matrix b to all processes
        MPI_Bcast(b, n2 * n3, MPI_DOUBLE, 0, MPI_COMM_WORLD);
        
        // 2. Scatter matrix a to all processes.
        int worker_size = size - 1;
        int chunk_size = n1 / worker_size;
        int remainder = n1 % worker_size;
        for (int worker_i = 0; worker_i < size - 1; worker_i++) {
            //real i is worker_i+1
            int start = worker_i * chunk_size + (worker_i < remainder ? worker_i : remainder);
            int end = start + chunk_size + (worker_i < remainder ? 1 : 0);
            int send_size = (end - start) * n2;
            MPI_Send(&send_size, 1, MPI_INT, worker_i + 1, 0, MPI_COMM_WORLD);
            MPI_Send(a + start * n2, send_size, MPI_DOUBLE, worker_i + 1, 0, MPI_COMM_WORLD);
        }
        // 3. Receive the result from all processes
        for (int worker_i = 0; worker_i < size - 1; worker_i++) {
            int start = worker_i * chunk_size + (worker_i < remainder ? worker_i : remainder);
            int end = start + chunk_size + (worker_i < remainder ? 1 : 0);
            int worker_n1 = end - start;
            MPI_Recv(c + start * n3, worker_n1 * n3, MPI_DOUBLE, worker_i + 1, worker_i + 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
        }
    } else {
        // worker process
        // 1. Receive matrix b from master process
        MPI_Bcast(b, n2 * n3, MPI_DOUBLE, 0, MPI_COMM_WORLD);
        // 2. Receive part of matrix a from master process
        int recv_size;
        MPI_Recv(&recv_size, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
        double *recv_a = (double *)malloc(recv_size * sizeof(double));
        MPI_Recv(recv_a, recv_size, MPI_DOUBLE, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
        // 3. Calculate the result
        // naive_mul(recv_a, b, c, recv_size / n2, n2, n3);
        int worker_n1 = recv_size / n2;
        double *worker_c = (double *)malloc(worker_n1 * n3 * sizeof(double));
        naive_mul(recv_a, b, worker_c, worker_n1, n2, n3);
        // 4. Send the result back to master process
        MPI_Send(worker_c, worker_n1 * n3, MPI_DOUBLE, 0, rank, MPI_COMM_WORLD);
        free(recv_a);
        free(worker_c);
    }
}
/* -----------------TODO---------------- */

int main(int argc, char **argv) {
    if (argc != 2) {
        std::cout << "Usage: " << argv[0] << " <filename>" << std::endl;
        return 1;
    }
    MPI_Init(NULL, NULL);
    double start = MPI_Wtime();

    // ###########################################

    std::string filename = argv[1];
    uint64_t n1, n2, n3;
    FILE *fi;

    int rank, size;
    int dump;

    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);

    fi = fopen((std::string(DATA_PATH) + filename + ".data").c_str(), "rb");
    dump = fread(&n1, 1, 8, fi);
    dump = fread(&n2, 1, 8, fi);
    dump = fread(&n3, 1, 8, fi);

    double *a = (double *)malloc(n1 * n2 * 8);
    double *b = (double *)malloc(n2 * n3 * 8);
    double *c = (double *)malloc(n1 * n3 * 8);

    dump = fread(a, 1, n1 * n2 * 8, fi);
    dump = fread(b, 1, n2 * n3 * 8, fi);
    fclose(fi);

    for (uint64_t i = 0; i < n1; i++) {
        for (uint64_t k = 0; k < n3; k++) {
            c[i * n3 + k] = 0;
        }
    }

    mul(a, b, c, n1, n2, n3);

    if (rank == 0) {
        fi = fopen((std::string(DATA_PATH) + filename + ".ans").c_str(), "wb");
        dump = fwrite(c, 1, n1 * n3 * 8, fi);
        fclose(fi);
    }

    free(a);
    free(b);
    free(c);

    // ###########################################

    double end = MPI_Wtime();
    if (rank == 0) {
        std::cout << "Time: " << end - start << std::endl;
    }
    MPI_Finalize();
    return 0;
}