Commit d39da102 authored by Robert Knight's avatar Robert Knight

Implement token-based authentication for the WebSocket

Supply the access token to the WebSocket via a query param.

This method is used to send the token because the WebSocket constructor
does not allow setting custom headers. See
https://github.com/hypothesis/product-backlog/issues/154 for context.

An alternative that was tried initially was embedding a username and
password in the URL via `wss://user:password@host/` syntax but that
turned out not to be supported by IE/Edge and required the server to fail
the initial request with a 401 response.

Fixes hypothesis/product-backlog#126
parent df778bfe
'use strict'; 'use strict';
var queryString = require('query-string');
var uuid = require('node-uuid'); var uuid = require('node-uuid');
var events = require('./events'); var events = require('./events');
...@@ -19,8 +20,8 @@ var Socket = require('./websocket'); ...@@ -19,8 +20,8 @@ var Socket = require('./websocket');
* @param settings - Application settings * @param settings - Application settings
*/ */
// @ngInject // @ngInject
function Streamer($rootScope, annotationMapper, annotationUI, groups, session, function Streamer($rootScope, annotationMapper, annotationUI, auth,
settings) { groups, session, settings) {
// The randomly generated session UUID // The randomly generated session UUID
var clientId = uuid.v4(); var clientId = uuid.v4();
...@@ -149,11 +150,25 @@ function Streamer($rootScope, annotationMapper, annotationUI, groups, session, ...@@ -149,11 +150,25 @@ function Streamer($rootScope, annotationMapper, annotationUI, groups, session,
} }
var _connect = function () { var _connect = function () {
var url = settings.websocketUrl;
// If we have no URL configured, don't do anything. // If we have no URL configured, don't do anything.
if (!url) { if (!settings.websocketUrl) {
return; return Promise.resolve();
}
return auth.tokenGetter().then(function (token) {
var url;
if (token) {
// Include the access token in the URL via a query param. This method
// is used to send credentials because the `WebSocket` constructor does
// not support setting the `Authorization` header directly as we do for
// other API requests.
var parsedURL = new URL(settings.websocketUrl);
var queryParams = queryString.parse(parsedURL.search);
queryParams.access_token = token;
parsedURL.search = queryString.stringify(queryParams);
url = parsedURL.toString();
} else {
url = settings.websocketUrl;
} }
socket = new Socket(url); socket = new Socket(url);
...@@ -167,14 +182,17 @@ function Streamer($rootScope, annotationMapper, annotationUI, groups, session, ...@@ -167,14 +182,17 @@ function Streamer($rootScope, annotationMapper, annotationUI, groups, session,
messageType: 'client_id', messageType: 'client_id',
value: clientId, value: clientId,
}); });
}).catch(function (err) {
console.error('Failed to fetch token for WebSocket authentication', err);
});
}; };
var connect = function () { var connect = function () {
if (socket) { if (socket) {
return; return Promise.resolve();
} }
_connect(); return _connect();
}; };
var reconnect = function () { var reconnect = function () {
...@@ -182,7 +200,7 @@ function Streamer($rootScope, annotationMapper, annotationUI, groups, session, ...@@ -182,7 +200,7 @@ function Streamer($rootScope, annotationMapper, annotationUI, groups, session,
socket.close(); socket.close();
} }
_connect(); return _connect();
}; };
function applyPendingUpdates() { function applyPendingUpdates() {
......
...@@ -42,9 +42,10 @@ var fixtures = { ...@@ -42,9 +42,10 @@ var fixtures = {
// the most recently created FakeSocket instance // the most recently created FakeSocket instance
var fakeWebSocket = null; var fakeWebSocket = null;
function FakeSocket() { function FakeSocket(url) {
fakeWebSocket = this; // eslint-disable-line consistent-this fakeWebSocket = this; // eslint-disable-line consistent-this
this.url = url;
this.messages = []; this.messages = [];
this.didClose = false; this.didClose = false;
...@@ -67,6 +68,7 @@ inherits(FakeSocket, EventEmitter); ...@@ -67,6 +68,7 @@ inherits(FakeSocket, EventEmitter);
describe('Streamer', function () { describe('Streamer', function () {
var fakeAnnotationMapper; var fakeAnnotationMapper;
var fakeAnnotationUI; var fakeAnnotationUI;
var fakeAuth;
var fakeGroups; var fakeGroups;
var fakeRootScope; var fakeRootScope;
var fakeSession; var fakeSession;
...@@ -79,6 +81,7 @@ describe('Streamer', function () { ...@@ -79,6 +81,7 @@ describe('Streamer', function () {
fakeRootScope, fakeRootScope,
fakeAnnotationMapper, fakeAnnotationMapper,
fakeAnnotationUI, fakeAnnotationUI,
fakeAuth,
fakeGroups, fakeGroups,
fakeSession, fakeSession,
fakeSettings fakeSettings
...@@ -88,6 +91,12 @@ describe('Streamer', function () { ...@@ -88,6 +91,12 @@ describe('Streamer', function () {
beforeEach(function () { beforeEach(function () {
var emitter = new EventEmitter(); var emitter = new EventEmitter();
fakeAuth = {
tokenGetter: function () {
return Promise.resolve('dummy-access-token');
},
};
fakeRootScope = { fakeRootScope = {
$apply: function (callback) { $apply: function (callback) {
callback(); callback();
...@@ -132,9 +141,11 @@ describe('Streamer', function () { ...@@ -132,9 +141,11 @@ describe('Streamer', function () {
it('should not create a websocket connection if websocketUrl is not provided', function () { it('should not create a websocket connection if websocketUrl is not provided', function () {
fakeSettings = {}; fakeSettings = {};
createDefaultStreamer(); createDefaultStreamer();
activeStreamer.connect();
return activeStreamer.connect().then(function () {
assert.isNull(fakeWebSocket); assert.isNull(fakeWebSocket);
}); });
});
it('should not create a websocket connection', function () { it('should not create a websocket connection', function () {
createDefaultStreamer(); createDefaultStreamer();
...@@ -148,44 +159,79 @@ describe('Streamer', function () { ...@@ -148,44 +159,79 @@ describe('Streamer', function () {
it('should send the client ID on connection', function () { it('should send the client ID on connection', function () {
createDefaultStreamer(); createDefaultStreamer();
activeStreamer.connect(); return activeStreamer.connect().then(function () {
assert.equal(fakeWebSocket.messages.length, 1); assert.equal(fakeWebSocket.messages.length, 1);
assert.equal(fakeWebSocket.messages[0].messageType, 'client_id'); assert.equal(fakeWebSocket.messages[0].messageType, 'client_id');
assert.equal(fakeWebSocket.messages[0].value, activeStreamer.clientId); assert.equal(fakeWebSocket.messages[0].value, activeStreamer.clientId);
}); });
});
describe('#connect()', function () { describe('#connect()', function () {
it('should create a websocket connection', function () { it('should create a websocket connection', function () {
createDefaultStreamer(); createDefaultStreamer();
activeStreamer.connect(); return activeStreamer.connect().then(function () {
assert.ok(fakeWebSocket); assert.ok(fakeWebSocket);
}); });
});
it('should include credentials in the URL if the client has an access token', function () {
createDefaultStreamer();
return activeStreamer.connect().then(function () {
assert.equal(fakeWebSocket.url, 'ws://example.com/ws?access_token=dummy-access-token');
});
});
it('should preserve query params when adding access token to URL', function () {
fakeSettings.websocketUrl = 'ws://example.com/ws?foo=bar';
createDefaultStreamer();
return activeStreamer.connect().then(function () {
assert.equal(fakeWebSocket.url, 'ws://example.com/ws?access_token=dummy-access-token&foo=bar');
});
});
it('should not include credentials in the URL if the client has no access token', function () {
fakeAuth.tokenGetter = function () {
return Promise.resolve(null);
};
createDefaultStreamer();
return activeStreamer.connect().then(function () {
assert.equal(fakeWebSocket.url, 'ws://example.com/ws');
});
});
it('should not close any existing socket', function () { it('should not close any existing socket', function () {
var oldWebSocket;
createDefaultStreamer(); createDefaultStreamer();
activeStreamer.connect(); return activeStreamer.connect().then(function () {
var oldWebSocket = fakeWebSocket; oldWebSocket = fakeWebSocket;
activeStreamer.connect(); return activeStreamer.connect();
}).then(function () {
assert.ok(!oldWebSocket.didClose); assert.ok(!oldWebSocket.didClose);
assert.ok(!fakeWebSocket.didClose); assert.ok(!fakeWebSocket.didClose);
}); });
}); });
});
describe('#reconnect()', function () { describe('#reconnect()', function () {
it('should close the existing socket', function () { it('should close the existing socket', function () {
var oldWebSocket;
createDefaultStreamer(); createDefaultStreamer();
activeStreamer.connect();
var oldWebSocket = fakeWebSocket; return activeStreamer.connect().then(function () {
activeStreamer.reconnect(); oldWebSocket = fakeWebSocket;
return activeStreamer.reconnect();
}).then(function () {
assert.ok(oldWebSocket.didClose); assert.ok(oldWebSocket.didClose);
assert.ok(!fakeWebSocket.didClose); assert.ok(!fakeWebSocket.didClose);
}); });
}); });
});
describe('annotation notifications', function () { describe('annotation notifications', function () {
beforeEach(function () { beforeEach(function () {
createDefaultStreamer(); createDefaultStreamer();
activeStreamer.connect(); return activeStreamer.connect();
}); });
context('when the app is the stream', function () { context('when the app is the stream', function () {
...@@ -271,7 +317,7 @@ describe('Streamer', function () { ...@@ -271,7 +317,7 @@ describe('Streamer', function () {
describe('#applyPendingUpdates', function () { describe('#applyPendingUpdates', function () {
beforeEach(function () { beforeEach(function () {
createDefaultStreamer(); createDefaultStreamer();
activeStreamer.connect(); return activeStreamer.connect();
}); });
it('applies pending updates', function () { it('applies pending updates', function () {
...@@ -307,7 +353,7 @@ describe('Streamer', function () { ...@@ -307,7 +353,7 @@ describe('Streamer', function () {
beforeEach(function () { beforeEach(function () {
createDefaultStreamer(); createDefaultStreamer();
activeStreamer.connect(); return activeStreamer.connect();
}); });
unroll('discards pending updates when #event occurs', function (testCase) { unroll('discards pending updates when #event occurs', function (testCase) {
...@@ -330,19 +376,19 @@ describe('Streamer', function () { ...@@ -330,19 +376,19 @@ describe('Streamer', function () {
describe('when the focused group changes', function () { describe('when the focused group changes', function () {
it('clears pending updates and deletions', function () { it('clears pending updates and deletions', function () {
createDefaultStreamer(); createDefaultStreamer();
activeStreamer.connect(); return activeStreamer.connect().then(function () {
fakeWebSocket.notify(fixtures.createNotification); fakeWebSocket.notify(fixtures.createNotification);
fakeRootScope.$broadcast(events.GROUP_FOCUSED); fakeRootScope.$broadcast(events.GROUP_FOCUSED);
assert.equal(activeStreamer.countPendingUpdates(), 0); assert.equal(activeStreamer.countPendingUpdates(), 0);
}); });
}); });
});
describe('session change notifications', function () { describe('session change notifications', function () {
it('updates the session when a notification is received', function () { it('updates the session when a notification is received', function () {
createDefaultStreamer(); createDefaultStreamer();
activeStreamer.connect(); return activeStreamer.connect().then(function () {
var model = { var model = {
groups: [{ groups: [{
id: 'new-group', id: 'new-group',
...@@ -355,15 +401,17 @@ describe('Streamer', function () { ...@@ -355,15 +401,17 @@ describe('Streamer', function () {
assert.ok(fakeSession.update.calledWith(model)); assert.ok(fakeSession.update.calledWith(model));
}); });
}); });
});
describe('reconnections', function () { describe('reconnections', function () {
it('resends configuration messages when a reconnection occurs', function () { it('resends configuration messages when a reconnection occurs', function () {
createDefaultStreamer(); createDefaultStreamer();
activeStreamer.connect(); return activeStreamer.connect().then(function () {
fakeWebSocket.messages = []; fakeWebSocket.messages = [];
fakeWebSocket.emit('open'); fakeWebSocket.emit('open');
assert.equal(fakeWebSocket.messages.length, 1); assert.equal(fakeWebSocket.messages.length, 1);
assert.equal(fakeWebSocket.messages[0].messageType, 'client_id'); assert.equal(fakeWebSocket.messages[0].messageType, 'client_id');
}); });
}); });
});
}); });
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment