Spring Boot 2系列(四十五):RestTemplate 源码分析与自定义请求和拦截器

Saas 项目分布式微服务架构,服务调用使用的是 RestTemplate,并且对 RestTemplate 的 Request 请求进行了自定义,做个记录。

自定义 Request 有很多作用。例如自定义请求实现安全认证,自定义请求拦截器实现负载均衡或请求代理等,可以非常灵活的做些定制化。

RestTemplate 相关文章:Spring Boot 2实践系列(二十一):RestTemplate 远程调用 REST 服务Spring Cloud系列(四):客户端负载均衡 Ribbon

源码分析

RestTemplate 不是一个独立的类,直接父类是 InterceptingHttpAccessor,顶级父类是 HttpAccessor 。

InterceptingHttpAccessor 是 RestTemplate 和其它 HTT 访问网关助手的基类,作用是将 拦截器 相关属性添加到父类 HttpAccessor 的公共属性中。

HttpAccessorInterceptingHttpAccessor 的直接父类,HttpAccessor 同样是 RestTemplate 和其它 HTT 访问网关助手的基类,作用是定义用于操作的公共属性(例如 ClientHttpRequestFactory)。

ClientHttpRequest

RestTemplate操作

RestTemplate 所有执行请求的操作调用的是 execute(),最终调用的都是 doExecute() 方法,在该方法里创建 Http Request。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@Override
@Nullable
public <T> T getForObject(String url, Class<T> responseType, Object... uriVariables) throws RestClientException {
RequestCallback requestCallback = acceptHeaderRequestCallback(responseType);
HttpMessageConverterExtractor<T> responseExtractor =
new HttpMessageConverterExtractor<>(responseType, getMessageConverters(), logger);
// 调用 execute() 方法
return execute(url, HttpMethod.GET, requestCallback, responseExtractor, uriVariables);
}

@Override
@Nullable
public <T> T execute(String url, HttpMethod method, @Nullable RequestCallback requestCallback, @Nullable ResponseExtractor<T> responseExtractor, Object... uriVariables) throws RestClientException {
URI expanded = getUriTemplateHandler().expand(url, uriVariables);
// 调用 doExecut 方法
return doExecute(expanded, method, requestCallback, responseExtractor);
}

@Nullable
protected <T> T doExecute(URI url, @Nullable HttpMethod method, @Nullable RequestCallback requestCallback,
@Nullable ResponseExtractor<T> responseExtractor) throws RestClientException {

Assert.notNull(url, "URI is required");
Assert.notNull(method, "HttpMethod is required");
ClientHttpResponse response = null;
try {
//创建请求
ClientHttpRequest request = createRequest(url, method);
if (requestCallback != null) {
//请求回调
requestCallback.doWithRequest(request);
}
// 执行请求
response = request.execute();
// 处理响应
handleResponse(url, method, response);
// 反回响应类型的数据
return (responseExtractor != null ? responseExtractor.extractData(response) : null);
}
catch (IOException ex) {
String resource = url.toString();
String query = url.getRawQuery();
resource = (query != null ? resource.substring(0, resource.indexOf('?')) : resource);
throw new ResourceAccessException("I/O error on " + method.name() +
" request for \"" + resource + "\": " + ex.getMessage(), ex);
}
finally {
if (response != null) {
response.close();
}
}
}

ClientHttpRequest

在 RestTemplate 的 doExecute() 方法中创建了 request(ClientHttpRequest request = createRequest(url, method))。

  1. 创建 request 调用的是顶级抽象父类的 createRequest() 方法

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    public abstract class HttpAccessor {

    protected final Log logger = HttpLogging.forLogName(getClass());
    //默认的 request factory:ClientHttpRequestFactory
    private ClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();

    /**
    * 设置 request factory 工厂
    * 这个方法在自定义请求时会用到
    */
    public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
    Assert.notNull(requestFactory, "ClientHttpRequestFactory must not be null");
    this.requestFactory = requestFactory;
    }

    /**
    * 返回 request 工厂
    */
    public ClientHttpRequestFactory getRequestFactory() {
    return this.requestFactory;
    }

    /**
    * 通过 ClientHttpRequestFactory 创建 ClientHttpRequest
    */
    protected ClientHttpRequest createRequest(URI url, HttpMethod method) throws IOException {
    // 获取请求工厂创建 Request
    // 注意:这里调的是子类 InterceptingHttpAccessor 重写的 getRequestFactory() 方法
    // 实际最终调的仍是顶级父类 HttpAccessor 的方法
    ClientHttpRequest request = getRequestFactory().createRequest(url, method);
    if (logger.isDebugEnabled()) {
    logger.debug("HTTP " + method.name() + " " + url);
    }
    return request;
    }
    }

InterceptingHttpAccessor

InterceptingHttpAccessor 是 RestTemplate 的直接父类,提供拦截器相关的设置。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
public abstract class InterceptingHttpAccessor extends HttpAccessor {

private final List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>();

@Nullable
private volatile ClientHttpRequestFactory interceptingRequestFactory;

/**
* 设置拦截器
*/
public void setInterceptors(List<ClientHttpRequestInterceptor> interceptors) {
if (this.interceptors != interceptors) {
this.interceptors.clear();
this.interceptors.addAll(interceptors);
AnnotationAwareOrderComparator.sort(this.interceptors);
}
}

/**
* 获取拦截器
*/
public List<ClientHttpRequestInterceptor> getInterceptors() {
return this.interceptors;
}

/**
* 设置请求工厂
*/
@Override
public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
super.setRequestFactory(requestFactory);
this.interceptingRequestFactory = null;
}

/**
* 重写抽象父类(HttpAccessor)的方法
*/
@Override
public ClientHttpRequestFactory getRequestFactory() {
List<ClientHttpRequestInterceptor> interceptors = getInterceptors();
if (!CollectionUtils.isEmpty(interceptors)) {
//如果存在拦截器
ClientHttpRequestFactory factory = this.interceptingRequestFactory;
if (factory == null) {
factory = new InterceptingClientHttpRequestFactory(super.getRequestFactory(), interceptors);
this.interceptingRequestFactory = factory;
}
return factory;
}
else {
//没有拦截器, 调父类(HttpAccessor)的方法
return super.getRequestFactory();
}
}
}

InterceptingClientHttpRequestFactory

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public class InterceptingClientHttpRequestFactory extends AbstractClientHttpRequestFactoryWrapper {

private final List<ClientHttpRequestInterceptor> interceptors;

/**
* 构造方法赋值
*/
public InterceptingClientHttpRequestFactory(ClientHttpRequestFactory requestFactory,
@Nullable List<ClientHttpRequestInterceptor> interceptors) {

super(requestFactory);
this.interceptors = (interceptors != null ? interceptors : Collections.emptyList());
}

/**
* 创建请求
*/
@Override
protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, ClientHttpRequestFactory requestFactory) {
return new InterceptingClientHttpRequest(requestFactory, this.interceptors, uri, httpMethod);
}
}

SimpleClientHttpRequestFactory

SimpleClientHttpRequestFactory 是 ClientHttpRequestFactory 默认的简单实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
public class SimpleClientHttpRequestFactory implements ClientHttpRequestFactory, AsyncClientHttpRequestFactory {

private static final int DEFAULT_CHUNK_SIZE = 4096;


@Nullable
private Proxy proxy;

private boolean bufferRequestBody = true;

private int chunkSize = DEFAULT_CHUNK_SIZE;

private int connectTimeout = -1;

private int readTimeout = -1;

private boolean outputStreaming = true;

@Nullable
private AsyncListenableTaskExecutor taskExecutor;


/**
* 设置 Request Factory 需要的代理
*/
public void setProxy(Proxy proxy) {
this.proxy = proxy;
}

/**
* 请求工厂是否应在内部缓存 ClientHttpRequest#getBody() request body
* 默认为 true,当使用 POST or PUT 发送大量数据时建议改为 false,以免耗尽内存
*
*/
public void setBufferRequestBody(boolean bufferRequestBody) {
this.bufferRequestBody = bufferRequestBody;
}

/**
* 当本地没有缓存请求正文时,设置每个块要写入的字节数
*/
public void setChunkSize(int chunkSize) {
this.chunkSize = chunkSize;
}

/**
* 设置链接超时(耗秒)
*/
public void setConnectTimeout(int connectTimeout) {
this.connectTimeout = connectTimeout;
}

/**
* 设置读取超时(耗秒)
*/
public void setReadTimeout(int readTimeout) {
this.readTimeout = readTimeout;
}

/**
* 是否设置为输出模式,默认 true。
*/
public void setOutputStreaming(boolean outputStreaming) {
this.outputStreaming = outputStreaming;
}

/**
* 为请求工厂设置任务执行器, 在创建异步请求时需要
*/
public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) {
this.taskExecutor = taskExecutor;
}

/**
* 创建请求
* 底层调的是 java.net 的 HttpURLConnection,继承自 URLConnection
*/
@Override
public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
HttpURLConnection connection = openConnection(uri.toURL(), this.proxy);
prepareConnection(connection, httpMethod.name());

if (this.bufferRequestBody) {
return new SimpleBufferingClientHttpRequest(connection, this.outputStreaming);
}
else {
return new SimpleStreamingClientHttpRequest(connection, this.chunkSize, this.outputStreaming);
}
}

/**
* 创建异步请求
*/
@Override
public AsyncClientHttpRequest createAsyncRequest(URI uri, HttpMethod httpMethod) throws IOException {
Assert.state(this.taskExecutor != null, "Asynchronous execution requires TaskExecutor to be set");

HttpURLConnection connection = openConnection(uri.toURL(), this.proxy);
prepareConnection(connection, httpMethod.name());

if (this.bufferRequestBody) {
return new SimpleBufferingAsyncClientHttpRequest(
connection, this.outputStreaming, this.taskExecutor);
}
else {
return new SimpleStreamingAsyncClientHttpRequest(
connection, this.chunkSize, this.outputStreaming, this.taskExecutor);
}
}

/**
* 打开并返回URL的连接。
* 默认实现使用了 setProxy(java.net.Proxy) proxy, 如果 proxy 存在的话
*/
protected HttpURLConnection openConnection(URL url, @Nullable Proxy proxy) throws IOException {
URLConnection urlConnection = (proxy != null ? url.openConnection(proxy) : url.openConnection());
if (!HttpURLConnection.class.isInstance(urlConnection)) {
throw new IllegalStateException("HttpURLConnection required for [" + url + "] but got: " + urlConnection);
}
return (HttpURLConnection) urlConnection;
}

/**
* 连接准备:设置一些参数,如 connectTimeout、readTimeout, Http 请求方法
*/
protected void prepareConnection(HttpURLConnection connection, String httpMethod) throws IOException {
if (this.connectTimeout >= 0) {
//连接超时
connection.setConnectTimeout(this.connectTimeout);
}
if (this.readTimeout >= 0) {
//读超时
connection.setReadTimeout(this.readTimeout);
}
//开启从连接读取数据
connection.setDoInput(true);

if ("GET".equals(httpMethod)) {
//开启自动重定向
connection.setInstanceFollowRedirects(true);
}
else {
connection.setInstanceFollowRedirects(false);
}

if ("POST".equals(httpMethod) || "PUT".equals(httpMethod) ||
"PATCH".equals(httpMethod) || "DELETE".equals(httpMethod)) {
//开启向连接写入数据
connection.setDoOutput(true);
}
else {
connection.setDoOutput(false);
}
// 设置请求类型
connection.setRequestMethod(httpMethod);
}
}

ClientHttpRequestFactory

ClientHttpRequestFactory 是个 request 抽象工厂接口,支持多种实现和自定义实现,如 默认的 SimpleClientHttpRequestFactory,OkHttp3ClientHttpRequestFactory,HttpComponentsClientHttpRequestFactory 等。

  1. RestTemplate 提供了传入 ClientHttpRequestFactory 类型参数的构造方法,为自定义 Request 创造了条件。

    即自定义 ClientHttpRequestFactory 实现,重写 createRequest() 方法。

    1
    2
    3
    4
    5
    6
    public RestTemplate(ClientHttpRequestFactory requestFactory) {
    this();
    //最终调的是父类 HttpAccessor 的 setRequestFactory() 方法,设置了 requestFactory 属性
    //覆盖了默认的 new SimpleClientHttpRequestFactory()
    setRequestFactory(requestFactory);
    }

    RestTemplate 的直接父类是 InterceptingHttpAccessor,为请求拦截器提供支持,顶级父类是 HttpAccessor。

    InterceptingHttpAccessor 重写了 HttpAccessor 的 setRequestFactory() 方法,但仍调的是父类的方法。

    1
    2
    3
    4
    5
    6
    @Override
    public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
    // 调父类
    super.setRequestFactory(requestFactory);
    this.interceptingRequestFactory = null;
    }

    HttpAccessor 的 setRequestFactory() 方法来覆盖默认的实现

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    public abstract class HttpAccessor {
    protected final Log logger = HttpLogging.forLogName(getClass());

    private ClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();

    public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
    Assert.notNull(requestFactory, "ClientHttpRequestFactory must not be null");
    this.requestFactory = requestFactory;
    }

    //........省略.........
    }
  2. RestTemplateBuilder 链式调用提供了 requestFactory() 方法来传入特定的 ClientHttpRequestFactory。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    public RestTemplateBuilder requestFactory(Class<? extends ClientHttpRequestFactory> requestFactory) {
    Assert.notNull(requestFactory, "RequestFactory must not be null");
    return requestFactory(() -> createRequestFactory(requestFactory));
    }

    /**
    * 获取传入的 ClientHttpRequestFactory 实例
    */
    private ClientHttpRequestFactory createRequestFactory(Class<? extends ClientHttpRequestFactory> requestFactory) {
    try {
    //通过返射调用构造方法创建实例
    Constructor<?> constructor = requestFactory.getDeclaredConstructor();
    constructor.setAccessible(true);
    return (ClientHttpRequestFactory) constructor.newInstance();
    }
    catch (Exception ex) {
    throw new IllegalStateException(ex);
    }
    }

    /**
    * 返回设置了属性的 RestTemplateBuilder 实例
    * 注意传入的 requestFactorySupplier
    * (Supplier 函数型接口,供应数据)
    */
    public RestTemplateBuilder requestFactory(Supplier<ClientHttpRequestFactory> requestFactorySupplier) {
    Assert.notNull(requestFactorySupplier, "RequestFactory Supplier must not be null");
    return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, this.messageConverters,
    requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication,
    this.restTemplateCustomizers, this.requestFactoryCustomizer, this.interceptors);
    }

    /**
    * 创建 RestTemplate 实例
    */
    public RestTemplate build() {
    return build(RestTemplate.class);
    }
    /**
    * 创建指定类型的 RestTemplate
    */
    public <T extends RestTemplate> T build(Class<T> restTemplateClass) {
    return configure(BeanUtils.instantiateClass(restTemplateClass));
    }

    /**
    * 配置 RestTemplate
    * messageConverters,errorHandler ....
    *
    */
    public <T extends RestTemplate> T configure(T restTemplate) {
    //配置 ClientHttpRequestFactory
    configureRequestFactory(restTemplate);
    if (!CollectionUtils.isEmpty(this.messageConverters)) {
    restTemplate.setMessageConverters(new ArrayList<>(this.messageConverters));
    }
    if (this.uriTemplateHandler != null) {
    restTemplate.setUriTemplateHandler(this.uriTemplateHandler);
    }
    if (this.errorHandler != null) {
    restTemplate.setErrorHandler(this.errorHandler);
    }
    if (this.rootUri != null) {
    RootUriTemplateHandler.addTo(restTemplate, this.rootUri);
    }
    if (this.basicAuthentication != null) {
    restTemplate.getInterceptors().add(this.basicAuthentication);
    }
    //添加拦截器
    restTemplate.getInterceptors().addAll(this.interceptors);
    if (!CollectionUtils.isEmpty(this.restTemplateCustomizers)) {
    for (RestTemplateCustomizer customizer : this.restTemplateCustomizers) {
    customizer.customize(restTemplate);
    }
    }
    return restTemplate;
    }

    /**
    * 配置 ClientHttpRequestFactory
    */
    private void configureRequestFactory(RestTemplate restTemplate) {
    ClientHttpRequestFactory requestFactory = null;
    if (this.requestFactorySupplier != null) {
    //获取传入的 ClientHttpRequestFactory
    requestFactory = this.requestFactorySupplier.get();
    }
    else if (this.detectRequestFactory) {
    requestFactory = new ClientHttpRequestFactorySupplier().get();
    }
    if (requestFactory != null) {
    if (this.requestFactoryCustomizer != null) {
    this.requestFactoryCustomizer.accept(requestFactory);
    }
    restTemplate.setRequestFactory(requestFactory);
    }
    }

ClientHttpRequestInterceptor

ClientHttpRequestInterceptor 是函数型接口,用于拦截客户端的 HTTP 请求,此接口的实现可以被 RestTemplate # setInterceptors 方法注册,用于修改输出输出。

1
2
3
4
5
6
7
8
9
10
11
@FunctionalInterface
public interface ClientHttpRequestInterceptor {

/**
* 拦截请求,并返回响应。
* ClientHttpRequestExecution 允许拦截器将请求和响应传递给链中的下一个实体
*/
ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution)
throws IOException;

}

InterceptingHttpAccessor

RestTemplate 的抽象父类 InterceptingHttpAccessor 提供了需传入 ClientHttpRequestInterceptor 类型参数的 setInterceptors() 方法用于设置自定义的拦截器,即自定义 HTTP 请求拦截器需实现 ClientHttpRequestInterceptor 接口。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
public abstract class InterceptingHttpAccessor extends HttpAccessor {

private final List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>();

@Nullable
private volatile ClientHttpRequestFactory interceptingRequestFactory;


/**
* 设置 ClientHttpRequestInterceptor 类型的拦载器
* 在 ClientHttpRequestFactory 的 getRequestFactory() 方法中用到
*/
public void setInterceptors(List<ClientHttpRequestInterceptor> interceptors) {
// Take getInterceptors() List as-is when passed in here
if (this.interceptors != interceptors) {
this.interceptors.clear();
this.interceptors.addAll(interceptors);
AnnotationAwareOrderComparator.sort(this.interceptors);
}
}

/**
* 返回拦截器
*/
public List<ClientHttpRequestInterceptor> getInterceptors() {
return this.interceptors;
}

/**
* 设置请求工厂 RequestFactory
*/
@Override
public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
super.setRequestFactory(requestFactory);
this.interceptingRequestFactory = null;
}

/**
* 获取 ClientHttpRequestFactory 类型实例
* 如果存在拦截器,则创建 InterceptingClientHttpRequestFactory 请求工厂实例,
* 内部实际调的仍是顶级父类 HttpAccessor 的 getRequestFactory() 方法
*/
@Override
public ClientHttpRequestFactory getRequestFactory() {
List<ClientHttpRequestInterceptor> interceptors = getInterceptors();
if (!CollectionUtils.isEmpty(interceptors)) {
ClientHttpRequestFactory factory = this.interceptingRequestFactory;
if (factory == null) {
//传入 interceptors,调用父类的
factory = new InterceptingClientHttpRequestFactory(super.getRequestFactory(), interceptors);
this.interceptingRequestFactory = factory;
}
return factory;
}
else {
return super.getRequestFactory();
}
}

}

InterceptingClientHttpRequestFactory

带拦截器的 HTTP Request Factory,创建包含拦截器的请求。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public class InterceptingClientHttpRequestFactory extends AbstractClientHttpRequestFactoryWrapper {

private final List<ClientHttpRequestInterceptor> interceptors;

/**
* 创建带拦截器的 HTTP 请求工厂
*/
public InterceptingClientHttpRequestFactory(ClientHttpRequestFactory requestFactory,
@Nullable List<ClientHttpRequestInterceptor> interceptors) {

super(requestFactory);
this.interceptors = (interceptors != null ? interceptors : Collections.emptyList());
}

/**
* 创建带拦截器的 HTTP 的请求
*/
@Override
protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, ClientHttpRequestFactory requestFactory) {
return new InterceptingClientHttpRequest(requestFactory, this.interceptors, uri, httpMethod);
}
}

InterceptingClientHttpRequest

这是一个支持请求拦截器的 Request

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class InterceptingClientHttpRequest extends AbstractBufferingClientHttpRequest {

private final ClientHttpRequestFactory requestFactory;
//拦载
private final List<ClientHttpRequestInterceptor> interceptors;

private HttpMethod method;

private URI uri;


protected InterceptingClientHttpRequest(ClientHttpRequestFactory requestFactory,
List<ClientHttpRequestInterceptor> interceptors, URI uri, HttpMethod method) {

this.requestFactory = requestFactory;
this.interceptors = interceptors;
this.method = method;
this.uri = uri;
}


@Override
public HttpMethod getMethod() {
return this.method;
}

@Override
public String getMethodValue() {
return this.method.name();
}

@Override
public URI getURI() {
return this.uri;
}

@Override
protected final ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException {
//内部类
InterceptingRequestExecution requestExecution = new InterceptingRequestExecution();
//执行请求
return requestExecution.execute(this, bufferedOutput);
}

/**
* 内部类, execute() 方法执行请求
*/
private class InterceptingRequestExecution implements ClientHttpRequestExecution {
//请求拦截器的迭代器
private final Iterator<ClientHttpRequestInterceptor> iterator;

public InterceptingRequestExecution() {
this.iterator = interceptors.iterator();
}

@Override
public ClientHttpResponse execute(HttpRequest request, byte[] body) throws IOException {
if (this.iterator.hasNext()) {
// ClientHttpRequestInterceptor 是个接口
// 如果存在拦截器, 则调用请求拦截器的 intercept() 方法
ClientHttpRequestInterceptor nextInterceptor = this.iterator.next();
// 从这里可以看出, 自定义的的请求拦截器还需要实现 intercept() 方法
// 在自定义重写的 intercept() 方法中, 使用的是 this 对象 调用 execute() 来发送请求
// 注意:这里传入了一个 this 对象,存在遍历调用直到所有拦截器全部处理完,再执行 else 的操作
return nextInterceptor.intercept(request, body, this);
}
else {
HttpMethod method = request.getMethod();
Assert.state(method != null, "No standard HTTP method");
//创建请求
ClientHttpRequest delegate = requestFactory.createRequest(request.getURI(), method);
request.getHeaders().forEach((key, value) -> delegate.getHeaders().addAll(key, value));
if (body.length > 0) {
if (delegate instanceof StreamingHttpOutputMessage) {
StreamingHttpOutputMessage streamingOutputMessage = (StreamingHttpOutputMessage) delegate;
streamingOutputMessage.setBody(outputStream -> StreamUtils.copy(body, outputStream));
}
else {
StreamUtils.copy(body, delegate.getBody());
}
}
//执行请求
return delegate.execute();
}
}
}
}

示例:Spring 提供的 BasicAuthenticationInterceptor 实现认证拦截,BasicAuthenticationInterceptor 实现 ClientHttpRequestInterceptor 接口,重写 intercept() 方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
public class BasicAuthenticationInterceptor implements ClientHttpRequestInterceptor {

private final String username;

private final String password;

@Nullable
private final Charset charset;

public BasicAuthenticationInterceptor(String username, String password) {
this(username, password, null);
}

public BasicAuthenticationInterceptor(String username, String password, @Nullable Charset charset) {
Assert.doesNotContain(username, ":", "Username must not contain a colon");
this.username = username;
this.password = password;
this.charset = charset;
}

/**
* 实现 intercept() 接口
*/
@Override
public ClientHttpResponse intercept(
HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException {

HttpHeaders headers = request.getHeaders();
if (!headers.containsKey(HttpHeaders.AUTHORIZATION)) {
headers.setBasicAuth(this.username, this.password, this.charset);
}
// 这里调用 execute() 方法又循环回到了 InterceptingClientHttpRequest 的 execute() 方法
return execution.execute(request, body);
}
}

RestTemplateBuilder

RestTemplateBuilder 提供了可传入请求拦截器的构造方法interceptors() 方法,在调用 build() 方法创建 RestTemplate 实例执行配置时将请求拦截器设置到 restTemplate 的 interceptors 属性中(实际是父类 InterceptingHttpAccessor 中的属性)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
public class RestTemplateBuilder {
//.....省略其它属性......

private final Set<ClientHttpRequestInterceptor> interceptors;

//.....省略构造方法......

/**
* 指定消息转换器(所以可以自定议消息转换器)
*/
public RestTemplateBuilder messageConverters(
HttpMessageConverter<?>... messageConverters) {
Assert.notNull(messageConverters, "MessageConverters must not be null");
return messageConverters(Arrays.asList(messageConverters));
}

/**
* 默认的消息转换器
*/
public RestTemplateBuilder defaultMessageConverters() {
return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri,
Collections.unmodifiableSet(
new LinkedHashSet<>(new RestTemplate().getMessageConverters())),
this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler,
this.basicAuthentication, this.restTemplateCustomizers,
this.requestFactoryCustomizer, this.interceptors);
}

/**
* 指定的拦截器(会替换所有先前定义的拦截器)
*/
public RestTemplateBuilder interceptors(
ClientHttpRequestInterceptor... interceptors) {
Assert.notNull(interceptors, "interceptors must not be null");
return interceptors(Arrays.asList(interceptors));
}

/**
* 被上面方法调用
*/
public RestTemplateBuilder interceptors(
Collection<ClientHttpRequestInterceptor> interceptors) {
Assert.notNull(interceptors, "interceptors must not be null");
return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri,
this.messageConverters, this.requestFactorySupplier,
this.uriTemplateHandler, this.errorHandler, this.basicAuthentication,
this.restTemplateCustomizers, this.requestFactoryCustomizer,
Collections.unmodifiableSet(new LinkedHashSet<>(interceptors)));
}

/**
* build() 方法构建指定类型的 RestTemplate ,并进行配置
*/
public <T extends RestTemplate> T build(Class<T> restTemplateClass) {
//通过返射创建 RestTemplate 实例
return configure(BeanUtils.instantiateClass(restTemplateClass));
}

/**
* 被上面方法调用
*/
public <T extends RestTemplate> T configure(T restTemplate) {
//配置请求工厂
configureRequestFactory(restTemplate);
if (!CollectionUtils.isEmpty(this.messageConverters)) {
restTemplate.setMessageConverters(new ArrayList<>(this.messageConverters));
}
if (this.uriTemplateHandler != null) {
restTemplate.setUriTemplateHandler(this.uriTemplateHandler);
}
if (this.errorHandler != null) {
restTemplate.setErrorHandler(this.errorHandler);
}
if (this.rootUri != null) {
RootUriTemplateHandler.addTo(restTemplate, this.rootUri);
}
if (this.basicAuthentication != null) {
restTemplate.getInterceptors().add(this.basicAuthentication);
}
// 添加请求拦截器
restTemplate.getInterceptors().addAll(this.interceptors);
if (!CollectionUtils.isEmpty(this.restTemplateCustomizers)) {
for (RestTemplateCustomizer customizer : this.restTemplateCustomizers) {
customizer.customize(restTemplate);
}
}
return restTemplate;
}
}
  1. 在使用通过 RestTemplateBuilder 创建的 RestTemplate 实例执行 HTTP 请求时,需要先创建请求工厂时。

  2. 在创建请求工厂时会判断是否存在拦截器,如果存在,则创建 InterceptingClientHttpRequestFactory 类型的请求工厂并传入拦截器。

  3. 在使用 InterceptingClientHttpRequestFactory 创建 Request 请求时,将加入传入的拦截器加入实例,在执行 execute() 方法时处理执行拦截器的处理。

    备注:后面两步的处理和 RestTemplate 传入拦截器调的是同样的代码逻辑。

自定义RestTemp late

定义RestTemplate

  1. RestTemplateBuilder

    如果需要调用远程服务,可以使用 Spring Framework 提供的 RestTemplate。RestTemplate 在使用前通常需要自定义,Spring Boot 没有提供自动配置 RestTemplate bean,但自动配置了RestTemplateBuilder bean,用于构建 RestTemplate 实例。自动配置的 RestTemplateBuilder 会确保将合理的 HttpMessageConverters 应用到 RestTemplate 实例中。

    根据上面思路,创建 RestTemplate 对象应该是如下操作:

    1
    2
    3
    RestTemplate template = new RestTemplate(new HttpComponentsClientHttpRequestFactory());
    // 或
    RestTemplate template = new RestTemplateBuilder().setConnectTimeout(Duration.ofMillis(1000)).setReadTimeout(Duration.ofMillis(1000)).build();

    RestTemplateBuilder 提供了很多有用的方法,可以快速地配置 RestTemplate。例如,添加 BASIC auth 认证支持,可以使用如下方式:

    1
    new RestTemplateBuilder().basicAuthentication("user","123456").build();

    另一种创建 RestTemplate 实例是直接 new 这个对象,如下方式。

  2. RestTemplate

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    @Bean
    public RestTemplate restTemplate(DiscoveryClient discoveryClient, CloudTemplateProperties properties) {
    // 自定义的 Http Request
    RestTemplate restTemplate = new RestTemplate(new CloudClientHttpRequestFactory(properties));
    // 自定义的 请求拦截器
    CloudHttpRequestInterceptor httpRequestInterceptor = new CloudHttpRequestInterceptor(discoveryClient, properties);
    ArrayList<ClientHttpRequestInterceptor> chRequestInterceptor = new ArrayList<ClientHttpRequestInterceptor>(1);
    // 加入
    chRequestInterceptor.add(httpRequestInterceptor);
    restTemplate.setInterceptors(chRequestInterceptor);
    return restTemplate;
    }

定义Http Request

  1. 默认的 ClientHttpRequestFactory

    RestTemplate 默认使用的是 java.net.HttpURLConnection 来执行请求,可以切换成不同的实现了ClientHttpRequestFactory 接口的 HTTP 库,如:Apache HttpComponents,Netty,OkHttp。

    示例:使用 Apache HttpComponents

    1
    RestTemplate template = new RestTemplate(new HttpComponentsClientHttpRequestFactory());
  2. 自定义请求工厂示例

    根据业务需要,继承 SimpleClientHttpRequestFactory 或实现 ClientHttpRequestFactory 接口,重写 createRequest()prepareConnection() 方法。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    public class CloudClientHttpRequestFactory extends SimpleClientHttpRequestFactory {

    private static final Logger logger = LoggerFactory.getLogger(CloudClientHttpRequestFactory.class);

    private CloudTemplateProperties propt;

    public CloudClientHttpRequestFactory(CloudTemplateProperties tmpProp) {
    this.propt = tmpProp;
    if(tmpProp != null){
    logger.info("ConnectTimeout = {},ReadTimeout={}",tmpProp.getConnectTimeout(),tmpProp.getReadTimeout());
    if(tmpProp.getConnectTimeout()>0){
    logger.info("ConnectTimeout setted");
    super.setConnectTimeout(tmpProp.getConnectTimeout());
    }
    if(tmpProp.getReadTimeout()>0){
    logger.info("ReadTimeout setted");
    super.setReadTimeout(tmpProp.getReadTimeout());
    }
    }
    }

    @Override
    public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
    URI rsv_uri = uri;
    try {
    rsv_uri = new URI(uri.toASCIIString());
    } catch (URISyntaxException ex) {
    logger.error("system error: ", ex);
    }
    return super.createRequest(rsv_uri, httpMethod);
    }

    @Override
    protected void prepareConnection(HttpURLConnection conn, String method) throws IOException {
    super.prepareConnection(conn, method);
    conn.setConnectTimeout(propt.getConnectTimeout());
    conn.setReadTimeout(propt.getReadTimeout());
    conn.setRequestProperty(HttpHeaders.USER_AGENT, propt.getUserAgent());
    if(StringUtils.isNotEmpty(propt.getAccept())) {
    conn.setRequestProperty(HttpHeaders.ACCEPT, propt.getAccept());
    }
    if(StringUtils.isNotEmpty(propt.getContentType()) && HttpMethod.POST.matches(method)) {
    conn.setRequestProperty(HttpHeaders.CONTENT_TYPE, propt.getContentType());
    }

    Object xAuth = TemplateXAuthHolder.getXAuth();
    if("notoken".equals(xAuth)) {
    return;
    }
    if(null != xAuth) {
    conn.setRequestProperty("X-Auth-Token", xAuth.toString());
    logger.info(" - request url[X-Auth-Token] from app: {}", conn.getURL());
    return;
    }

    HttpServletRequest req = UserContext.getHttpRequest();
    if(null == req) {
    return;
    }
    String xAuthToken = req.getHeader("X-Auth-Token");
    if (StringUtils.isNotEmpty(xAuthToken)) {
    conn.setRequestProperty("X-Auth-Token", xAuthToken);
    logger.info(" - request url[X-Auth-Token]: {}", conn.getURL());
    return;
    }
    String cookie = req.getHeader("Cookie");
    if (StringUtils.isNotEmpty(cookie)) {
    conn.setRequestProperty("Cookie", cookie);
    logger.debug(" - request url[Cookie]: {}", conn.getURL());
    return;
    }
    }
    }

定义Request Interceptor

自定义请求拦截器需要实现 ClientHttpRequestInterceptor 接口中的 intercept() 方法。这块在客户端侧的负载均衡会用到,如 Ribbon。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
public class CloudHttpRequestInterceptor implements ClientHttpRequestInterceptor {

private static final Logger logger = LoggerFactory.getLogger(CloudHttpRequestInterceptor.class);

private AtomicInteger nextServerCyclicCounter = new AtomicInteger(0);;

private DiscoveryClient discoveryClient;
private int maxAttempts;
private Map<String, Object[]> services = new HashMap<String, Object[]>();

public CloudHttpRequestInterceptor(DiscoveryClient discoveryClient, CloudTemplateProperties propt) {
this.discoveryClient = discoveryClient;
this.maxAttempts = propt.getMaxAttempts();
Map<String, String> svrs = propt.getServices();
if (null != svrs) {
for (Entry<String, String> item : svrs.entrySet()) {
String[] service = item.getValue().split(":"); //host:port
services.put(item.getKey(), new Object[] { service[0], service.length > 1 ? Integer.parseInt(service[1]) : 8080 });
}
}
}

@Override
public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException {
HttpRequestWrapper requestWrapper = new HttpRequestWrapper(request) {
@Override
public URI getURI() {
URI uri = request.getURI();
try {
String host = uri.getHost();
Object[] service = services.get(host);
if (service != null) {
return new URI(uri.getScheme(), uri.getUserInfo(), (String) service[0], (int) service[1],
uri.getPath(), uri.getQuery(), uri.getFragment());
}
if (null == discoveryClient) {
return uri;
}
ServiceInstance instance = getInstances(host);
if (null == instance) {
// 走默认URL方式
return uri;
} else {
// 走注册中心服务调用方式
return new URI(uri.getScheme(), uri.getUserInfo(), instance.getHost(), instance.getPort(),
uri.getPath(), uri.getQuery(), uri.getFragment());
}
} catch (URISyntaxException e) {
logger.error("URISyntaxException", e);
return uri;
}
}
};

ClientHttpResponse resp = null;
for (int i = 0; i < maxAttempts; i++) {
try {
resp = execution.execute(requestWrapper, body);
if (resp.getStatusCode() == HttpStatus.OK) {
break;
}
} catch (IOException ex) {
logger.error(" Http Request error: ", ex);
if (i == 2) throw new IOException("三次请求服务失败", ex);
}
}
TemplateXAuthHolder.remove();
return resp;
}

/**
* 根据服务名获取注册中心服务
* @param serviceId
* @return ServiceInstance
*/
private ServiceInstance getInstances(String serviceId) {
List<ServiceInstance> instances = discoveryClient.getInstances(serviceId);
if (CollectionUtils.isEmpty(instances)) {
logger.warn("注册中心获取服务为空! 服务名:{}", serviceId);
return null;
}
logger.info("注册中心获取服务名称:{},在线服务数量:{}", serviceId, instances.size());
int nextServerIndex = incrementAndGetModulo(instances.size());
return instances.get(nextServerIndex);
}

/**
* 轮训策略
* @param modulo
*/
private int incrementAndGetModulo(int modulo) {
int current = nextServerCyclicCounter.get();
int next = (current + 1) % modulo;
if (nextServerCyclicCounter.compareAndSet(current, next)) {
return next;
} else {
return 0;
}
}
}

Spring Boot 2系列(四十五):RestTemplate 源码分析与自定义请求和拦截器

http://blog.gxitsky.com/2019/11/22/SpringBoot-45-RestTemplate-source/

作者

光星

发布于

2019-11-22

更新于

2022-06-17

许可协议

评论