plugins-core/socket/plugin.ts

439 lines
11 KiB
TypeScript

import {
Plugin,
EventListener,
} from '@squeebot/core/lib/plugin';
import { logger } from '@squeebot/core/lib/core';
import path from 'path';
import fs from 'fs/promises';
import tls, { TLSSocket } from 'tls';
import net, { Server, Socket } from 'net';
import { randomBytes } from 'crypto';
import EventEmitter from 'events';
interface TLSConfiguration {
[x: string]: string | boolean | undefined;
enabled: boolean;
key?: string;
cert?: string;
ca?: string;
dhparam?: string;
crl?: string;
pfx?: string;
rejectUnauthorized: boolean;
requestCert: boolean;
}
interface SocketMessage {
[x: string]: unknown;
command?: string;
id?: string;
status?: string;
arguments?: unknown[];
}
interface SocketConfiguration {
name: string;
plugin?: string;
authorizedIPs: string[];
tls?: TLSConfiguration;
bind: number | string | null | false;
}
interface PluginConfiguration {
sockets: SocketConfiguration[];
}
type SocketWithID = (Socket | TLSSocket) & { id: string };
const makeID = () => randomBytes(8).toString('hex');
class SocketServer extends EventEmitter {
private server: Server | null = null;
private sockets = new Set<SocketWithID>();
private intents = new Map<string, string>();
public alive = false;
constructor(public config: SocketConfiguration) {
super();
}
get logName() {
return `socket:${this.config.name}`;
}
public stop() {
if (this.server) {
logger.log('[%s] Stopping socket server..', this.logName);
this.server.close();
for (const sock of this.sockets) {
sock.destroy();
}
this.sockets.clear();
this.intents.clear();
}
this.emit('close');
this.alive = false;
}
public async start() {
if (this.alive) {
this.stop();
}
await this.createSocket();
}
public send(message: SocketMessage): void {
this.sockets.forEach((socket) => this.sendToClient(socket, message));
}
public sendToID(id: string, message: SocketMessage): void {
const findClient = Array.from(this.sockets.values()).find(
(client) => client.id === id
);
if (!findClient) return;
this.sendToClient(findClient, message);
}
public sendIntent(intent: string, message: SocketMessage): void {
for (const [id, clientintent] of this.intents) {
if (clientintent !== intent) continue;
this.sendToID(id, message);
}
}
private sendToClient(socket: SocketWithID, message: SocketMessage): void {
socket.write(JSON.stringify(message) + '\r\n');
}
private handleClientLine(socket: SocketWithID, req: SocketMessage): void {
if (!req.command || req.command === 'status') {
socket.write(
JSON.stringify({
id: socket.id,
status: 'OK',
}) + '\r\n'
);
return;
}
// Client requests graceful close
if (req.command === 'quit') {
socket.end();
this.intents.delete(socket.id);
this.emit('disconnect', socket.id);
return;
}
// Parse argument list
let args = [];
const argField = req.args || req.argument || req.arguments;
if (argField) {
if (!Array.isArray(argField)) {
args = [argField];
} else {
args = argField;
}
}
// Set or return the intent
if (req.command === 'intent') {
if (args[0]) {
this.intents.set(socket.id, args[0]);
this.emit('intent', socket.id, args[0]);
}
this.sendToClient(socket, {
id: socket.id,
command: 'intent',
arguments: [this.intents.get(socket.id)],
});
return;
}
try {
const intent = this.intents.get(socket.id);
this.emit(
'message',
{ ...req, arguments: args },
{ id: socket.id, intent },
(response: SocketMessage) => {
socket.write(JSON.stringify(response) + '\r\n');
}
);
} catch (err) {
logger.error(
`[%s] Error in message handler: ${(err as Error).stack}`,
this.logName
);
}
}
private handleIncoming(incoming: Socket | TLSSocket): void {
const socket = incoming as SocketWithID;
socket.id = makeID();
const c = this.config;
let addr = socket.remoteAddress;
if (addr?.indexOf('::ffff:') === 0) {
addr = addr.substring(7);
}
if (
addr &&
c.authorizedIPs?.length &&
c.authorizedIPs.indexOf(addr) === -1
) {
if (
!(
c.authorizedIPs.indexOf('localhost') !== -1 &&
(addr === '::1' || addr === '127.0.0.1')
)
) {
logger.warn(
'[%s] Unauthorized connection made from %s',
this.logName,
addr
);
socket.destroy();
return;
}
}
this.sockets.add(socket);
socket.once('end', () => {
logger.log('[%s] Client from %s (%s) disconnected.', this.logName, addr, socket.id);
this.sockets.delete(socket);
this.intents.delete(socket.id);
});
logger.log('[%s] Client from %s (%s) connected.', this.logName, addr, socket.id);
socket.setEncoding('utf8');
socket.write(
JSON.stringify({
id: socket.id,
status: 'OK',
}) + '\r\n'
);
socket.on('data', (data) => {
try {
const split = data.split('\r\n');
for (const chunk of split) {
if (chunk === '') {
continue;
}
const req = JSON.parse(chunk);
this.handleClientLine(socket, req);
}
} catch (error) {
this.sendToClient(socket, {
status: 'ERROR',
arguments: [(error as Error).message],
});
}
});
this.emit('connect', socket.id);
}
private createSocket(): Promise<void> {
const c = this.config;
if (c.bind == null || c.bind === false) {
return Promise.resolve();
}
return new Promise((resolve, reject) => {
if (c.tls && c.tls.enabled) {
if (
!c.tls.rejectUnauthorized &&
(!c.authorizedIPs || !c.authorizedIPs.length)
) {
logger.warn(
'[%s] [SECURITY WARNING] !!! YOUR CONTROL SOCKET IS (STILL) INSECURE !!!',
this.logName
);
logger.warn(
'[%s] [SECURITY WARNING] You have enabled TLS, ' +
'but you do not have any form of access control configured.',
this.logName
);
logger.warn(
'[%s] [SECURITY WARNING] In order to secure your control socket, ' +
'either enable invalid certificate rejection (rejectUnauthorized) or set a ' +
'list of authorized IPs.',
this.logName
);
logger.warn(
'[%s] [SECURITY WARNING] !!! YOUR CONTROL SOCKET IS (STILL) INSECURE !!!',
this.logName
);
}
parseTLSConfig(c.tls).then(
(options) => {
this.server = tls.createServer(options, (socket) =>
this.handleIncoming(socket)
);
this.server.on('error', (e) => {
logger.error('[%s] Secure socket error:', e.message);
if (!this.alive) reject(e);
});
this.server.listen(c.bind, () => {
logger.log(
'[%s] Secure socket listening on %s',
this.logName,
c.bind?.toString()
);
this.alive = true;
resolve();
});
},
(err) => {
logger.error(
'[%s] Secure socket listen failed: %s',
this.logName,
err.message
);
reject(err);
}
);
return;
}
if (!c.authorizedIPs || !c.authorizedIPs.length) {
logger.warn(
'[%s] [SECURITY WARNING] !!! YOUR SOCKET IS INSECURE !!!',
this.logName
);
logger.warn(
'[%s] [SECURITY WARNING] You do not have any form of access control configured.',
this.logName
);
logger.warn(
'[%s] [SECURITY WARNING] In order to secure your control socket, ' +
'either enable TLS with certificate verification or set a list of authorized IPs.',
this.logName
);
logger.warn(
'[%s] [SECURITY WARNING] !!! YOUR SOCKET IS INSECURE !!!',
this.logName
);
}
this.server = net.createServer((socket) => this.handleIncoming(socket));
this.server.on('error', (e) => {
logger.error('[%s] Socket error:', e.message);
if (!this.alive) reject(e);
});
this.server.listen(c.bind, () => {
logger.log(
'[%s] Socket listening on %s',
this.logName,
c.bind?.toString()
);
this.alive = true;
resolve();
});
});
}
}
const match = ['key', 'cert', 'ca', 'dhparam', 'crl', 'pfx'];
async function parseTLSConfig(
tlsconfig: TLSConfiguration
): Promise<Record<string, unknown>> {
const result: Record<string, unknown> = {};
for (const key in tlsconfig) {
if (key === 'enabled') {
continue;
}
if (match.indexOf(key) === -1) {
result[key] = tlsconfig[key];
continue;
}
const value = path.resolve(tlsconfig[key] as string);
const bf = await fs.readFile(value);
result[key] = bf;
}
return result;
}
class SocketPlugin extends Plugin {
private plugins = new Map<string, any>();
private servers = new Map<string, SocketServer>();
/** Create an unique 8-bit ID */
public makeID = makeID;
/**
* Kill a socket server by name
* @param name Socket server name
*/
public killServer(name: string) {
if (!this.servers.has(name)) return;
this.servers.get(name)?.stop();
this.servers.delete(name);
}
/**
* Create a socket server
* @param config Socket Server configuration
* @param start Start the socket automatically on creation
* @returns Socket Server
*/
public async createServer(config: SocketConfiguration, start = true) {
if (this.servers.has(config.name)) return;
const server = new SocketServer(config);
this.servers.set(config.name, server);
if (start) await server.start();
return server;
}
/**
* Get an existing socket server by name
* @param name Socket Server name
* @returns Existing Socket Server or undefined
*/
public getServer(name: string) {
return this.servers.get(name);
}
@EventListener('pluginUnload')
public unloadEventHandler(plugin: string | Plugin): void {
if (plugin === this.name || plugin === this) {
logger.debug('[%s]', this.name, 'shutting down..');
for (const [key] of this.servers) {
this.killServer(key);
}
this.plugins.clear();
this.emit('pluginUnloaded', this);
}
}
@EventListener('pluginUnloaded')
public unloadedEventHandler(plugin: string | Plugin): void {
if (typeof plugin !== 'string') {
plugin = plugin.manifest.name;
}
this.plugins.delete(plugin);
this.killPluginServers(plugin);
}
private killPluginServers(plugin: string) {
for (const [name, server] of this.servers) {
if (!server.config.plugin || server.config.plugin !== plugin) continue;
this.killServer(name);
}
}
}
module.exports = SocketPlugin;