summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--fs/cifs/cifsglob.h19
-rw-r--r--fs/cifs/smb2ops.c13
-rw-r--r--fs/cifs/smb2transport.c14
-rw-r--r--fs/cifs/transport.c6
4 files changed, 49 insertions, 3 deletions
diff --git a/fs/cifs/cifsglob.h b/fs/cifs/cifsglob.h
index 33264f9a9665..1b25e6e95d45 100644
--- a/fs/cifs/cifsglob.h
+++ b/fs/cifs/cifsglob.h
@@ -236,6 +236,8 @@ struct smb_version_operations {
 	int * (*get_credits_field)(struct TCP_Server_Info *, const int);
 	unsigned int (*get_credits)(struct mid_q_entry *);
 	__u64 (*get_next_mid)(struct TCP_Server_Info *);
+	void (*revert_current_mid)(struct TCP_Server_Info *server,
+				   const unsigned int val);
 	/* data offset from read response message */
 	unsigned int (*read_data_offset)(char *);
 	/*
@@ -771,6 +773,22 @@ get_next_mid(struct TCP_Server_Info *server)
 	return cpu_to_le16(mid);
 }
 
+static inline void
+revert_current_mid(struct TCP_Server_Info *server, const unsigned int val)
+{
+	if (server->ops->revert_current_mid)
+		server->ops->revert_current_mid(server, val);
+}
+
+static inline void
+revert_current_mid_from_hdr(struct TCP_Server_Info *server,
+			    const struct smb2_sync_hdr *shdr)
+{
+	unsigned int num = le16_to_cpu(shdr->CreditCharge);
+
+	return revert_current_mid(server, num > 0 ? num : 1);
+}
+
 static inline __u16
 get_mid(const struct smb_hdr *smb)
 {
@@ -1423,6 +1441,7 @@ struct mid_q_entry {
 	struct kref refcount;
 	struct TCP_Server_Info *server;	/* server corresponding to this mid */
 	__u64 mid;		/* multiplex id */
+	__u16 credits;		/* number of credits consumed by this mid */
 	__u32 pid;		/* process id */
 	__u32 sequence_number;  /* for CIFS signing */
 	unsigned long when_alloc;  /* when mid was created */
diff --git a/fs/cifs/smb2ops.c b/fs/cifs/smb2ops.c
index 16704941b969..ea56b1cdbdde 100644
--- a/fs/cifs/smb2ops.c
+++ b/fs/cifs/smb2ops.c
@@ -219,6 +219,15 @@ smb2_get_next_mid(struct TCP_Server_Info *server)
 	return mid;
 }
 
+static void
+smb2_revert_current_mid(struct TCP_Server_Info *server, const unsigned int val)
+{
+	spin_lock(&GlobalMid_Lock);
+	if (server->CurrentMid >= val)
+		server->CurrentMid -= val;
+	spin_unlock(&GlobalMid_Lock);
+}
+
 static struct mid_q_entry *
 smb2_find_mid(struct TCP_Server_Info *server, char *buf)
 {
@@ -3560,6 +3569,7 @@ struct smb_version_operations smb20_operations = {
 	.get_credits = smb2_get_credits,
 	.wait_mtu_credits = cifs_wait_mtu_credits,
 	.get_next_mid = smb2_get_next_mid,
+	.revert_current_mid = smb2_revert_current_mid,
 	.read_data_offset = smb2_read_data_offset,
 	.read_data_length = smb2_read_data_length,
 	.map_error = map_smb2_to_linux_error,
@@ -3655,6 +3665,7 @@ struct smb_version_operations smb21_operations = {
 	.get_credits = smb2_get_credits,
 	.wait_mtu_credits = smb2_wait_mtu_credits,
 	.get_next_mid = smb2_get_next_mid,
+	.revert_current_mid = smb2_revert_current_mid,
 	.read_data_offset = smb2_read_data_offset,
 	.read_data_length = smb2_read_data_length,
 	.map_error = map_smb2_to_linux_error,
@@ -3751,6 +3762,7 @@ struct smb_version_operations smb30_operations = {
 	.get_credits = smb2_get_credits,
 	.wait_mtu_credits = smb2_wait_mtu_credits,
 	.get_next_mid = smb2_get_next_mid,
+	.revert_current_mid = smb2_revert_current_mid,
 	.read_data_offset = smb2_read_data_offset,
 	.read_data_length = smb2_read_data_length,
 	.map_error = map_smb2_to_linux_error,
@@ -3856,6 +3868,7 @@ struct smb_version_operations smb311_operations = {
 	.get_credits = smb2_get_credits,
 	.wait_mtu_credits = smb2_wait_mtu_credits,
 	.get_next_mid = smb2_get_next_mid,
+	.revert_current_mid = smb2_revert_current_mid,
 	.read_data_offset = smb2_read_data_offset,
 	.read_data_length = smb2_read_data_length,
 	.map_error = map_smb2_to_linux_error,
diff --git a/fs/cifs/smb2transport.c b/fs/cifs/smb2transport.c
index 7b351c65ee46..63264db78b89 100644
--- a/fs/cifs/smb2transport.c
+++ b/fs/cifs/smb2transport.c
@@ -576,6 +576,7 @@ smb2_mid_entry_alloc(const struct smb2_sync_hdr *shdr,
 		     struct TCP_Server_Info *server)
 {
 	struct mid_q_entry *temp;
+	unsigned int credits = le16_to_cpu(shdr->CreditCharge);
 
 	if (server == NULL) {
 		cifs_dbg(VFS, "Null TCP session in smb2_mid_entry_alloc\n");
@@ -586,6 +587,7 @@ smb2_mid_entry_alloc(const struct smb2_sync_hdr *shdr,
 	memset(temp, 0, sizeof(struct mid_q_entry));
 	kref_init(&temp->refcount);
 	temp->mid = le64_to_cpu(shdr->MessageId);
+	temp->credits = credits > 0 ? credits : 1;
 	temp->pid = current->pid;
 	temp->command = shdr->Command; /* Always LE */
 	temp->when_alloc = jiffies;
@@ -674,13 +676,18 @@ smb2_setup_request(struct cifs_ses *ses, struct smb_rqst *rqst)
 	smb2_seq_num_into_buf(ses->server, shdr);
 
 	rc = smb2_get_mid_entry(ses, shdr, &mid);
-	if (rc)
+	if (rc) {
+		revert_current_mid_from_hdr(ses->server, shdr);
 		return ERR_PTR(rc);
+	}
+
 	rc = smb2_sign_rqst(rqst, ses->server);
 	if (rc) {
+		revert_current_mid_from_hdr(ses->server, shdr);
 		cifs_delete_mid(mid);
 		return ERR_PTR(rc);
 	}
+
 	return mid;
 }
 
@@ -695,11 +702,14 @@ smb2_setup_async_request(struct TCP_Server_Info *server, struct smb_rqst *rqst)
 	smb2_seq_num_into_buf(server, shdr);
 
 	mid = smb2_mid_entry_alloc(shdr, server);
-	if (mid == NULL)
+	if (mid == NULL) {
+		revert_current_mid_from_hdr(server, shdr);
 		return ERR_PTR(-ENOMEM);
+	}
 
 	rc = smb2_sign_rqst(rqst, server);
 	if (rc) {
+		revert_current_mid_from_hdr(server, shdr);
 		DeleteMidQEntry(mid);
 		return ERR_PTR(rc);
 	}
diff --git a/fs/cifs/transport.c b/fs/cifs/transport.c
index 53532bd3f50d..9544eb99b5a2 100644
--- a/fs/cifs/transport.c
+++ b/fs/cifs/transport.c
@@ -647,6 +647,7 @@ cifs_call_async(struct TCP_Server_Info *server, struct smb_rqst *rqst,
 	cifs_in_send_dec(server);
 
 	if (rc < 0) {
+		revert_current_mid(server, mid->credits);
 		server->sequence_number -= 2;
 		cifs_delete_mid(mid);
 	}
@@ -868,6 +869,7 @@ compound_send_recv(const unsigned int xid, struct cifs_ses *ses,
 	for (i = 0; i < num_rqst; i++) {
 		midQ[i] = ses->server->ops->setup_request(ses, &rqst[i]);
 		if (IS_ERR(midQ[i])) {
+			revert_current_mid(ses->server, i);
 			for (j = 0; j < i; j++)
 				cifs_delete_mid(midQ[j]);
 			mutex_unlock(&ses->server->srv_mutex);
@@ -897,8 +899,10 @@ compound_send_recv(const unsigned int xid, struct cifs_ses *ses,
 	for (i = 0; i < num_rqst; i++)
 		cifs_save_when_sent(midQ[i]);
 
-	if (rc < 0)
+	if (rc < 0) {
+		revert_current_mid(ses->server, num_rqst);
 		ses->server->sequence_number -= 2;
+	}
 
 	mutex_unlock(&ses->server->srv_mutex);