#include <stdio.h>

#define N 4

int main() {

double A[N][N] = {
    {9, 2, 1, 1},
    {18, 4, 2, 2},
    {1, -1, 3, 2},
    {2, 1, 1, 5}
};

double b[N] = {10, 25, 7, 8};

double L[N][N] = {0};
double U[N][N];

double y[N] = {0};
double x[N] = {0};

double LU[N][N] = {0};
double Ax[N] = {0};

int i, j, k;

/* Copy A into U */
for(i = 0; i < N; i++) {
for(j = 0; j < N; j++) {
U[i][j] = A[i][j];
}
}

/* Initialize L as identity matrix */
for(i = 0; i < N; i++) {
L[i][i] = 1.0;
}

/* LU decomposition */
for(i = 0; i < N - 1; i++) {

for(j = i + 1; j < N; j++) {

double factor = U[j][i] / U[i][i];

L[j][i] = factor;

for(k = 0; k < N; k++) {
U[j][k] = U[j][k] - factor * U[i][k];
}
}
}

/* Forward substitution: Ly = b */
for(i = 0; i < N; i++) {

y[i] = b[i];

for(j = 0; j < i; j++) {
y[i] -= L[i][j] * y[j];
}
}

/* Back substitution: Ux = y */
for(i = N - 1; i >= 0; i--) {

x[i] = y[i];

for(j = i + 1; j < N; j++) {
x[i] -= U[i][j] * x[j];
}

x[i] /= U[i][i];
}

/* Display L matrix */
printf("L行列:\n");

for(i = 0; i < N; i++) {
for(j = 0; j < N; j++) {
printf("%8.3f ", L[i][j]);
}
printf("\n");
}

/* Display U matrix */
printf("\nU行列:\n");

for(i = 0; i < N; i++) {
for(j = 0; j < N; j++) {
printf("%8.3f ", U[i][j]);
}
printf("\n");
}

/* Display y vector */
printf("\nyベクトル:\n");

for(i = 0; i < N; i++) {
printf("%8.3f\n", y[i]);
}

/* Display solution x */
printf("\n連立方程式の解 x:\n");

for(i = 0; i < N; i++) {
printf("x%d = %8.3f\n", i + 1, x[i]);
}

/* Check LU = A */
for(i = 0; i < N; i++) {
for(j = 0; j < N; j++) {

LU[i][j] = 0;

for(k = 0; k < N; k++) {
LU[i][j] += L[i][k] * U[k][j];
}
}
}

printf("\nLU行列（LU = A の確認）:\n");

for(i = 0; i < N; i++) {
for(j = 0; j < N; j++) {
printf("%8.3f ", LU[i][j]);
}
printf("\n");
}

/* Check Ax = b */
for(i = 0; i < N; i++) {

Ax[i] = 0;

for(j = 0; j < N; j++) {
Ax[i] += A[i][j] * x[j];
}
}

printf("\nAxベクトル（Ax = b の確認）:\n");

for(i = 0; i < N; i++) {
printf("%8.3f\n", Ax[i]);
}

printf("\n元のbベクトル:\n");

for(i = 0; i < N; i++) {
printf("%8.3f\n", b[i]);
}

return 0;
}
