[MPI] MPI Matrix Multiplication
2024-07-19
4 min read
[2024 Summer GeekPie HPC Lecture Notes] Catalog of Notes
- Week 1 Notes
- Week 2
- Week 3
This is also an exercise (ex2-5) of the 2024 summer HPC lectures.
Code
#include <chrono>
#include <fstream>
#include <iostream>
#include <mpi.h>
#include <string>
#define DATA_PATH "./attachment/"
// serial
void naive_mul(double *a, double *b, double *c, uint64_t n1, uint64_t n2, uint64_t n3) {
for (int i = 0; i < n1; i++) {
for (int j = 0; j < n2; j++) {
for (int k = 0; k < n3; k++) {
c[i * n3 + k] += a[i * n2 + j] * b[j * n3 + k]; // c[i][k] += a[i][j] * b[j][k]
}
}
}
}
// THIS IS THE FUNCTION IMPLEMENTED
/* -----------------TODO---------------- */
void mul(double *a, double *b, double *c, uint64_t n1, uint64_t n2, uint64_t n3) {
// TODO: Implement your parallel matrix multiplication
int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
// distribute different rows to different processes, i.e. c[i] to process i
if (rank == 0) {
// master process
// 1. Broadcast matrix b to all processes
MPI_Bcast(b, n2 * n3, MPI_DOUBLE, 0, MPI_COMM_WORLD);
// 2. Scatter matrix a to all processes.
int worker_size = size - 1;
int chunk_size = n1 / worker_size;
int remainder = n1 % worker_size;
for (int worker_i = 0; worker_i < size - 1; worker_i++) {
//real i is worker_i+1
int start = worker_i * chunk_size + (worker_i < remainder ? worker_i : remainder);
int end = start + chunk_size + (worker_i < remainder ? 1 : 0);
int send_size = (end - start) * n2;
MPI_Send(&send_size, 1, MPI_INT, worker_i + 1, 0, MPI_COMM_WORLD);
MPI_Send(a + start * n2, send_size, MPI_DOUBLE, worker_i + 1, 0, MPI_COMM_WORLD);
}
// 3. Receive the result from all processes
for (int worker_i = 0; worker_i < size - 1; worker_i++) {
int start = worker_i * chunk_size + (worker_i < remainder ? worker_i : remainder);
int end = start + chunk_size + (worker_i < remainder ? 1 : 0);
int worker_n1 = end - start;
MPI_Recv(c + start * n3, worker_n1 * n3, MPI_DOUBLE, worker_i + 1, worker_i + 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}
} else {
// worker process
// 1. Receive matrix b from master process
MPI_Bcast(b, n2 * n3, MPI_DOUBLE, 0, MPI_COMM_WORLD);
// 2. Receive part of matrix a from master process
int recv_size;
MPI_Recv(&recv_size, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
double *recv_a = (double *)malloc(recv_size * sizeof(double));
MPI_Recv(recv_a, recv_size, MPI_DOUBLE, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
// 3. Calculate the result
// naive_mul(recv_a, b, c, recv_size / n2, n2, n3);
int worker_n1 = recv_size / n2;
double *worker_c = (double *)malloc(worker_n1 * n3 * sizeof(double));
naive_mul(recv_a, b, worker_c, worker_n1, n2, n3);
// 4. Send the result back to master process
MPI_Send(worker_c, worker_n1 * n3, MPI_DOUBLE, 0, rank, MPI_COMM_WORLD);
free(recv_a);
free(worker_c);
}
}
/* -----------------TODO---------------- */
int main(int argc, char **argv) {
if (argc != 2) {
std::cout << "Usage: " << argv[0] << " <filename>" << std::endl;
return 1;
}
MPI_Init(NULL, NULL);
double start = MPI_Wtime();
// ###########################################
std::string filename = argv[1];
uint64_t n1, n2, n3;
FILE *fi;
int rank, size;
int dump;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
fi = fopen((std::string(DATA_PATH) + filename + ".data").c_str(), "rb");
dump = fread(&n1, 1, 8, fi);
dump = fread(&n2, 1, 8, fi);
dump = fread(&n3, 1, 8, fi);
double *a = (double *)malloc(n1 * n2 * 8);
double *b = (double *)malloc(n2 * n3 * 8);
double *c = (double *)malloc(n1 * n3 * 8);
dump = fread(a, 1, n1 * n2 * 8, fi);
dump = fread(b, 1, n2 * n3 * 8, fi);
fclose(fi);
for (uint64_t i = 0; i < n1; i++) {
for (uint64_t k = 0; k < n3; k++) {
c[i * n3 + k] = 0;
}
}
mul(a, b, c, n1, n2, n3);
if (rank == 0) {
fi = fopen((std::string(DATA_PATH) + filename + ".ans").c_str(), "wb");
dump = fwrite(c, 1, n1 * n3 * 8, fi);
fclose(fi);
}
free(a);
free(b);
free(c);
// ###########################################
double end = MPI_Wtime();
if (rank == 0) {
std::cout << "Time: " << end - start << std::endl;
}
MPI_Finalize();
return 0;
}