#include <bits/stdc++.h>
#define LL long long
#define ULL unsigned long long
#define mem(i, j) memset(i, j, sizeof(i))
#define rep(i, j, k) for(int i = j; i <= k; i++)
#define dep(i, j, k) for(int i = k; i >= j; i--)
#define pb push_back
#define make make_pair
#define INF 0x3f3f3f3f
#define inf LLONG_MAX
#define PI acos(-1)
#define fir first
#define sec second
#define lb(x) ((x) & (-(x)))
using namespace std;
const int N = 1e6 + 5;
const LL mod = 998244353;
LL ksm(LL a, LL b) {
LL res = 1LL;
while(b) {
if(b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
LL fac[N];
LL C(LL n, LL m) {
return fac[n] * ksm(fac[m] * fac[n - m] % mod, mod - 2) % mod;
}
void solve() {
LL n, k;
scanf("%lld %lld", &n, &k);
fac[0] = 1LL;
for(LL i = 1LL; i <= n; i++) fac[i] = fac[i - 1] * i % mod;
if(k == 0LL) {
printf("%lld\n", fac[n]);
return ;
}
if(k >= n) {
puts("0");
return ;
}
LL ans = 0LL;
LL flag = 1LL;
rep(i, 0, n - k) {
ans = (ans + flag * C(n - k, n - k - i) * ksm(n - k - i, n) % mod + mod) % mod;
flag = -flag;
}
ans = 2LL * C(n, n - k) * ans % mod;
printf("%lld\n", ans);
}
int main() {
// int _; scanf("%d", &_);
// while(_--) solve();
solve();
return 0;
}