示例如下,需要注意的是,
- 类上需要打上@Scope("prototype")注解,否则socket就是单例的
- socket的configurator属性的配置类需要实现ApplicationContextAware,和代码里面一样的配置之后,server里面才能成功注入spring管理的对象
package com.xiaogang.websocketdemo.web.socket;
import com.alibaba.fastjson.JSON;
import com.xiaogang.websocketdemo.config.ClickServerEndpointConf;
import com.xiaogang.websocketdemo.dto.ApiResult;
import com.xiaogang.websocketdemo.dto.UserActionCount;
import com.xiaogang.websocketdemo.runnable.SendAllUserClickMsg;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Scope;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ThreadPoolExecutor;
/**
* @author xiaogang
* @date 2018/11/30 11:12
*/
@Slf4j
@Component
@Scope("prototype")
@ServerEndpoint(value = "/socket/order/msg/{userName}",configurator = ClickServerEndpointConf.class)
@Order(11111)
public class ClickServer {
@Autowired
private volatile ThreadPoolExecutor threadPoolExecutor;
public ClickServer() {
System.out.println("ClickServer.ClickServer");
}
/**
* 和某个客户端会话的唯一关联
*/
private volatile Session session;
/**
* 整个ClickServer的全局对象
*/
// private ServerEndpointConfig serverEndpointConfig;
public static Set<Session> sessions = new HashSet<>();
private String userName;
@OnOpen
public void onOpen(Session session,@PathParam("userName")String userName) throws IOException {
ApiResult result = new ApiResult();
this.session = session;
this.userName = userName;
log.info("用户:{}链接成功",userName);
try {
validUserUnique(userName);
} catch (Exception e) {
e.printStackTrace();
result.setCode(0);
result.setMsg(e.getMessage());
String jsonString = JSON.toJSONString(result);
this.session.getBasicRemote().sendText(jsonString);
this.session.close();
return;
}
// this.serverEndpointConfig = (ServerEndpointConfig) endpointConfig;
Map<String, Object> userProperties = this.session.getUserProperties();
UserActionCount userActionCount = new UserActionCount();
userActionCount.setUserName(userName);
userActionCount.setActionCount(0);
userProperties.put("UserActionCount", userActionCount);
sessions.add(session);
}
private void validUserUnique(String userName) {
Assert.isTrue(!StringUtils.isEmpty(userName),"用户名不能为空");
Assert.isTrue(!(userName.length() > 6),"用户名长度不能超过6个字符");
for (Session openSession : sessions) {
UserActionCount userActionCount = (UserActionCount) openSession.getUserProperties().get("UserActionCount");
String existuserName = userActionCount.getUserName();
Assert.isTrue(!existuserName.equals(userName),"用户已经报名,无法重复报名");
}
}
@OnMessage
public void onMessage(String msg) throws IOException {
log.info("接收到消息:{}",msg);
if (msg.equals("restart")) {
for (Session session : sessions) {
Map<String, Object> userProperties = session.getUserProperties();
UserActionCount userActionCount = (UserActionCount) userProperties.get("UserActionCount");
userActionCount.setActionCount(0);
userProperties.put("UserActionCount",userActionCount);
ApiResult result = new ApiResult();
result.setCode(1);
result.setMsg("ok");
result.setData(new ArrayList());
String jsonString = JSON.toJSONString(result);
session.getBasicRemote().sendText(jsonString);
}
for (Session session : sessions) {
session.close();
}
}else{
addCount();
sendAllUserClickData();
}
}
@OnClose
public void onClose() throws IOException {
if (this.session == null) {
sendAllUserClickData();
return;
}
boolean open = this.session.isOpen();
if (open) {
this.session.close();
}
sessions.remove(this.session);
sendAllUserClickData();
}
@OnError
public void onError(Throwable throwable) throws IOException {
throwable.printStackTrace();
boolean open = this.session.isOpen();
if (open) {
this.session.close();
}
}
private void sendAllUserClickData() throws IOException {
log.info("给所有用户发送所有用户的点击次数");
Set<UserActionCount> userActionCounts = obtainAllUserClickData();
SendAllUserClickMsg sendAllUserClickMsg = new SendAllUserClickMsg(sessions, userActionCounts);
threadPoolExecutor.execute(sendAllUserClickMsg);
}
private void addCount() {
log.info("为用户:{}添加一次点击次数",userName);
UserActionCount userActionCount = (UserActionCount) this.session.getUserProperties().get("UserActionCount");
int actionCount = userActionCount.getActionCount();
userActionCount.setActionCount(actionCount + 1);
this.session.getUserProperties().put("UserActionCount",userActionCount);
log.info("点击次数添加完成");
}
private Set<UserActionCount> obtainAllUserClickData() throws IOException {
log.info("获取所有用户的点击数据");
Set<UserActionCount> allUserClickData = new HashSet();
for (Session openSession : sessions) {
if (!openSession.isOpen()) {
openSession.close();
sessions.remove(openSession);
}
UserActionCount userActionCount = (UserActionCount) openSession.getUserProperties().get("UserActionCount");
allUserClickData.add(userActionCount);
}
log.info("所有用户的点击次数获取完成:{}",allUserClickData);
return allUserClickData;
}
}
package com.xiaogang.websocketdemo.config;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Configuration;
import javax.websocket.server.ServerEndpointConfig.Configurator;
/**
* @author xiaogang
* @date 2018/11/30 16:03
*/
@Configuration
public class ClickServerEndpointConf extends Configurator implements ApplicationContextAware{
private static volatile ApplicationContext applicationContext;
@Override
public <T> T getEndpointInstance(Class<T> clazz) throws InstantiationException {
T bean = applicationContext.getBean(clazz);
return bean;
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
ClickServerEndpointConf.applicationContext = applicationContext;
}
}