439 lines
11 KiB
TypeScript
439 lines
11 KiB
TypeScript
import {
|
|
Plugin,
|
|
EventListener,
|
|
} from '@squeebot/core/lib/plugin';
|
|
|
|
import { logger } from '@squeebot/core/lib/core';
|
|
|
|
import path from 'path';
|
|
import fs from 'fs/promises';
|
|
import tls, { TLSSocket } from 'tls';
|
|
import net, { Server, Socket } from 'net';
|
|
import { randomBytes } from 'crypto';
|
|
import EventEmitter from 'events';
|
|
|
|
interface TLSConfiguration {
|
|
[x: string]: string | boolean | undefined;
|
|
enabled: boolean;
|
|
key?: string;
|
|
cert?: string;
|
|
ca?: string;
|
|
dhparam?: string;
|
|
crl?: string;
|
|
pfx?: string;
|
|
rejectUnauthorized: boolean;
|
|
requestCert: boolean;
|
|
}
|
|
|
|
interface SocketMessage {
|
|
[x: string]: unknown;
|
|
command?: string;
|
|
id?: string;
|
|
status?: string;
|
|
arguments?: unknown[];
|
|
}
|
|
|
|
interface SocketConfiguration {
|
|
name: string;
|
|
plugin?: string;
|
|
authorizedIPs: string[];
|
|
tls?: TLSConfiguration;
|
|
bind: number | string | null | false;
|
|
}
|
|
|
|
interface PluginConfiguration {
|
|
sockets: SocketConfiguration[];
|
|
}
|
|
|
|
type SocketWithID = (Socket | TLSSocket) & { id: string };
|
|
|
|
const makeID = () => randomBytes(8).toString('hex');
|
|
|
|
class SocketServer extends EventEmitter {
|
|
private server: Server | null = null;
|
|
private sockets = new Set<SocketWithID>();
|
|
private intents = new Map<string, string>();
|
|
public alive = false;
|
|
|
|
constructor(public config: SocketConfiguration) {
|
|
super();
|
|
}
|
|
|
|
get logName() {
|
|
return `socket:${this.config.name}`;
|
|
}
|
|
|
|
public stop() {
|
|
if (this.server) {
|
|
logger.log('[%s] Stopping socket server..', this.logName);
|
|
this.server.close();
|
|
for (const sock of this.sockets) {
|
|
sock.destroy();
|
|
}
|
|
this.sockets.clear();
|
|
this.intents.clear();
|
|
}
|
|
this.emit('close');
|
|
this.alive = false;
|
|
}
|
|
|
|
public async start() {
|
|
if (this.alive) {
|
|
this.stop();
|
|
}
|
|
await this.createSocket();
|
|
}
|
|
|
|
public send(message: SocketMessage): void {
|
|
this.sockets.forEach((socket) => this.sendToClient(socket, message));
|
|
}
|
|
|
|
public sendToID(id: string, message: SocketMessage): void {
|
|
const findClient = Array.from(this.sockets.values()).find(
|
|
(client) => client.id === id
|
|
);
|
|
if (!findClient) return;
|
|
this.sendToClient(findClient, message);
|
|
}
|
|
|
|
public sendIntent(intent: string, message: SocketMessage): void {
|
|
for (const [id, clientintent] of this.intents) {
|
|
if (clientintent !== intent) continue;
|
|
this.sendToID(id, message);
|
|
}
|
|
}
|
|
|
|
private sendToClient(socket: SocketWithID, message: SocketMessage): void {
|
|
socket.write(JSON.stringify(message) + '\r\n');
|
|
}
|
|
|
|
private handleClientLine(socket: SocketWithID, req: SocketMessage): void {
|
|
if (!req.command || req.command === 'status') {
|
|
socket.write(
|
|
JSON.stringify({
|
|
id: socket.id,
|
|
status: 'OK',
|
|
}) + '\r\n'
|
|
);
|
|
return;
|
|
}
|
|
|
|
// Client requests graceful close
|
|
if (req.command === 'quit') {
|
|
socket.end();
|
|
this.intents.delete(socket.id);
|
|
this.emit('disconnect', socket.id);
|
|
return;
|
|
}
|
|
|
|
// Parse argument list
|
|
let args = [];
|
|
const argField = req.args || req.argument || req.arguments;
|
|
if (argField) {
|
|
if (!Array.isArray(argField)) {
|
|
args = [argField];
|
|
} else {
|
|
args = argField;
|
|
}
|
|
}
|
|
|
|
// Set or return the intent
|
|
if (req.command === 'intent') {
|
|
if (args[0]) {
|
|
this.intents.set(socket.id, args[0]);
|
|
this.emit('intent', socket.id, args[0]);
|
|
}
|
|
|
|
this.sendToClient(socket, {
|
|
id: socket.id,
|
|
command: 'intent',
|
|
arguments: [this.intents.get(socket.id)],
|
|
});
|
|
|
|
return;
|
|
}
|
|
|
|
try {
|
|
const intent = this.intents.get(socket.id);
|
|
this.emit(
|
|
'message',
|
|
{ ...req, arguments: args },
|
|
{ id: socket.id, intent },
|
|
(response: SocketMessage) => {
|
|
socket.write(JSON.stringify(response) + '\r\n');
|
|
}
|
|
);
|
|
} catch (err) {
|
|
logger.error(
|
|
`[%s] Error in message handler: ${(err as Error).stack}`,
|
|
this.logName
|
|
);
|
|
}
|
|
}
|
|
|
|
private handleIncoming(incoming: Socket | TLSSocket): void {
|
|
const socket = incoming as SocketWithID;
|
|
socket.id = makeID();
|
|
const c = this.config;
|
|
let addr = socket.remoteAddress;
|
|
if (addr?.indexOf('::ffff:') === 0) {
|
|
addr = addr.substring(7);
|
|
}
|
|
|
|
if (
|
|
addr &&
|
|
c.authorizedIPs?.length &&
|
|
c.authorizedIPs.indexOf(addr) === -1
|
|
) {
|
|
if (
|
|
!(
|
|
c.authorizedIPs.indexOf('localhost') !== -1 &&
|
|
(addr === '::1' || addr === '127.0.0.1')
|
|
)
|
|
) {
|
|
logger.warn(
|
|
'[%s] Unauthorized connection made from %s',
|
|
this.logName,
|
|
addr
|
|
);
|
|
socket.destroy();
|
|
return;
|
|
}
|
|
}
|
|
|
|
this.sockets.add(socket);
|
|
|
|
socket.once('end', () => {
|
|
logger.log('[%s] Client from %s (%s) disconnected.', this.logName, addr, socket.id);
|
|
this.sockets.delete(socket);
|
|
this.intents.delete(socket.id);
|
|
});
|
|
|
|
logger.log('[%s] Client from %s (%s) connected.', this.logName, addr, socket.id);
|
|
socket.setEncoding('utf8');
|
|
socket.write(
|
|
JSON.stringify({
|
|
id: socket.id,
|
|
status: 'OK',
|
|
}) + '\r\n'
|
|
);
|
|
|
|
socket.on('data', (data) => {
|
|
try {
|
|
const split = data.split('\r\n');
|
|
for (const chunk of split) {
|
|
if (chunk === '') {
|
|
continue;
|
|
}
|
|
|
|
const req = JSON.parse(chunk);
|
|
this.handleClientLine(socket, req);
|
|
}
|
|
} catch (error) {
|
|
this.sendToClient(socket, {
|
|
status: 'ERROR',
|
|
arguments: [(error as Error).message],
|
|
});
|
|
}
|
|
});
|
|
|
|
this.emit('connect', socket.id);
|
|
}
|
|
|
|
private createSocket(): Promise<void> {
|
|
const c = this.config;
|
|
if (c.bind == null || c.bind === false) {
|
|
return Promise.resolve();
|
|
}
|
|
|
|
return new Promise((resolve, reject) => {
|
|
if (c.tls && c.tls.enabled) {
|
|
if (
|
|
!c.tls.rejectUnauthorized &&
|
|
(!c.authorizedIPs || !c.authorizedIPs.length)
|
|
) {
|
|
logger.warn(
|
|
'[%s] [SECURITY WARNING] !!! YOUR CONTROL SOCKET IS (STILL) INSECURE !!!',
|
|
this.logName
|
|
);
|
|
logger.warn(
|
|
'[%s] [SECURITY WARNING] You have enabled TLS, ' +
|
|
'but you do not have any form of access control configured.',
|
|
this.logName
|
|
);
|
|
logger.warn(
|
|
'[%s] [SECURITY WARNING] In order to secure your control socket, ' +
|
|
'either enable invalid certificate rejection (rejectUnauthorized) or set a ' +
|
|
'list of authorized IPs.',
|
|
this.logName
|
|
);
|
|
logger.warn(
|
|
'[%s] [SECURITY WARNING] !!! YOUR CONTROL SOCKET IS (STILL) INSECURE !!!',
|
|
this.logName
|
|
);
|
|
}
|
|
parseTLSConfig(c.tls).then(
|
|
(options) => {
|
|
this.server = tls.createServer(options, (socket) =>
|
|
this.handleIncoming(socket)
|
|
);
|
|
this.server.on('error', (e) => {
|
|
logger.error('[%s] Secure socket error:', e.message);
|
|
if (!this.alive) reject(e);
|
|
});
|
|
this.server.listen(c.bind, () => {
|
|
logger.log(
|
|
'[%s] Secure socket listening on %s',
|
|
this.logName,
|
|
c.bind?.toString()
|
|
);
|
|
this.alive = true;
|
|
resolve();
|
|
});
|
|
},
|
|
(err) => {
|
|
logger.error(
|
|
'[%s] Secure socket listen failed: %s',
|
|
this.logName,
|
|
err.message
|
|
);
|
|
reject(err);
|
|
}
|
|
);
|
|
return;
|
|
}
|
|
|
|
if (!c.authorizedIPs || !c.authorizedIPs.length) {
|
|
logger.warn(
|
|
'[%s] [SECURITY WARNING] !!! YOUR SOCKET IS INSECURE !!!',
|
|
this.logName
|
|
);
|
|
logger.warn(
|
|
'[%s] [SECURITY WARNING] You do not have any form of access control configured.',
|
|
this.logName
|
|
);
|
|
logger.warn(
|
|
'[%s] [SECURITY WARNING] In order to secure your control socket, ' +
|
|
'either enable TLS with certificate verification or set a list of authorized IPs.',
|
|
this.logName
|
|
);
|
|
logger.warn(
|
|
'[%s] [SECURITY WARNING] !!! YOUR SOCKET IS INSECURE !!!',
|
|
this.logName
|
|
);
|
|
}
|
|
|
|
this.server = net.createServer((socket) => this.handleIncoming(socket));
|
|
this.server.on('error', (e) => {
|
|
logger.error('[%s] Socket error:', e.message);
|
|
if (!this.alive) reject(e);
|
|
});
|
|
this.server.listen(c.bind, () => {
|
|
logger.log(
|
|
'[%s] Socket listening on %s',
|
|
this.logName,
|
|
c.bind?.toString()
|
|
);
|
|
this.alive = true;
|
|
resolve();
|
|
});
|
|
});
|
|
}
|
|
}
|
|
|
|
const match = ['key', 'cert', 'ca', 'dhparam', 'crl', 'pfx'];
|
|
async function parseTLSConfig(
|
|
tlsconfig: TLSConfiguration
|
|
): Promise<Record<string, unknown>> {
|
|
const result: Record<string, unknown> = {};
|
|
for (const key in tlsconfig) {
|
|
if (key === 'enabled') {
|
|
continue;
|
|
}
|
|
|
|
if (match.indexOf(key) === -1) {
|
|
result[key] = tlsconfig[key];
|
|
continue;
|
|
}
|
|
|
|
const value = path.resolve(tlsconfig[key] as string);
|
|
const bf = await fs.readFile(value);
|
|
result[key] = bf;
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
class SocketPlugin extends Plugin {
|
|
private plugins = new Map<string, any>();
|
|
private servers = new Map<string, SocketServer>();
|
|
|
|
/** Create an unique 8-bit ID */
|
|
public makeID = makeID;
|
|
|
|
/**
|
|
* Kill a socket server by name
|
|
* @param name Socket server name
|
|
*/
|
|
public killServer(name: string) {
|
|
if (!this.servers.has(name)) return;
|
|
this.servers.get(name)?.stop();
|
|
this.servers.delete(name);
|
|
}
|
|
|
|
/**
|
|
* Create a socket server
|
|
* @param config Socket Server configuration
|
|
* @param start Start the socket automatically on creation
|
|
* @returns Socket Server
|
|
*/
|
|
public async createServer(config: SocketConfiguration, start = true) {
|
|
if (this.servers.has(config.name)) return;
|
|
const server = new SocketServer(config);
|
|
this.servers.set(config.name, server);
|
|
if (start) await server.start();
|
|
return server;
|
|
}
|
|
|
|
/**
|
|
* Get an existing socket server by name
|
|
* @param name Socket Server name
|
|
* @returns Existing Socket Server or undefined
|
|
*/
|
|
public getServer(name: string) {
|
|
return this.servers.get(name);
|
|
}
|
|
|
|
@EventListener('pluginUnload')
|
|
public unloadEventHandler(plugin: string | Plugin): void {
|
|
if (plugin === this.name || plugin === this) {
|
|
logger.debug('[%s]', this.name, 'shutting down..');
|
|
|
|
for (const [key] of this.servers) {
|
|
this.killServer(key);
|
|
}
|
|
|
|
this.plugins.clear();
|
|
this.emit('pluginUnloaded', this);
|
|
}
|
|
}
|
|
|
|
@EventListener('pluginUnloaded')
|
|
public unloadedEventHandler(plugin: string | Plugin): void {
|
|
if (typeof plugin !== 'string') {
|
|
plugin = plugin.manifest.name;
|
|
}
|
|
this.plugins.delete(plugin);
|
|
this.killPluginServers(plugin);
|
|
}
|
|
|
|
private killPluginServers(plugin: string) {
|
|
for (const [name, server] of this.servers) {
|
|
if (!server.config.plugin || server.config.plugin !== plugin) continue;
|
|
this.killServer(name);
|
|
}
|
|
}
|
|
}
|
|
|
|
module.exports = SocketPlugin;
|